Upload folder using huggingface_hub
Browse files- api.py +10 -8
- artifact.py +11 -2
- dataset.py +0 -1
- dataset_utils.py +8 -5
- inference.py +86 -29
- metric.py +0 -1
- metrics.py +76 -32
- operators.py +47 -0
- serializers.py +1 -6
- struct_data_operators.py +21 -1
- tool_calling.py +0 -119
- type_utils.py +15 -3
- types.py +12 -6
- version.py +1 -1
api.py
CHANGED
|
@@ -37,12 +37,11 @@ def short_hex_hash(value, length=8):
|
|
| 37 |
return h[:length]
|
| 38 |
|
| 39 |
|
| 40 |
-
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
|
| 41 |
-
dataset_query = dataset_query.replace("sys_prompt", "instruction")
|
| 42 |
try:
|
| 43 |
-
dataset_stream, _ = fetch_artifact(dataset_query)
|
| 44 |
except:
|
| 45 |
-
dataset_stream = get_dataset_artifact(dataset_query)
|
| 46 |
return dataset_stream
|
| 47 |
|
| 48 |
|
|
@@ -82,14 +81,15 @@ def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
|
|
| 82 |
if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
|
| 83 |
return dataset_query
|
| 84 |
|
| 85 |
-
_verify_dataset_args(dataset_query, kwargs)
|
| 86 |
-
|
| 87 |
if dataset_query:
|
| 88 |
-
recipe = _get_recipe_from_query(dataset_query)
|
| 89 |
|
| 90 |
-
|
| 91 |
recipe = _get_recipe_from_dict(kwargs)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
return recipe
|
| 94 |
|
| 95 |
|
|
@@ -187,6 +187,8 @@ def load_dataset(
|
|
| 187 |
Alternatively, dataset is loaded from a provided card based on explicitly
|
| 188 |
given parameters.
|
| 189 |
|
|
|
|
|
|
|
| 190 |
Args:
|
| 191 |
dataset_query (str, optional):
|
| 192 |
A string query which specifies a dataset to load from
|
|
|
|
| 37 |
return h[:length]
|
| 38 |
|
| 39 |
|
| 40 |
+
def _get_recipe_from_query(dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]]=None) -> DatasetRecipe:
|
|
|
|
| 41 |
try:
|
| 42 |
+
dataset_stream, _ = fetch_artifact(dataset_query, overwrite_kwargs=overwrite_kwargs)
|
| 43 |
except:
|
| 44 |
+
dataset_stream = get_dataset_artifact(dataset_query, overwrite_kwargs=overwrite_kwargs)
|
| 45 |
return dataset_stream
|
| 46 |
|
| 47 |
|
|
|
|
| 81 |
if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
|
| 82 |
return dataset_query
|
| 83 |
|
|
|
|
|
|
|
| 84 |
if dataset_query:
|
| 85 |
+
recipe = _get_recipe_from_query(dataset_query, kwargs)
|
| 86 |
|
| 87 |
+
elif kwargs:
|
| 88 |
recipe = _get_recipe_from_dict(kwargs)
|
| 89 |
|
| 90 |
+
else:
|
| 91 |
+
raise UnitxtError("Specify either dataset recipe string artifact name or recipe args.")
|
| 92 |
+
|
| 93 |
return recipe
|
| 94 |
|
| 95 |
|
|
|
|
| 187 |
Alternatively, dataset is loaded from a provided card based on explicitly
|
| 188 |
given parameters.
|
| 189 |
|
| 190 |
+
If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
|
| 191 |
+
|
| 192 |
Args:
|
| 193 |
dataset_query (str, optional):
|
| 194 |
A string query which specifies a dataset to load from
|
artifact.py
CHANGED
|
@@ -22,7 +22,7 @@ from .parsing_utils import (
|
|
| 22 |
separate_inside_and_outside_square_brackets,
|
| 23 |
)
|
| 24 |
from .settings_utils import get_constants, get_settings
|
| 25 |
-
from .text_utils import camel_to_snake_case, is_camel_case
|
| 26 |
from .type_utils import isoftype, issubtype
|
| 27 |
from .utils import (
|
| 28 |
artifacts_json_cache,
|
|
@@ -369,6 +369,10 @@ class Artifact(Dataclass):
|
|
| 369 |
data = self.to_dict()
|
| 370 |
return json_dump(data)
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
def serialize(self):
|
| 373 |
if self.__id__ is not None:
|
| 374 |
return self.__id__
|
|
@@ -528,7 +532,7 @@ class UnitxtArtifactNotFoundError(UnitxtError):
|
|
| 528 |
super().__init__(msg)
|
| 529 |
|
| 530 |
|
| 531 |
-
def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]]:
|
| 532 |
"""Loads an artifict from one of possible representations.
|
| 533 |
|
| 534 |
(1) If artifact representation is already an Artifact object, return it.
|
|
@@ -553,6 +557,11 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]
|
|
| 553 |
name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
|
| 554 |
if is_name_legal_for_catalog(name):
|
| 555 |
catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
artifact_to_return = catalog.get_with_overwrite(
|
| 557 |
artifact_rep, overwrite_args=args
|
| 558 |
)
|
|
|
|
| 22 |
separate_inside_and_outside_square_brackets,
|
| 23 |
)
|
| 24 |
from .settings_utils import get_constants, get_settings
|
| 25 |
+
from .text_utils import camel_to_snake_case, is_camel_case, print_dict_as_yaml
|
| 26 |
from .type_utils import isoftype, issubtype
|
| 27 |
from .utils import (
|
| 28 |
artifacts_json_cache,
|
|
|
|
| 369 |
data = self.to_dict()
|
| 370 |
return json_dump(data)
|
| 371 |
|
| 372 |
+
def to_yaml(self):
|
| 373 |
+
data = self.to_dict()
|
| 374 |
+
return print_dict_as_yaml(data)
|
| 375 |
+
|
| 376 |
def serialize(self):
|
| 377 |
if self.__id__ is not None:
|
| 378 |
return self.__id__
|
|
|
|
| 532 |
super().__init__(msg)
|
| 533 |
|
| 534 |
|
| 535 |
+
def fetch_artifact(artifact_rep, overwrite_kwargs: Optional[Dict[str, Any]]=None) -> Tuple[Artifact, Union[AbstractCatalog, None]]:
|
| 536 |
"""Loads an artifict from one of possible representations.
|
| 537 |
|
| 538 |
(1) If artifact representation is already an Artifact object, return it.
|
|
|
|
| 557 |
name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
|
| 558 |
if is_name_legal_for_catalog(name):
|
| 559 |
catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
|
| 560 |
+
if overwrite_kwargs is not None:
|
| 561 |
+
if args is None:
|
| 562 |
+
args = overwrite_kwargs
|
| 563 |
+
else:
|
| 564 |
+
args.update(overwrite_kwargs)
|
| 565 |
artifact_to_return = catalog.get_with_overwrite(
|
| 566 |
artifact_rep, overwrite_args=args
|
| 567 |
)
|
dataset.py
CHANGED
|
@@ -68,7 +68,6 @@ from .system_prompts import __file__ as _
|
|
| 68 |
from .task import __file__ as _
|
| 69 |
from .templates import __file__ as _
|
| 70 |
from .text_utils import __file__ as _
|
| 71 |
-
from .tool_calling import __file__ as _
|
| 72 |
from .type_utils import __file__ as _
|
| 73 |
from .types import __file__ as _
|
| 74 |
from .utils import __file__ as _
|
|
|
|
| 68 |
from .task import __file__ as _
|
| 69 |
from .templates import __file__ as _
|
| 70 |
from .text_utils import __file__ as _
|
|
|
|
| 71 |
from .type_utils import __file__ as _
|
| 72 |
from .types import __file__ as _
|
| 73 |
from .utils import __file__ as _
|
dataset_utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from json.decoder import JSONDecodeError
|
|
|
|
| 2 |
|
| 3 |
from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
|
| 4 |
from .logging_utils import get_logger
|
|
@@ -11,19 +12,19 @@ logger = get_logger()
|
|
| 11 |
settings = get_settings()
|
| 12 |
|
| 13 |
|
| 14 |
-
def fetch(artifact_name):
|
| 15 |
try:
|
| 16 |
-
artifact, _ = fetch_artifact(artifact_name)
|
| 17 |
return artifact
|
| 18 |
except (UnitxtArtifactNotFoundError, JSONDecodeError):
|
| 19 |
return None
|
| 20 |
|
| 21 |
|
| 22 |
-
def parse(query: str):
|
| 23 |
return parse_key_equals_value_string_to_dict(query)
|
| 24 |
|
| 25 |
|
| 26 |
-
def get_dataset_artifact(dataset):
|
| 27 |
if isinstance(dataset, DatasetRecipe):
|
| 28 |
return dataset
|
| 29 |
assert isinstance(
|
|
@@ -31,10 +32,12 @@ def get_dataset_artifact(dataset):
|
|
| 31 |
), "dataset should be string description of recipe, or recipe object."
|
| 32 |
_reset_env_local_catalogs()
|
| 33 |
register_all_artifacts()
|
| 34 |
-
recipe = fetch(dataset)
|
| 35 |
if recipe is None:
|
| 36 |
args = parse(dataset)
|
| 37 |
if "__type__" not in args:
|
| 38 |
args["__type__"] = settings.default_recipe
|
|
|
|
|
|
|
| 39 |
recipe = Artifact.from_dict(args)
|
| 40 |
return recipe
|
|
|
|
| 1 |
from json.decoder import JSONDecodeError
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
|
| 4 |
from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
|
| 5 |
from .logging_utils import get_logger
|
|
|
|
| 12 |
settings = get_settings()
|
| 13 |
|
| 14 |
|
| 15 |
+
def fetch(artifact_name: str, overwrite_kwargs: Optional[Dict[str, Any]]=None):
|
| 16 |
try:
|
| 17 |
+
artifact, _ = fetch_artifact(artifact_name, overwrite_kwargs=overwrite_kwargs)
|
| 18 |
return artifact
|
| 19 |
except (UnitxtArtifactNotFoundError, JSONDecodeError):
|
| 20 |
return None
|
| 21 |
|
| 22 |
|
| 23 |
+
def parse(query: str) -> dict:
|
| 24 |
return parse_key_equals_value_string_to_dict(query)
|
| 25 |
|
| 26 |
|
| 27 |
+
def get_dataset_artifact(dataset, overwrite_kwargs: Optional[Dict[str, Any]]=None):
|
| 28 |
if isinstance(dataset, DatasetRecipe):
|
| 29 |
return dataset
|
| 30 |
assert isinstance(
|
|
|
|
| 32 |
), "dataset should be string description of recipe, or recipe object."
|
| 33 |
_reset_env_local_catalogs()
|
| 34 |
register_all_artifacts()
|
| 35 |
+
recipe = fetch(dataset, overwrite_kwargs=overwrite_kwargs)
|
| 36 |
if recipe is None:
|
| 37 |
args = parse(dataset)
|
| 38 |
if "__type__" not in args:
|
| 39 |
args["__type__"] = settings.default_recipe
|
| 40 |
+
if overwrite_kwargs is not None:
|
| 41 |
+
args.update(overwrite_kwargs)
|
| 42 |
recipe = Artifact.from_dict(args)
|
| 43 |
return recipe
|
inference.py
CHANGED
|
@@ -344,6 +344,8 @@ class InferenceEngine(Artifact):
|
|
| 344 |
|
| 345 |
def to_tools(self, instance):
|
| 346 |
task_data = instance.get("task_data")
|
|
|
|
|
|
|
| 347 |
if isinstance(task_data, str):
|
| 348 |
task_data = json.loads(task_data)
|
| 349 |
if "__tools__" in task_data:
|
|
@@ -445,6 +447,8 @@ class HFInferenceEngineBase(
|
|
| 445 |
model: Any = InternalField(default=None, name="Inference object")
|
| 446 |
processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
|
| 447 |
|
|
|
|
|
|
|
| 448 |
_requirements_list = {
|
| 449 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers",
|
| 450 |
"torch": "Install torch, go on PyTorch website for mode details.",
|
|
@@ -655,8 +659,6 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 655 |
truncation: bool = True
|
| 656 |
padding_side: str = "left" # for decoder only models
|
| 657 |
|
| 658 |
-
chat_kwargs_dict: dict = {}
|
| 659 |
-
|
| 660 |
def _init_processor(self):
|
| 661 |
from transformers import AutoTokenizer
|
| 662 |
|
|
@@ -712,10 +714,9 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 712 |
trust_remote_code=True,
|
| 713 |
**model_args,
|
| 714 |
)
|
| 715 |
-
if self.device_map is None:
|
| 716 |
-
self.model.to(self.device)
|
| 717 |
|
| 718 |
def prepare_inputs(self, data: Iterable) -> Mapping:
|
|
|
|
| 719 |
if isinstance(data[0], list):
|
| 720 |
data = self.processor.apply_chat_template(
|
| 721 |
data,
|
|
@@ -723,6 +724,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 723 |
add_generation_prompt=True,
|
| 724 |
**self.chat_kwargs_dict,
|
| 725 |
)
|
|
|
|
| 726 |
|
| 727 |
if self.processor.pad_token is None:
|
| 728 |
self.processor.pad_token_id = self.model.config.eos_token_id[0]
|
|
@@ -733,6 +735,8 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 733 |
padding=self.padding,
|
| 734 |
truncation=self.truncation,
|
| 735 |
padding_side=self.padding_side,
|
|
|
|
|
|
|
| 736 |
).to(self.device or self.device_map)
|
| 737 |
|
| 738 |
def _infer_fn(
|
|
@@ -755,13 +759,14 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 755 |
"""
|
| 756 |
all_final_outputs = [] # List to store results from all batches
|
| 757 |
|
| 758 |
-
for
|
| 759 |
-
|
| 760 |
desc=f"Running inference in batches of {self.batch_size}",
|
|
|
|
| 761 |
):
|
|
|
|
| 762 |
# Get the current batch
|
| 763 |
-
|
| 764 |
-
batch_sources = [instance["source"] for instance in batch_data]
|
| 765 |
|
| 766 |
# --- Process the current batch ---
|
| 767 |
# 1. Tokenize inputs for the batch
|
|
@@ -800,7 +805,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 800 |
j
|
| 801 |
], # Output for the j-th item in the batch
|
| 802 |
output_tokens=len(string_tokens_batch[j]),
|
| 803 |
-
inp=
|
| 804 |
inp_tokens=len(tokenized_inputs.encodings[j].tokens)
|
| 805 |
if tokenized_inputs.encodings is not None
|
| 806 |
else None,
|
|
@@ -1840,15 +1845,26 @@ class OpenAiInferenceEngine(
|
|
| 1840 |
@run_with_imap
|
| 1841 |
def _get_chat_completion(self, instance, return_meta_data):
|
| 1842 |
import openai
|
| 1843 |
-
|
| 1844 |
messages = self.to_messages(instance)
|
| 1845 |
try:
|
| 1846 |
response = self.client.chat.completions.create(
|
| 1847 |
messages=messages,
|
|
|
|
| 1848 |
model=self.get_client_model_name(),
|
| 1849 |
**self._get_completion_kwargs(),
|
|
|
|
| 1850 |
)
|
| 1851 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1852 |
return self.get_return_object(prediction, response, return_meta_data)
|
| 1853 |
# catch in case of content_filtering failure
|
| 1854 |
except openai.BadRequestError as e:
|
|
@@ -2742,14 +2758,37 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
| 2742 |
# images as SDK allows sending only one image per message.
|
| 2743 |
return [messages]
|
| 2744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2745 |
def _handle_async_requests(
|
| 2746 |
self,
|
| 2747 |
-
|
| 2748 |
params: Dict[str, Any],
|
| 2749 |
) -> List[Dict[str, Any]]:
|
| 2750 |
async def handle_async_requests(start_idx, end_idx):
|
| 2751 |
coroutines = [
|
| 2752 |
-
self._model.achat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2753 |
for idx in range(start_idx, end_idx)
|
| 2754 |
]
|
| 2755 |
batch_results = await asyncio.gather(*coroutines)
|
|
@@ -2758,10 +2797,10 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
| 2758 |
loop = asyncio.get_event_loop()
|
| 2759 |
results = []
|
| 2760 |
|
| 2761 |
-
for batch_idx in range(0, len(
|
| 2762 |
batch_results = loop.run_until_complete(
|
| 2763 |
handle_async_requests(
|
| 2764 |
-
batch_idx, min(batch_idx + self.concurrency_limit, len(
|
| 2765 |
)
|
| 2766 |
)
|
| 2767 |
results.extend(batch_results)
|
|
@@ -2783,25 +2822,43 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
| 2783 |
output_type = "message"
|
| 2784 |
params["logprobs"] = False
|
| 2785 |
|
| 2786 |
-
|
| 2787 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2788 |
for i in range(len(dataset))
|
| 2789 |
for message in self.to_messages(dataset[i])
|
| 2790 |
]
|
| 2791 |
|
| 2792 |
-
|
| 2793 |
-
[msg[1] for msg in indexed_messages], params
|
| 2794 |
-
)
|
| 2795 |
|
| 2796 |
-
|
| 2797 |
-
|
| 2798 |
-
|
| 2799 |
-
|
| 2800 |
-
|
| 2801 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2802 |
)
|
| 2803 |
-
|
| 2804 |
-
|
| 2805 |
|
| 2806 |
def get_return_object(self, predict_result, result, input_text, return_meta_data):
|
| 2807 |
if return_meta_data:
|
|
@@ -3439,7 +3496,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
| 3439 |
"aws": LiteLLMInferenceEngine,
|
| 3440 |
"ollama": OllamaInferenceEngine,
|
| 3441 |
"bam": IbmGenAiInferenceEngine,
|
| 3442 |
-
"watsonx-sdk":
|
| 3443 |
"rits": RITSInferenceEngine,
|
| 3444 |
"azure": LiteLLMInferenceEngine,
|
| 3445 |
"vertex-ai": LiteLLMInferenceEngine,
|
|
|
|
| 344 |
|
| 345 |
def to_tools(self, instance):
|
| 346 |
task_data = instance.get("task_data")
|
| 347 |
+
if task_data is None:
|
| 348 |
+
return None
|
| 349 |
if isinstance(task_data, str):
|
| 350 |
task_data = json.loads(task_data)
|
| 351 |
if "__tools__" in task_data:
|
|
|
|
| 447 |
model: Any = InternalField(default=None, name="Inference object")
|
| 448 |
processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
|
| 449 |
|
| 450 |
+
chat_kwargs_dict: dict = {}
|
| 451 |
+
|
| 452 |
_requirements_list = {
|
| 453 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers",
|
| 454 |
"torch": "Install torch, go on PyTorch website for mode details.",
|
|
|
|
| 659 |
truncation: bool = True
|
| 660 |
padding_side: str = "left" # for decoder only models
|
| 661 |
|
|
|
|
|
|
|
| 662 |
def _init_processor(self):
|
| 663 |
from transformers import AutoTokenizer
|
| 664 |
|
|
|
|
| 714 |
trust_remote_code=True,
|
| 715 |
**model_args,
|
| 716 |
)
|
|
|
|
|
|
|
| 717 |
|
| 718 |
def prepare_inputs(self, data: Iterable) -> Mapping:
|
| 719 |
+
tokenizer_kargs = {}
|
| 720 |
if isinstance(data[0], list):
|
| 721 |
data = self.processor.apply_chat_template(
|
| 722 |
data,
|
|
|
|
| 724 |
add_generation_prompt=True,
|
| 725 |
**self.chat_kwargs_dict,
|
| 726 |
)
|
| 727 |
+
tokenizer_kargs["add_special_tokens"] = False
|
| 728 |
|
| 729 |
if self.processor.pad_token is None:
|
| 730 |
self.processor.pad_token_id = self.model.config.eos_token_id[0]
|
|
|
|
| 735 |
padding=self.padding,
|
| 736 |
truncation=self.truncation,
|
| 737 |
padding_side=self.padding_side,
|
| 738 |
+
**tokenizer_kargs
|
| 739 |
+
|
| 740 |
).to(self.device or self.device_map)
|
| 741 |
|
| 742 |
def _infer_fn(
|
|
|
|
| 759 |
"""
|
| 760 |
all_final_outputs = [] # List to store results from all batches
|
| 761 |
|
| 762 |
+
for batch in tqdm(
|
| 763 |
+
batched(dataset, self.batch_size),
|
| 764 |
desc=f"Running inference in batches of {self.batch_size}",
|
| 765 |
+
total=len(dataset) // self.batch_size,
|
| 766 |
):
|
| 767 |
+
|
| 768 |
# Get the current batch
|
| 769 |
+
batch_sources = [instance["source"] for instance in batch]
|
|
|
|
| 770 |
|
| 771 |
# --- Process the current batch ---
|
| 772 |
# 1. Tokenize inputs for the batch
|
|
|
|
| 805 |
j
|
| 806 |
], # Output for the j-th item in the batch
|
| 807 |
output_tokens=len(string_tokens_batch[j]),
|
| 808 |
+
inp=batch[j]["source"], # Original input for the j-th item
|
| 809 |
inp_tokens=len(tokenized_inputs.encodings[j].tokens)
|
| 810 |
if tokenized_inputs.encodings is not None
|
| 811 |
else None,
|
|
|
|
| 1845 |
@run_with_imap
|
| 1846 |
def _get_chat_completion(self, instance, return_meta_data):
|
| 1847 |
import openai
|
| 1848 |
+
tools = self.to_tools(instance)
|
| 1849 |
messages = self.to_messages(instance)
|
| 1850 |
try:
|
| 1851 |
response = self.client.chat.completions.create(
|
| 1852 |
messages=messages,
|
| 1853 |
+
tools=tools,
|
| 1854 |
model=self.get_client_model_name(),
|
| 1855 |
**self._get_completion_kwargs(),
|
| 1856 |
+
# tool_choice="auto"
|
| 1857 |
)
|
| 1858 |
+
|
| 1859 |
+
if tools is None:
|
| 1860 |
+
prediction = response.choices[0].message.content
|
| 1861 |
+
else:
|
| 1862 |
+
try:
|
| 1863 |
+
func_call = response.choices[0].message.tool_calls[0].function
|
| 1864 |
+
prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}'
|
| 1865 |
+
except:
|
| 1866 |
+
prediction = response.choices[0].message.content or ""
|
| 1867 |
+
|
| 1868 |
return self.get_return_object(prediction, response, return_meta_data)
|
| 1869 |
# catch in case of content_filtering failure
|
| 1870 |
except openai.BadRequestError as e:
|
|
|
|
| 2758 |
# images as SDK allows sending only one image per message.
|
| 2759 |
return [messages]
|
| 2760 |
|
| 2761 |
+
def to_tools(
|
| 2762 |
+
self,
|
| 2763 |
+
instance: Dict[str, Any]
|
| 2764 |
+
) -> Dict[str, Union[Optional[List[Dict[str, str]]], Optional[Dict[str, str]]]]:
|
| 2765 |
+
"""watsonx.ai chat also allows specifying which tools models must use."""
|
| 2766 |
+
task_data = instance.get("task_data")
|
| 2767 |
+
if task_data is None:
|
| 2768 |
+
return {"tools": None, "tool_choice": None}
|
| 2769 |
+
|
| 2770 |
+
if isinstance(task_data, str):
|
| 2771 |
+
task_data = json.loads(task_data)
|
| 2772 |
+
if "__tools__" in task_data:
|
| 2773 |
+
tools: List[Dict[str, str]] = task_data["__tools__"]
|
| 2774 |
+
tool_choice: Optional[Dict[str, str]] = task_data.get("__tool_choice__")
|
| 2775 |
+
return {"tools": tools, "tool_choice": tool_choice}
|
| 2776 |
+
|
| 2777 |
+
return {"tools": None, "tool_choice": None}
|
| 2778 |
+
|
| 2779 |
def _handle_async_requests(
|
| 2780 |
self,
|
| 2781 |
+
data: List[Dict[str, Any]],
|
| 2782 |
params: Dict[str, Any],
|
| 2783 |
) -> List[Dict[str, Any]]:
|
| 2784 |
async def handle_async_requests(start_idx, end_idx):
|
| 2785 |
coroutines = [
|
| 2786 |
+
self._model.achat(
|
| 2787 |
+
messages=data[idx]["msg"],
|
| 2788 |
+
params=params,
|
| 2789 |
+
tools=data[idx]["tools"]["tools"],
|
| 2790 |
+
tool_choice=data[idx]["tools"]["tool_choice"],
|
| 2791 |
+
)
|
| 2792 |
for idx in range(start_idx, end_idx)
|
| 2793 |
]
|
| 2794 |
batch_results = await asyncio.gather(*coroutines)
|
|
|
|
| 2797 |
loop = asyncio.get_event_loop()
|
| 2798 |
results = []
|
| 2799 |
|
| 2800 |
+
for batch_idx in range(0, len(data), self.concurrency_limit):
|
| 2801 |
batch_results = loop.run_until_complete(
|
| 2802 |
handle_async_requests(
|
| 2803 |
+
batch_idx, min(batch_idx + self.concurrency_limit, len(data))
|
| 2804 |
)
|
| 2805 |
)
|
| 2806 |
results.extend(batch_results)
|
|
|
|
| 2822 |
output_type = "message"
|
| 2823 |
params["logprobs"] = False
|
| 2824 |
|
| 2825 |
+
data = [
|
| 2826 |
+
{
|
| 2827 |
+
"idx": i,
|
| 2828 |
+
"msg": message,
|
| 2829 |
+
"tools": self.to_tools(dataset[i]),
|
| 2830 |
+
}
|
| 2831 |
for i in range(len(dataset))
|
| 2832 |
for message in self.to_messages(dataset[i])
|
| 2833 |
]
|
| 2834 |
|
| 2835 |
+
responses = self._handle_async_requests(data, params)
|
|
|
|
|
|
|
| 2836 |
|
| 2837 |
+
results = []
|
| 2838 |
+
for inp, response in zip(data, responses):
|
| 2839 |
+
idx = inp["idx"]
|
| 2840 |
+
tool_call = data[idx]["tools"]["tools"] is not None
|
| 2841 |
+
|
| 2842 |
+
output = response["choices"][0][output_type]
|
| 2843 |
+
if tool_call:
|
| 2844 |
+
if "tool_calls" in output:
|
| 2845 |
+
func = output["tool_calls"][0]["function"]
|
| 2846 |
+
prediction = f'{{"name": "{func["name"]}", "arguments": {func["arguments"]}}}'
|
| 2847 |
+
else:
|
| 2848 |
+
prediction = output["content"]
|
| 2849 |
+
else:
|
| 2850 |
+
prediction = output["content"]
|
| 2851 |
+
|
| 2852 |
+
results.append(
|
| 2853 |
+
self.get_return_object(
|
| 2854 |
+
prediction,
|
| 2855 |
+
response,
|
| 2856 |
+
str(inp),
|
| 2857 |
+
return_meta_data,
|
| 2858 |
+
)
|
| 2859 |
)
|
| 2860 |
+
|
| 2861 |
+
return results
|
| 2862 |
|
| 2863 |
def get_return_object(self, predict_result, result, input_text, return_meta_data):
|
| 2864 |
if return_meta_data:
|
|
|
|
| 3496 |
"aws": LiteLLMInferenceEngine,
|
| 3497 |
"ollama": OllamaInferenceEngine,
|
| 3498 |
"bam": IbmGenAiInferenceEngine,
|
| 3499 |
+
"watsonx-sdk": WMLInferenceEngineChat,
|
| 3500 |
"rits": RITSInferenceEngine,
|
| 3501 |
"azure": LiteLLMInferenceEngine,
|
| 3502 |
"vertex-ai": LiteLLMInferenceEngine,
|
metric.py
CHANGED
|
@@ -65,7 +65,6 @@ from .system_prompts import __file__ as _
|
|
| 65 |
from .task import __file__ as _
|
| 66 |
from .templates import __file__ as _
|
| 67 |
from .text_utils import __file__ as _
|
| 68 |
-
from .tool_calling import __file__ as _
|
| 69 |
from .type_utils import __file__ as _
|
| 70 |
from .types import __file__ as _
|
| 71 |
from .utils import __file__ as _
|
|
|
|
| 65 |
from .task import __file__ as _
|
| 66 |
from .templates import __file__ as _
|
| 67 |
from .text_utils import __file__ as _
|
|
|
|
| 68 |
from .type_utils import __file__ as _
|
| 69 |
from .types import __file__ as _
|
| 70 |
from .utils import __file__ as _
|
metrics.py
CHANGED
|
@@ -63,7 +63,6 @@ from .operators import ArtifactFetcherMixin, Copy, Set
|
|
| 63 |
from .random_utils import get_seed
|
| 64 |
from .settings_utils import get_settings
|
| 65 |
from .stream import MultiStream, Stream
|
| 66 |
-
from .tool_calling import convert_chat_api_format_to_tool
|
| 67 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
| 68 |
from .types import ToolCall
|
| 69 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
|
@@ -789,74 +788,92 @@ class F1Fast(MapReduceMetric[str, Tuple[int, int]]):
|
|
| 789 |
return result
|
| 790 |
|
| 791 |
class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
|
|
|
| 792 |
main_score = "exact_match"
|
| 793 |
reduction = MeanReduction()
|
| 794 |
prediction_type = ToolCall
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
|
| 796 |
def map(
|
| 797 |
self, prediction: ToolCall, references: List[ToolCall], task_data: Dict[str, Any]
|
| 798 |
) -> Dict[str, float]:
|
| 799 |
|
| 800 |
-
|
| 801 |
exact_match = float(
|
| 802 |
-
|
| 803 |
)
|
| 804 |
|
| 805 |
-
|
| 806 |
str(prediction["name"]) in [str(reference["name"]) for reference in references]
|
| 807 |
)
|
| 808 |
|
| 809 |
-
|
| 810 |
for reference in references:
|
| 811 |
-
if len(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
|
|
|
|
|
|
|
|
|
|
| 813 |
score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(prediction["arguments"]))
|
| 814 |
-
|
| 815 |
score = 1.0
|
| 816 |
-
|
| 817 |
-
|
|
|
|
|
|
|
|
|
|
| 818 |
|
|
|
|
| 819 |
|
| 820 |
-
parameter_values = 0.0
|
| 821 |
for reference in references:
|
| 822 |
value_matches = 0
|
|
|
|
| 823 |
for key, val in prediction["arguments"].items():
|
| 824 |
try:
|
| 825 |
-
|
|
|
|
|
|
|
| 826 |
value_matches += 1
|
| 827 |
except:
|
| 828 |
pass
|
| 829 |
|
| 830 |
if len(prediction["arguments"]) > 0:
|
| 831 |
-
|
| 832 |
score = value_matches / len(prediction["arguments"])
|
| 833 |
else:
|
| 834 |
score = 1.0
|
| 835 |
-
if score >
|
| 836 |
-
|
| 837 |
|
|
|
|
| 838 |
for tool in task_data["__tools__"]:
|
| 839 |
-
tool
|
| 840 |
-
|
| 841 |
-
for param in tool["parameters"]:
|
| 842 |
-
tool_params_types[param["name"]] = param["type"]
|
| 843 |
-
correct_parameters_types = 0
|
| 844 |
-
for key, value in prediction["arguments"].items():
|
| 845 |
-
typing_type = tool_params_types.get(key, Any)
|
| 846 |
-
if isoftype(value, typing_type):
|
| 847 |
-
correct_parameters_types += 1
|
| 848 |
-
if len(prediction["arguments"]) > 0:
|
| 849 |
-
parameters_types = correct_parameters_types / len(prediction["arguments"])
|
| 850 |
-
else:
|
| 851 |
-
parameters_types = 1.0
|
| 852 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
|
| 854 |
return {
|
| 855 |
self.main_score: exact_match,
|
| 856 |
-
"
|
| 857 |
-
"
|
| 858 |
-
"
|
| 859 |
-
"
|
|
|
|
| 860 |
}
|
| 861 |
|
| 862 |
|
|
@@ -3499,7 +3516,7 @@ class CustomF1(GlobalMetric):
|
|
| 3499 |
class KeyValueExtraction(GlobalMetric):
|
| 3500 |
prediction_type = Dict[str, str]
|
| 3501 |
metric: Metric
|
| 3502 |
-
single_reference_per_prediction =
|
| 3503 |
main_score = ""
|
| 3504 |
|
| 3505 |
def prepare(self):
|
|
@@ -3575,6 +3592,33 @@ class KeyValueExtraction(GlobalMetric):
|
|
| 3575 |
|
| 3576 |
return result
|
| 3577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3578 |
|
| 3579 |
class NER(CustomF1):
|
| 3580 |
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
|
|
|
|
| 63 |
from .random_utils import get_seed
|
| 64 |
from .settings_utils import get_settings
|
| 65 |
from .stream import MultiStream, Stream
|
|
|
|
| 66 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
| 67 |
from .types import ToolCall
|
| 68 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
|
|
|
| 788 |
return result
|
| 789 |
|
| 790 |
class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
| 791 |
+
"""Compares each predicted tool call with list of references tool call."""
|
| 792 |
main_score = "exact_match"
|
| 793 |
reduction = MeanReduction()
|
| 794 |
prediction_type = ToolCall
|
| 795 |
+
_requirements_list = ["jsonschema-rs"]
|
| 796 |
+
|
| 797 |
+
def prepare(self):
|
| 798 |
+
super().prepare()
|
| 799 |
+
import jsonschema_rs
|
| 800 |
+
self._schema = jsonschema_rs
|
| 801 |
|
| 802 |
def map(
|
| 803 |
self, prediction: ToolCall, references: List[ToolCall], task_data: Dict[str, Any]
|
| 804 |
) -> Dict[str, float]:
|
| 805 |
|
|
|
|
| 806 |
exact_match = float(
|
| 807 |
+
json.dumps(prediction, sort_keys=True) in [json.dumps(reference, sort_keys=True) for reference in references]
|
| 808 |
)
|
| 809 |
|
| 810 |
+
tool_name_accuracy = float(
|
| 811 |
str(prediction["name"]) in [str(reference["name"]) for reference in references]
|
| 812 |
)
|
| 813 |
|
| 814 |
+
argument_name_recall = 0.0
|
| 815 |
for reference in references:
|
| 816 |
+
if len(reference["arguments"]) > 0:
|
| 817 |
+
score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(reference["arguments"]))
|
| 818 |
+
else:
|
| 819 |
+
score = 1.0
|
| 820 |
+
if score > argument_name_recall:
|
| 821 |
+
argument_name_recall = score
|
| 822 |
|
| 823 |
+
argument_name_precision = 0.0
|
| 824 |
+
for reference in references:
|
| 825 |
+
if len(prediction["arguments"]) > 0:
|
| 826 |
score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(prediction["arguments"]))
|
| 827 |
+
elif len(reference["arguments"]) == 0:
|
| 828 |
score = 1.0
|
| 829 |
+
else:
|
| 830 |
+
score = 0.0
|
| 831 |
+
if score > argument_name_precision:
|
| 832 |
+
argument_name_precision = score
|
| 833 |
+
|
| 834 |
|
| 835 |
+
argument_value_precision = 0.0
|
| 836 |
|
|
|
|
| 837 |
for reference in references:
|
| 838 |
value_matches = 0
|
| 839 |
+
|
| 840 |
for key, val in prediction["arguments"].items():
|
| 841 |
try:
|
| 842 |
+
predicted = json.dumps(val, sort_keys=True)
|
| 843 |
+
target = json.dumps(reference["arguments"][key], sort_keys=True)
|
| 844 |
+
if predicted == target:
|
| 845 |
value_matches += 1
|
| 846 |
except:
|
| 847 |
pass
|
| 848 |
|
| 849 |
if len(prediction["arguments"]) > 0:
|
|
|
|
| 850 |
score = value_matches / len(prediction["arguments"])
|
| 851 |
else:
|
| 852 |
score = 1.0
|
| 853 |
+
if score > argument_value_precision:
|
| 854 |
+
argument_value_precision = score
|
| 855 |
|
| 856 |
+
parameters = None
|
| 857 |
for tool in task_data["__tools__"]:
|
| 858 |
+
if tool["function"]["name"] == prediction["name"]:
|
| 859 |
+
parameters = tool["function"]["parameters"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 860 |
|
| 861 |
+
if parameters is None:
|
| 862 |
+
argument_schema_validation = 0.0
|
| 863 |
+
else:
|
| 864 |
+
try:
|
| 865 |
+
self._schema.validate(parameters, prediction["arguments"], )
|
| 866 |
+
argument_schema_validation = 1.0
|
| 867 |
+
except self._schema.ValidationError:
|
| 868 |
+
argument_schema_validation = 0.0
|
| 869 |
|
| 870 |
return {
|
| 871 |
self.main_score: exact_match,
|
| 872 |
+
"tool_name_accuracy": tool_name_accuracy,
|
| 873 |
+
"argument_name_recall": argument_name_recall,
|
| 874 |
+
"argument_name_precision": argument_name_precision,
|
| 875 |
+
"argument_value_precision": argument_value_precision,
|
| 876 |
+
"argument_schema_validation": argument_schema_validation,
|
| 877 |
}
|
| 878 |
|
| 879 |
|
|
|
|
| 3516 |
class KeyValueExtraction(GlobalMetric):
|
| 3517 |
prediction_type = Dict[str, str]
|
| 3518 |
metric: Metric
|
| 3519 |
+
single_reference_per_prediction = False
|
| 3520 |
main_score = ""
|
| 3521 |
|
| 3522 |
def prepare(self):
|
|
|
|
| 3592 |
|
| 3593 |
return result
|
| 3594 |
|
| 3595 |
+
class ToolCallKeyValueExtraction(KeyValueExtraction):
|
| 3596 |
+
prediction_type = ToolCall
|
| 3597 |
+
|
| 3598 |
+
def flatten_dict(self,nested_dict, parent_key="", sep="."):
|
| 3599 |
+
flat_dict = {}
|
| 3600 |
+
for k, v in nested_dict.items():
|
| 3601 |
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
| 3602 |
+
if isinstance(v, list):
|
| 3603 |
+
for e in v:
|
| 3604 |
+
if isinstance(e,dict):
|
| 3605 |
+
flat_dict.update(self.flatten_dict(e, new_key, sep=sep))
|
| 3606 |
+
elif isinstance(v, dict):
|
| 3607 |
+
flat_dict.update(self.flatten_dict(v, new_key, sep=sep))
|
| 3608 |
+
else:
|
| 3609 |
+
flat_dict[new_key] = v
|
| 3610 |
+
return flat_dict
|
| 3611 |
+
|
| 3612 |
+
def compute(
|
| 3613 |
+
self,
|
| 3614 |
+
references: List[List[ToolCall]],
|
| 3615 |
+
predictions: List[ToolCall],
|
| 3616 |
+
task_data: List[Dict],
|
| 3617 |
+
) -> dict:
|
| 3618 |
+
return super().compute([[ self.flatten_dict(r) for r in ref ] for ref in references],
|
| 3619 |
+
[ self.flatten_dict(p) for p in predictions],task_data)
|
| 3620 |
+
|
| 3621 |
+
|
| 3622 |
|
| 3623 |
class NER(CustomF1):
|
| 3624 |
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
|
operators.py
CHANGED
|
@@ -283,6 +283,53 @@ class Set(InstanceOperator):
|
|
| 283 |
dict_set(instance, key, value)
|
| 284 |
return instance
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
@deprecation(version="2.0.0", alternative=Set)
|
| 288 |
class AddFields(Set):
|
|
|
|
| 283 |
dict_set(instance, key, value)
|
| 284 |
return instance
|
| 285 |
|
| 286 |
+
def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
|
| 287 |
+
"""Recursively traverses a data structure (dicts and lists), replaces values of target_key using value_map, and removes values listed in value_remove.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
data: The data structure (dict or list) to traverse.
|
| 291 |
+
target_key: The specific key whose value needs to be checked and replaced or removed.
|
| 292 |
+
value_map: A dictionary mapping old values to new values.
|
| 293 |
+
value_remove: A list of values to completely remove if found as values of target_key.
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
The modified data structure. Modification is done in-place.
|
| 297 |
+
"""
|
| 298 |
+
if value_remove is None:
|
| 299 |
+
value_remove = []
|
| 300 |
+
|
| 301 |
+
if isinstance(data, dict):
|
| 302 |
+
keys_to_delete = []
|
| 303 |
+
for key, value in data.items():
|
| 304 |
+
if key == target_key:
|
| 305 |
+
if isinstance(value, list):
|
| 306 |
+
data[key] = [
|
| 307 |
+
value_map.get(item, item)
|
| 308 |
+
for item in value
|
| 309 |
+
if not isinstance(item, dict) and item not in value_remove
|
| 310 |
+
]
|
| 311 |
+
elif isinstance(value, dict):
|
| 312 |
+
pass # Skip or handle dict values if needed
|
| 313 |
+
elif value in value_remove:
|
| 314 |
+
keys_to_delete.append(key)
|
| 315 |
+
elif value in value_map:
|
| 316 |
+
data[key] = value_map[value]
|
| 317 |
+
else:
|
| 318 |
+
recursive_key_value_replace(value, target_key, value_map, value_remove)
|
| 319 |
+
for key in keys_to_delete:
|
| 320 |
+
del data[key]
|
| 321 |
+
elif isinstance(data, list):
|
| 322 |
+
for item in data:
|
| 323 |
+
recursive_key_value_replace(item, target_key, value_map, value_remove)
|
| 324 |
+
return data
|
| 325 |
+
|
| 326 |
+
class RecursiveReplace(InstanceOperator):
|
| 327 |
+
key: str
|
| 328 |
+
map_values: dict
|
| 329 |
+
remove_values: Optional[list] = None
|
| 330 |
+
|
| 331 |
+
def process(self, instance: Dict[str, Any], stream_name: Optional[str] = None) -> Dict[str, Any]:
|
| 332 |
+
return recursive_key_value_replace(instance, self.key, self.map_values, self.remove_values)
|
| 333 |
|
| 334 |
@deprecation(version="2.0.0", alternative=Set)
|
| 335 |
class AddFields(Set):
|
serializers.py
CHANGED
|
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Union
|
|
| 7 |
from .dataclass import AbstractField, Field
|
| 8 |
from .operators import InstanceFieldOperator
|
| 9 |
from .settings_utils import get_constants
|
| 10 |
-
from .tool_calling import convert_to_chat_api_format
|
| 11 |
from .type_utils import isoftype, to_type_string
|
| 12 |
from .types import (
|
| 13 |
Dialog,
|
|
@@ -168,24 +167,20 @@ class MultiDocumentSerializer(DocumentSerializer):
|
|
| 168 |
class ToolsSerializer(SingleTypeSerializer):
|
| 169 |
|
| 170 |
serialized_type = List[Tool]
|
| 171 |
-
_requirements_list: List[str] = ["pydantic"]
|
| 172 |
|
| 173 |
def serialize(self, value: List[Tool], instance: Dict[str, Any]) -> str:
|
| 174 |
if "__tools__" not in instance:
|
| 175 |
instance["__tools__"] = []
|
| 176 |
tool = []
|
| 177 |
for tool in value:
|
| 178 |
-
chat_api_tool = convert_to_chat_api_format(tool=tool)
|
| 179 |
instance["__tools__"].append(
|
| 180 |
-
|
| 181 |
)
|
| 182 |
-
tool["parameters"] = chat_api_tool["function"]["parameters"]
|
| 183 |
return json.dumps(instance["__tools__"], indent=4)
|
| 184 |
|
| 185 |
class ToolCallSerializer(SingleTypeSerializer):
|
| 186 |
|
| 187 |
serialized_type = ToolCall
|
| 188 |
-
_requirements_list: List[str] = ["pydantic"]
|
| 189 |
|
| 190 |
def serialize(self, value: ToolCall, instance: Dict[str, Any]) -> str:
|
| 191 |
return json.dumps(value)
|
|
|
|
| 7 |
from .dataclass import AbstractField, Field
|
| 8 |
from .operators import InstanceFieldOperator
|
| 9 |
from .settings_utils import get_constants
|
|
|
|
| 10 |
from .type_utils import isoftype, to_type_string
|
| 11 |
from .types import (
|
| 12 |
Dialog,
|
|
|
|
| 167 |
class ToolsSerializer(SingleTypeSerializer):
|
| 168 |
|
| 169 |
serialized_type = List[Tool]
|
|
|
|
| 170 |
|
| 171 |
def serialize(self, value: List[Tool], instance: Dict[str, Any]) -> str:
|
| 172 |
if "__tools__" not in instance:
|
| 173 |
instance["__tools__"] = []
|
| 174 |
tool = []
|
| 175 |
for tool in value:
|
|
|
|
| 176 |
instance["__tools__"].append(
|
| 177 |
+
{"type": "function", "function": tool}
|
| 178 |
)
|
|
|
|
| 179 |
return json.dumps(instance["__tools__"], indent=4)
|
| 180 |
|
| 181 |
class ToolCallSerializer(SingleTypeSerializer):
|
| 182 |
|
| 183 |
serialized_type = ToolCall
|
|
|
|
| 184 |
|
| 185 |
def serialize(self, value: ToolCall, instance: Dict[str, Any]) -> str:
|
| 186 |
return json.dumps(value)
|
struct_data_operators.py
CHANGED
|
@@ -43,7 +43,7 @@ from .operators import FieldOperator, InstanceOperator
|
|
| 43 |
from .random_utils import new_random_generator
|
| 44 |
from .serializers import ImageSerializer, TableSerializer
|
| 45 |
from .type_utils import isoftype
|
| 46 |
-
from .types import Table
|
| 47 |
from .utils import recursive_copy
|
| 48 |
|
| 49 |
|
|
@@ -754,6 +754,26 @@ class LoadJson(FieldOperator):
|
|
| 754 |
return json.loads(value, strict=False)
|
| 755 |
|
| 756 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
class DumpJson(FieldOperator):
|
| 758 |
def process_value(self, value: str) -> str:
|
| 759 |
return json.dumps(value)
|
|
|
|
| 43 |
from .random_utils import new_random_generator
|
| 44 |
from .serializers import ImageSerializer, TableSerializer
|
| 45 |
from .type_utils import isoftype
|
| 46 |
+
from .types import Table, ToolCall
|
| 47 |
from .utils import recursive_copy
|
| 48 |
|
| 49 |
|
|
|
|
| 754 |
return json.loads(value, strict=False)
|
| 755 |
|
| 756 |
|
| 757 |
+
class ToolCallPostProcessor(FieldOperator):
|
| 758 |
+
failure_value: Any = None
|
| 759 |
+
allow_failure: bool = False
|
| 760 |
+
def process_value(self, value: str) -> ToolCall:
|
| 761 |
+
if self.allow_failure:
|
| 762 |
+
try:
|
| 763 |
+
result = json.loads(value)
|
| 764 |
+
except json.JSONDecodeError:
|
| 765 |
+
return self.failure_value
|
| 766 |
+
else:
|
| 767 |
+
result = json.loads(value, strict=False)
|
| 768 |
+
if isoftype(result, List[ToolCall]):
|
| 769 |
+
if len(result) > 1:
|
| 770 |
+
UnitxtWarning(f"More than one tool returned from model: {result}" )
|
| 771 |
+
return self.failure_value
|
| 772 |
+
return result[0]
|
| 773 |
+
if not isoftype(result, ToolCall):
|
| 774 |
+
return self.failure_value
|
| 775 |
+
return result
|
| 776 |
+
|
| 777 |
class DumpJson(FieldOperator):
|
| 778 |
def process_value(self, value: str) -> str:
|
| 779 |
return json.dumps(value)
|
tool_calling.py
DELETED
|
@@ -1,119 +0,0 @@
|
|
| 1 |
-
from typing import Any, Dict, List, Type
|
| 2 |
-
|
| 3 |
-
from .operators import FieldOperator
|
| 4 |
-
from .types import Parameter, Tool
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def convert_to_chat_api_format(tool: Tool) -> Dict[str, Any]:
|
| 8 |
-
|
| 9 |
-
from pydantic import create_model
|
| 10 |
-
|
| 11 |
-
field_definitions = {}
|
| 12 |
-
for param in tool["parameters"]:
|
| 13 |
-
param_name = param["name"]
|
| 14 |
-
param_type = param.get("type", Any)
|
| 15 |
-
field_definitions[param_name] = (param_type, ...) # ... means required in Pydantic
|
| 16 |
-
|
| 17 |
-
model = create_model(f"{tool['name']}Params", **field_definitions)
|
| 18 |
-
|
| 19 |
-
schema = model.model_json_schema()
|
| 20 |
-
|
| 21 |
-
return {
|
| 22 |
-
"type": "function",
|
| 23 |
-
"function": {
|
| 24 |
-
"name": tool["name"],
|
| 25 |
-
"description": tool["description"],
|
| 26 |
-
"parameters": schema
|
| 27 |
-
}
|
| 28 |
-
}
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def convert_chat_api_format_to_tool(chat_api_tool: Dict[str, Any]) -> Tool:
|
| 32 |
-
"""Convert a Chat API formatted tool back to the original Tool structure.
|
| 33 |
-
|
| 34 |
-
Args:
|
| 35 |
-
chat_api_tool: A dictionary representing a tool in Chat API format
|
| 36 |
-
|
| 37 |
-
Returns:
|
| 38 |
-
A Tool dictionary with name, description, and parameters
|
| 39 |
-
"""
|
| 40 |
-
# Extract function information
|
| 41 |
-
function_info = chat_api_tool.get("function", {})
|
| 42 |
-
name = function_info.get("name", chat_api_tool.get("name", ""))
|
| 43 |
-
description = function_info.get("description", chat_api_tool.get("description", ""))
|
| 44 |
-
|
| 45 |
-
# Extract parameters from schema
|
| 46 |
-
parameters: List[Parameter] = []
|
| 47 |
-
schema = function_info.get("parameters", chat_api_tool.get("parameters", ""))
|
| 48 |
-
properties = schema.get("properties", {})
|
| 49 |
-
|
| 50 |
-
for param_name, param_schema in properties.items():
|
| 51 |
-
# Map JSON schema type to Python type
|
| 52 |
-
param_type = json_schema_to_python_type(param_schema)
|
| 53 |
-
|
| 54 |
-
parameter: Parameter = {
|
| 55 |
-
"name": param_name,
|
| 56 |
-
"type": param_type
|
| 57 |
-
}
|
| 58 |
-
parameters.append(parameter)
|
| 59 |
-
|
| 60 |
-
# Construct and return the Tool
|
| 61 |
-
tool: Tool = {
|
| 62 |
-
"name": name,
|
| 63 |
-
"description": description,
|
| 64 |
-
"parameters": parameters
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
return tool
|
| 68 |
-
|
| 69 |
-
def json_schema_to_python_type(schema: Dict[str, Any]) -> Type:
|
| 70 |
-
"""Convert JSON schema type to Python type."""
|
| 71 |
-
from typing import Any, Dict, List, Union
|
| 72 |
-
|
| 73 |
-
schema_type = schema.get("type")
|
| 74 |
-
|
| 75 |
-
# Handle simple types
|
| 76 |
-
simple_types = {
|
| 77 |
-
"string": str,
|
| 78 |
-
"integer": int,
|
| 79 |
-
"number": float,
|
| 80 |
-
"boolean": bool,
|
| 81 |
-
"null": type(None)
|
| 82 |
-
}
|
| 83 |
-
|
| 84 |
-
if schema_type in simple_types:
|
| 85 |
-
return simple_types[schema_type]
|
| 86 |
-
|
| 87 |
-
# Handle arrays
|
| 88 |
-
if schema_type == "array":
|
| 89 |
-
items = schema.get("items", {})
|
| 90 |
-
if not items:
|
| 91 |
-
return List[Any]
|
| 92 |
-
|
| 93 |
-
item_type = json_schema_to_python_type(items)
|
| 94 |
-
return List[item_type]
|
| 95 |
-
|
| 96 |
-
# Handle objects
|
| 97 |
-
if schema_type == "object":
|
| 98 |
-
return Dict[str, Any]
|
| 99 |
-
|
| 100 |
-
# Handle unions with anyOf/oneOf
|
| 101 |
-
if "anyOf" in schema or "oneOf" in schema:
|
| 102 |
-
union_schemas = schema.get("anyOf", []) or schema.get("oneOf", [])
|
| 103 |
-
union_types = [json_schema_to_python_type(s) for s in union_schemas]
|
| 104 |
-
# Use Union for Python 3.9+ or create Union using typing module
|
| 105 |
-
return Union[tuple(union_types)] if union_types else Any
|
| 106 |
-
|
| 107 |
-
# Handle references (simplified)
|
| 108 |
-
if "$ref" in schema:
|
| 109 |
-
# In a real implementation, you'd resolve references
|
| 110 |
-
return Any
|
| 111 |
-
|
| 112 |
-
# Default to Any for unrecognized schema types
|
| 113 |
-
return Any
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
class ToTool(FieldOperator):
|
| 117 |
-
|
| 118 |
-
def process_value(self, value: Dict[str, Any]) -> Tool:
|
| 119 |
-
return convert_chat_api_format_to_tool(value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type_utils.py
CHANGED
|
@@ -27,7 +27,7 @@ _registered_types = {
|
|
| 27 |
def register_type(new_type):
|
| 28 |
assert is_new_type(new_type) or is_typed_dict(
|
| 29 |
new_type
|
| 30 |
-
), "Can register only typing.NewType or typing.TypedDict"
|
| 31 |
_registered_types[new_type.__name__] = new_type
|
| 32 |
|
| 33 |
|
|
@@ -489,6 +489,9 @@ def isoftype(object, typing_type):
|
|
| 489 |
if not is_type(typing_type):
|
| 490 |
raise UnsupportedTypeError(typing_type)
|
| 491 |
|
|
|
|
|
|
|
|
|
|
| 492 |
if typing_type is typing.Type:
|
| 493 |
return is_type(object)
|
| 494 |
|
|
@@ -1066,9 +1069,18 @@ def verify_required_schema(
|
|
| 1066 |
f"{class_name} description: {description}"
|
| 1067 |
) from e
|
| 1068 |
|
| 1069 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1070 |
raise ValueError(
|
| 1071 |
-
f"Passed value
|
| 1072 |
f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
|
| 1073 |
f"{class_name} description: {description}"
|
| 1074 |
)
|
|
|
|
| 27 |
def register_type(new_type):
|
| 28 |
assert is_new_type(new_type) or is_typed_dict(
|
| 29 |
new_type
|
| 30 |
+
) or hasattr(new_type, "__verify_type__"), "Can register only typing.NewType or typing.TypedDict or object with __verify_type__ class function"
|
| 31 |
_registered_types[new_type.__name__] = new_type
|
| 32 |
|
| 33 |
|
|
|
|
| 489 |
if not is_type(typing_type):
|
| 490 |
raise UnsupportedTypeError(typing_type)
|
| 491 |
|
| 492 |
+
if hasattr(typing_type, "__verify_type__"):
|
| 493 |
+
return typing_type.__verify_type__(object)
|
| 494 |
+
|
| 495 |
if typing_type is typing.Type:
|
| 496 |
return is_type(object)
|
| 497 |
|
|
|
|
| 1069 |
f"{class_name} description: {description}"
|
| 1070 |
) from e
|
| 1071 |
|
| 1072 |
+
try:
|
| 1073 |
+
valid = isoftype(value, data_type)
|
| 1074 |
+
except Exception as e:
|
| 1075 |
+
raise ValueError(
|
| 1076 |
+
f"Passed value {value} of field '{field_name}' is not "
|
| 1077 |
+
f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
|
| 1078 |
+
f"{class_name} description: {description}\nReason:\n{e}"
|
| 1079 |
+
) from e
|
| 1080 |
+
|
| 1081 |
+
if not valid:
|
| 1082 |
raise ValueError(
|
| 1083 |
+
f"Passed value {value} of field '{field_name}' is not "
|
| 1084 |
f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
|
| 1085 |
f"{class_name} description: {description}"
|
| 1086 |
)
|
types.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import Any, Dict, List, Literal, NewType, Optional,
|
| 2 |
|
| 3 |
from .type_utils import register_type
|
| 4 |
|
|
@@ -51,14 +51,20 @@ class SQLDatabase(TypedDict):
|
|
| 51 |
dbms: Optional[str]
|
| 52 |
data: Optional[Dict[str, Dict]]
|
| 53 |
|
| 54 |
-
class
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
class Tool(TypedDict):
|
| 59 |
name: str
|
| 60 |
description: str
|
| 61 |
-
parameters:
|
| 62 |
|
| 63 |
class ToolCall(TypedDict):
|
| 64 |
name: str
|
|
@@ -76,7 +82,7 @@ register_type(Document)
|
|
| 76 |
register_type(MultiDocument)
|
| 77 |
register_type(RagResponse)
|
| 78 |
register_type(SQLDatabase)
|
| 79 |
-
register_type(Parameter)
|
| 80 |
register_type(Tool)
|
|
|
|
| 81 |
register_type(ToolCall)
|
| 82 |
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Literal, NewType, Optional, TypedDict, Union
|
| 2 |
|
| 3 |
from .type_utils import register_type
|
| 4 |
|
|
|
|
| 51 |
dbms: Optional[str]
|
| 52 |
data: Optional[Dict[str, Dict]]
|
| 53 |
|
| 54 |
+
class JsonSchema:
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def __verify_type__(cls, object):
|
| 58 |
+
if not isinstance(object, dict):
|
| 59 |
+
return False
|
| 60 |
+
import jsonschema_rs
|
| 61 |
+
jsonschema_rs.meta.validate(object)
|
| 62 |
+
return True
|
| 63 |
|
| 64 |
class Tool(TypedDict):
|
| 65 |
name: str
|
| 66 |
description: str
|
| 67 |
+
parameters: JsonSchema
|
| 68 |
|
| 69 |
class ToolCall(TypedDict):
|
| 70 |
name: str
|
|
|
|
| 82 |
register_type(MultiDocument)
|
| 83 |
register_type(RagResponse)
|
| 84 |
register_type(SQLDatabase)
|
|
|
|
| 85 |
register_type(Tool)
|
| 86 |
+
register_type(JsonSchema)
|
| 87 |
register_type(ToolCall)
|
| 88 |
|
version.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
version = "1.
|
|
|
|
| 1 |
+
version = "1.23.0"
|