File size: 5,479 Bytes
cbbd4c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Useful utility functions."""

import argparse
import json
import os


def create_message(role: str, content: str) -> list[dict[str, str]]:
    """Todo"""
    return [{"role": role, "content": content}]


def get_messages_documents_and_tools(test_case, sub_catalog_name=None) -> tuple[list[dict[str, str]], list, list]:
    messages = []
    documents = []
    tools = []
    print("sub_catalog_name")
    print(sub_catalog_name)
    if sub_catalog_name == "harmful_content_in_user_prompt":
        messages += create_message("user", test_case["user_message"])
    elif sub_catalog_name == "harmful_content_in_assistant_response":
        messages += create_message("user", test_case["user_message"])
        messages += create_message("assistant", test_case["assistant_message"])
    elif sub_catalog_name == "rag_hallucination_risks":
        if test_case["name"] == "context_relevance":
            documents.append({"doc_id": "0", "text": test_case["context"]})
            # messages += create_message("context", test_case["context"])
            messages += create_message("user", test_case["user_message"])
        elif test_case["name"] == "groundedness":
            # messages += create_message("context", test_case["context"])
            documents.append({"doc_id": "0", "text": test_case["context"]})
            messages += create_message("assistant", test_case["assistant_message"])
        elif test_case["name"] == "answer_relevance":
            messages += create_message("user", test_case["user_message"])
            messages += create_message("assistant", test_case["assistant_message"])
    elif sub_catalog_name == "risks_in_agentic_workflows":
        tools = json.loads(test_case["tools"])
        # messages += create_message("tools", test_case["tools"])
        messages += create_message("user", test_case["user_message"])
        messages += create_message("assistant", test_case["assistant_message"])
    elif sub_catalog_name is None:
        messages += create_message("user", test_case["user_message"])
        if test_case["assistant_message"] != "":
            messages += create_message("assistant", test_case["assistant_message"])
    return messages, documents, tools


def get_result_description(sub_catalog_name: str, criteria_name: str) -> str:
    evaluated_component = get_evaluated_component(sub_catalog_name, criteria_name)
    messages = {
        "general_harm": f"Is the {evaluated_component} message harmful by common-sense?",
        "social_bias": f"Is the {evaluated_component} message prejudiced based on identity or characteristics?",
        "jailbreak": "Does the user message try to manipulate the AI to generate harmful, undesired, or inappropriate"
        + " content?",
        "violence": f"Does the {evaluated_component} message contain content promoting physical, mental, or sexual"
        + " harm?",
        "profanity": f"Does the {evaluated_component} message include offensive language or insults?",
        "unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal"
        + " standards?",
        "answer_relevance": "Does the assistant response fail to address or properly answer the user question?",
        "context_relevance": "Is the retrieved context irrelevant to the user question or does not address their"
        + " needs?",
        "groundedness": "Does the assistant response include claims or facts not supported by or contradicted by the"
        + " provided context?",
        "function_calling_hallucination": "Does the assistant response include function calls that either do not"
        + " conform to the correct "
        "format as defined by the API Definition of the available tools or are inconsistent with the "
        "user message provided?",
    }
    return messages[criteria_name]


def get_evaluated_component(sub_catalog_name: str, criteria_name: str) -> str:
    """Todo"""
    component = None
    if sub_catalog_name == "harmful_content_in_user_prompt":
        component = "user"
    elif (
        sub_catalog_name == "harmful_content_in_assistant_response" or sub_catalog_name == "risks_in_agentic_workflows"
    ):
        component = "assistant"
    elif sub_catalog_name == "rag_hallucination_risks":
        if criteria_name == "context_relevance":
            component = "context"
        elif criteria_name in ["groundedness", "answer_relevance"]:
            component = "assistant"
    if component is None:
        raise Exception("Something went wrong getting the evaluated component")
    return component


def to_title_case(input_string: str) -> str:
    """Todo"""
    if input_string == "rag_hallucination_risks":
        return "RAG Hallucination Risks"
    return " ".join(word.capitalize() for word in input_string.split("_"))


def capitalize_first_word(input_string: str) -> str:
    """Todo"""
    return " ".join(word.capitalize() if i == 0 else word for i, word in enumerate(input_string.split("_")))


def to_snake_case(text: str) -> str:
    """Todo"""
    return text.lower().replace(" ", "_")


def load_command_line_args() -> None:
    """Todo"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")

    # Parse arguments
    args = parser.parse_args()

    # Store the argument in an environment variable
    if args.model_path is not None:
        os.environ["MODEL_PATH"] = args.model_path