Upload folder using huggingface_hub
Browse files- api.py +49 -5
- artifact.py +28 -18
- collections_operators.py +60 -2
- dataclass.py +59 -0
- dataset.py +1 -1
- dialog_operators.py +10 -1
- dict_utils.py +1 -1
- error_utils.py +254 -12
- evaluate_cli.py +56 -3
- formats.py +53 -9
- fusion.py +14 -16
- inference.py +229 -126
- llm_as_judge_constants.py +1 -2
- loaders.py +107 -81
- metric.py +1 -1
- metric_utils.py +19 -12
- metrics.py +548 -654
- operator.py +23 -13
- operators.py +79 -58
- processors.py +11 -1
- schema.py +1 -1
- serializers.py +18 -2
- settings_utils.py +4 -0
- struct_data_operators.py +49 -0
- task.py +13 -8
- templates.py +13 -1
- sql_utils.py → text2sql_utils.py +488 -2
- type_utils.py +18 -2
- types.py +56 -26
- version.py +1 -1
api.py
CHANGED
|
@@ -11,6 +11,7 @@ from datasets.exceptions import DatasetGenerationError
|
|
| 11 |
from .artifact import fetch_artifact
|
| 12 |
from .benchmark import Benchmark
|
| 13 |
from .card import TaskCard
|
|
|
|
| 14 |
from .dataset_utils import get_dataset_artifact
|
| 15 |
from .error_utils import UnitxtError
|
| 16 |
from .inference import (
|
|
@@ -149,6 +150,36 @@ def create_dataset(
|
|
| 149 |
return load_dataset(card=card, split=split, **kwargs)
|
| 150 |
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
def _source_to_dataset(
|
| 153 |
source: SourceOperator,
|
| 154 |
split=None,
|
|
@@ -157,22 +188,35 @@ def _source_to_dataset(
|
|
| 157 |
):
|
| 158 |
from .dataset import Dataset as UnitxtDataset
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
stream = source()
|
| 161 |
|
| 162 |
try:
|
| 163 |
ds_builder = UnitxtDataset(
|
| 164 |
dataset_name="unitxt",
|
| 165 |
-
config_name=
|
| 166 |
version=constants.version,
|
| 167 |
)
|
| 168 |
if split is not None:
|
| 169 |
stream = {split: stream[split]}
|
| 170 |
ds_builder._generators = stream
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
if streaming:
|
| 178 |
return ds_builder.as_streaming_dataset(split=split)
|
|
|
|
| 11 |
from .artifact import fetch_artifact
|
| 12 |
from .benchmark import Benchmark
|
| 13 |
from .card import TaskCard
|
| 14 |
+
from .dataclass import to_dict
|
| 15 |
from .dataset_utils import get_dataset_artifact
|
| 16 |
from .error_utils import UnitxtError
|
| 17 |
from .inference import (
|
|
|
|
| 150 |
return load_dataset(card=card, split=split, **kwargs)
|
| 151 |
|
| 152 |
|
| 153 |
+
def object_to_str_without_addresses(obj):
|
| 154 |
+
"""Generates a string representation of a Python object while removing memory address references.
|
| 155 |
+
|
| 156 |
+
This function is useful for creating consistent and comparable string representations of objects
|
| 157 |
+
that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
|
| 158 |
+
between executions. By stripping the memory address, the function ensures that the representation
|
| 159 |
+
is stable and independent of the object's location in memory.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
obj: Any Python object to be converted to a string representation.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
str: A string representation of the object with memory addresses removed if present.
|
| 166 |
+
|
| 167 |
+
Example:
|
| 168 |
+
```python
|
| 169 |
+
class MyClass:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
obj = MyClass()
|
| 173 |
+
print(str(obj)) # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
|
| 174 |
+
print(to_str_without_addresses(obj)) # "<__main__.MyClass object>"
|
| 175 |
+
```
|
| 176 |
+
"""
|
| 177 |
+
obj_str = str(obj)
|
| 178 |
+
if " at 0x" in obj_str:
|
| 179 |
+
obj_str = obj_str.split(" at 0x")[0] + ">"
|
| 180 |
+
return obj_str
|
| 181 |
+
|
| 182 |
+
|
| 183 |
def _source_to_dataset(
|
| 184 |
source: SourceOperator,
|
| 185 |
split=None,
|
|
|
|
| 188 |
):
|
| 189 |
from .dataset import Dataset as UnitxtDataset
|
| 190 |
|
| 191 |
+
# Generate a unique signature for the source
|
| 192 |
+
source_signature = json.dumps(
|
| 193 |
+
to_dict(source, object_to_str_without_addresses), sort_keys=True
|
| 194 |
+
)
|
| 195 |
+
config_name = "recipe-" + short_hex_hash(source_signature)
|
| 196 |
+
# Obtain data stream from the source
|
| 197 |
stream = source()
|
| 198 |
|
| 199 |
try:
|
| 200 |
ds_builder = UnitxtDataset(
|
| 201 |
dataset_name="unitxt",
|
| 202 |
+
config_name=config_name, # Dictate the cache name
|
| 203 |
version=constants.version,
|
| 204 |
)
|
| 205 |
if split is not None:
|
| 206 |
stream = {split: stream[split]}
|
| 207 |
ds_builder._generators = stream
|
| 208 |
|
| 209 |
+
try:
|
| 210 |
+
ds_builder.download_and_prepare(
|
| 211 |
+
verification_mode="no_checks",
|
| 212 |
+
download_mode=None if use_cache else "force_redownload",
|
| 213 |
+
)
|
| 214 |
+
except DatasetGenerationError as e:
|
| 215 |
+
if e.__cause__:
|
| 216 |
+
raise e.__cause__ from None
|
| 217 |
+
if e.__context__:
|
| 218 |
+
raise e.__context__ from None
|
| 219 |
+
raise
|
| 220 |
|
| 221 |
if streaming:
|
| 222 |
return ds_builder.as_streaming_dataset(split=split)
|
artifact.py
CHANGED
|
@@ -16,13 +16,13 @@ from .dataclass import (
|
|
| 16 |
NonPositionalField,
|
| 17 |
fields,
|
| 18 |
)
|
| 19 |
-
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
| 20 |
from .logging_utils import get_logger
|
| 21 |
from .parsing_utils import (
|
| 22 |
separate_inside_and_outside_square_brackets,
|
| 23 |
)
|
| 24 |
from .settings_utils import get_constants, get_settings
|
| 25 |
-
from .text_utils import camel_to_snake_case, is_camel_case
|
| 26 |
from .type_utils import isoftype, issubtype
|
| 27 |
from .utils import (
|
| 28 |
artifacts_json_cache,
|
|
@@ -342,8 +342,10 @@ class Artifact(Dataclass):
|
|
| 342 |
self.verify_data_classification_policy()
|
| 343 |
self.prepare_args()
|
| 344 |
if not settings.skip_artifacts_prepare_and_verify:
|
| 345 |
-
self
|
| 346 |
-
|
|
|
|
|
|
|
| 347 |
|
| 348 |
def _to_raw_dict(self):
|
| 349 |
return {
|
|
@@ -367,11 +369,14 @@ class Artifact(Dataclass):
|
|
| 367 |
|
| 368 |
def to_json(self):
|
| 369 |
data = self.to_dict()
|
|
|
|
| 370 |
return json_dump(data)
|
| 371 |
|
| 372 |
def to_yaml(self):
|
|
|
|
|
|
|
| 373 |
data = self.to_dict()
|
| 374 |
-
return
|
| 375 |
|
| 376 |
def serialize(self):
|
| 377 |
if self.__id__ is not None:
|
|
@@ -449,20 +454,25 @@ class Artifact(Dataclass):
|
|
| 449 |
)
|
| 450 |
return instance
|
| 451 |
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
|
|
|
| 455 |
):
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
return instance
|
| 468 |
|
|
|
|
| 16 |
NonPositionalField,
|
| 17 |
fields,
|
| 18 |
)
|
| 19 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
|
| 20 |
from .logging_utils import get_logger
|
| 21 |
from .parsing_utils import (
|
| 22 |
separate_inside_and_outside_square_brackets,
|
| 23 |
)
|
| 24 |
from .settings_utils import get_constants, get_settings
|
| 25 |
+
from .text_utils import camel_to_snake_case, is_camel_case
|
| 26 |
from .type_utils import isoftype, issubtype
|
| 27 |
from .utils import (
|
| 28 |
artifacts_json_cache,
|
|
|
|
| 342 |
self.verify_data_classification_policy()
|
| 343 |
self.prepare_args()
|
| 344 |
if not settings.skip_artifacts_prepare_and_verify:
|
| 345 |
+
with error_context(self, action="Prepare Object"):
|
| 346 |
+
self.prepare()
|
| 347 |
+
with error_context(self, action="Verify Object"):
|
| 348 |
+
self.verify()
|
| 349 |
|
| 350 |
def _to_raw_dict(self):
|
| 351 |
return {
|
|
|
|
| 369 |
|
| 370 |
def to_json(self):
|
| 371 |
data = self.to_dict()
|
| 372 |
+
|
| 373 |
return json_dump(data)
|
| 374 |
|
| 375 |
def to_yaml(self):
|
| 376 |
+
import yaml
|
| 377 |
+
|
| 378 |
data = self.to_dict()
|
| 379 |
+
return yaml.dump(data)
|
| 380 |
|
| 381 |
def serialize(self):
|
| 382 |
if self.__id__ is not None:
|
|
|
|
| 454 |
)
|
| 455 |
return instance
|
| 456 |
|
| 457 |
+
with error_context(
|
| 458 |
+
self,
|
| 459 |
+
action="Sensitive Data Verification",
|
| 460 |
+
help="https://www.unitxt.ai/en/latest/docs/data_classification_policy.html",
|
| 461 |
):
|
| 462 |
+
if not any(
|
| 463 |
+
data_classification in data_classification_policy
|
| 464 |
+
for data_classification in instance_data_classification
|
| 465 |
+
):
|
| 466 |
+
raise UnitxtError(
|
| 467 |
+
f"The instance '{instance} 'has the following data classification policy "
|
| 468 |
+
f"'{instance_data_classification}', however, the artifact '{name}' "
|
| 469 |
+
f"is only configured to support the data with classification "
|
| 470 |
+
f"'{data_classification_policy}'. To enable this either change "
|
| 471 |
+
f"the 'data_classification_policy' attribute of the artifact, "
|
| 472 |
+
f"or modify the environment variable "
|
| 473 |
+
f"'UNITXT_DATA_CLASSIFICATION_POLICY' accordingly.",
|
| 474 |
+
Documentation.DATA_CLASSIFICATION_POLICY,
|
| 475 |
+
)
|
| 476 |
|
| 477 |
return instance
|
| 478 |
|
collections_operators.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
|
|
| 1 |
from typing import Any, Dict, Generator, List, Optional
|
| 2 |
|
| 3 |
from .dict_utils import dict_get, dict_set
|
|
|
|
| 4 |
from .operators import FieldOperator, StreamOperator
|
| 5 |
from .stream import Stream
|
| 6 |
from .utils import recursive_shallow_copy
|
|
@@ -13,11 +15,52 @@ class Dictify(FieldOperator):
|
|
| 13 |
return dict(zip(self.with_keys, tup))
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class DictToTuplesList(FieldOperator):
|
| 17 |
def process_value(self, dic: Dict) -> Any:
|
| 18 |
return list(dic.items())
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
class Wrap(FieldOperator):
|
| 22 |
inside: str
|
| 23 |
|
|
@@ -64,6 +107,13 @@ class Get(FieldOperator):
|
|
| 64 |
return collection[self.item]
|
| 65 |
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
class DuplicateByList(StreamOperator):
|
| 68 |
field: str
|
| 69 |
to_field: Optional[str] = None
|
|
@@ -91,12 +141,16 @@ class DuplicateBySubLists(StreamOperator):
|
|
| 91 |
field: str
|
| 92 |
to_field: Optional[str] = None
|
| 93 |
use_deep_copy: bool = False
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 96 |
to_field = self.field if self.to_field is None else self.to_field
|
| 97 |
for instance in stream:
|
| 98 |
-
elements = instance
|
| 99 |
-
|
|
|
|
| 100 |
if self.use_deep_copy:
|
| 101 |
instance_copy = recursive_shallow_copy(instance)
|
| 102 |
instance_copy[to_field] = elements[:i]
|
|
@@ -109,6 +163,10 @@ class DuplicateBySubLists(StreamOperator):
|
|
| 109 |
yield instance_copy
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
class GetLength(FieldOperator):
|
| 113 |
def process_value(self, collection: Any) -> Any:
|
| 114 |
return len(collection)
|
|
|
|
| 1 |
+
from itertools import zip_longest
|
| 2 |
from typing import Any, Dict, Generator, List, Optional
|
| 3 |
|
| 4 |
from .dict_utils import dict_get, dict_set
|
| 5 |
+
from .operator import InstanceOperator
|
| 6 |
from .operators import FieldOperator, StreamOperator
|
| 7 |
from .stream import Stream
|
| 8 |
from .utils import recursive_shallow_copy
|
|
|
|
| 15 |
return dict(zip(self.with_keys, tup))
|
| 16 |
|
| 17 |
|
| 18 |
+
class Zip(InstanceOperator):
|
| 19 |
+
fields: List[str]
|
| 20 |
+
to_field: str
|
| 21 |
+
|
| 22 |
+
def zip(self, values):
|
| 23 |
+
return list(zip(*values))
|
| 24 |
+
|
| 25 |
+
def process(
|
| 26 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 27 |
+
) -> Dict[str, Any]:
|
| 28 |
+
values = []
|
| 29 |
+
for field in self.fields:
|
| 30 |
+
values.append(dict_get(instance, field))
|
| 31 |
+
dict_set(instance, self.to_field, self.zip(values))
|
| 32 |
+
return instance
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ZipLongest(Zip):
|
| 36 |
+
fields: List[str]
|
| 37 |
+
fill_value: Any = None
|
| 38 |
+
|
| 39 |
+
def zip(self, values):
|
| 40 |
+
return list(zip_longest(*values, fillvalue=self.fill_value))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
class DictToTuplesList(FieldOperator):
|
| 44 |
def process_value(self, dic: Dict) -> Any:
|
| 45 |
return list(dic.items())
|
| 46 |
|
| 47 |
|
| 48 |
+
def flatten(container):
|
| 49 |
+
def _flat_gen(x):
|
| 50 |
+
for item in x:
|
| 51 |
+
if isinstance(item, (list, tuple)):
|
| 52 |
+
yield from _flat_gen(item)
|
| 53 |
+
else:
|
| 54 |
+
yield item
|
| 55 |
+
|
| 56 |
+
return type(container)(_flat_gen(container))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Flatten(FieldOperator):
|
| 60 |
+
def process_value(self, value: Any) -> Any:
|
| 61 |
+
return flatten(value)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
class Wrap(FieldOperator):
|
| 65 |
inside: str
|
| 66 |
|
|
|
|
| 107 |
return collection[self.item]
|
| 108 |
|
| 109 |
|
| 110 |
+
class Pop(FieldOperator):
|
| 111 |
+
item: Any = None
|
| 112 |
+
|
| 113 |
+
def process_value(self, collection: Any) -> Any:
|
| 114 |
+
return collection.pop(self.item)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
class DuplicateByList(StreamOperator):
|
| 118 |
field: str
|
| 119 |
to_field: Optional[str] = None
|
|
|
|
| 141 |
field: str
|
| 142 |
to_field: Optional[str] = None
|
| 143 |
use_deep_copy: bool = False
|
| 144 |
+
start: int = 1
|
| 145 |
+
end: int = 0
|
| 146 |
+
step: int = 1
|
| 147 |
|
| 148 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 149 |
to_field = self.field if self.to_field is None else self.to_field
|
| 150 |
for instance in stream:
|
| 151 |
+
elements = dict_get(instance, self.field)
|
| 152 |
+
end = len(elements) + 1 + self.end
|
| 153 |
+
for i in range(self.start, end, self.step):
|
| 154 |
if self.use_deep_copy:
|
| 155 |
instance_copy = recursive_shallow_copy(instance)
|
| 156 |
instance_copy[to_field] = elements[:i]
|
|
|
|
| 163 |
yield instance_copy
|
| 164 |
|
| 165 |
|
| 166 |
+
class ExplodeSubLists(DuplicateBySubLists):
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
|
| 170 |
class GetLength(FieldOperator):
|
| 171 |
def process_value(self, collection: Any) -> Any:
|
| 172 |
return len(collection)
|
dataclass.py
CHANGED
|
@@ -297,6 +297,65 @@ def _asdict_inner(obj):
|
|
| 297 |
return copy.deepcopy(obj)
|
| 298 |
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
class DataclassMeta(ABCMeta):
|
| 301 |
"""Metaclass for Dataclass.
|
| 302 |
|
|
|
|
| 297 |
return copy.deepcopy(obj)
|
| 298 |
|
| 299 |
|
| 300 |
+
def to_dict(obj, func=copy.deepcopy, _visited=None):
|
| 301 |
+
"""Recursively converts an object into a dictionary representation while avoiding infinite recursion due to circular references.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
obj: Any Python object to be converted into a dictionary-like structure.
|
| 305 |
+
func (Callable, optional): A function applied to non-iterable objects. Defaults to `copy.deepcopy`.
|
| 306 |
+
_visited (set, optional): A set of object IDs used to track visited objects and prevent infinite recursion.
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
dict: A dictionary representation of the input object, with supported collections and dataclasses
|
| 310 |
+
recursively processed.
|
| 311 |
+
|
| 312 |
+
Notes:
|
| 313 |
+
- Supports dataclasses, named tuples, lists, tuples, and dictionaries.
|
| 314 |
+
- Circular references are detected using object IDs and replaced by `func(obj)`.
|
| 315 |
+
- Named tuples retain their original type instead of being converted to dictionaries.
|
| 316 |
+
"""
|
| 317 |
+
# Initialize visited set on first call
|
| 318 |
+
if _visited is None:
|
| 319 |
+
_visited = set()
|
| 320 |
+
|
| 321 |
+
# Get object ID to track visited objects
|
| 322 |
+
obj_id = id(obj)
|
| 323 |
+
|
| 324 |
+
# If we've seen this object before, return a placeholder to avoid infinite recursion
|
| 325 |
+
if obj_id in _visited:
|
| 326 |
+
return func(obj)
|
| 327 |
+
|
| 328 |
+
# For mutable objects, add to visited set before recursing
|
| 329 |
+
if (
|
| 330 |
+
isinstance(obj, (dict, list))
|
| 331 |
+
or is_dataclass(obj)
|
| 332 |
+
or (isinstance(obj, tuple) and hasattr(obj, "_fields"))
|
| 333 |
+
):
|
| 334 |
+
_visited.add(obj_id)
|
| 335 |
+
|
| 336 |
+
if is_dataclass(obj):
|
| 337 |
+
return {
|
| 338 |
+
field.name: to_dict(getattr(obj, field.name), func, _visited)
|
| 339 |
+
for field in fields(obj)
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
|
| 343 |
+
return type(obj)(*[to_dict(v, func, _visited) for v in obj])
|
| 344 |
+
|
| 345 |
+
if isinstance(obj, (list, tuple)):
|
| 346 |
+
return type(obj)([to_dict(v, func, _visited) for v in obj])
|
| 347 |
+
|
| 348 |
+
if isinstance(obj, dict):
|
| 349 |
+
return type(obj)(
|
| 350 |
+
{
|
| 351 |
+
to_dict(k, func, _visited): to_dict(v, func, _visited)
|
| 352 |
+
for k, v in obj.items()
|
| 353 |
+
}
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
return func(obj)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
class DataclassMeta(ABCMeta):
|
| 360 |
"""Metaclass for Dataclass.
|
| 361 |
|
dataset.py
CHANGED
|
@@ -59,7 +59,6 @@ from .settings_utils import get_constants
|
|
| 59 |
from .span_lableing_operators import __file__ as _
|
| 60 |
from .split_utils import __file__ as _
|
| 61 |
from .splitters import __file__ as _
|
| 62 |
-
from .sql_utils import __file__ as _
|
| 63 |
from .standard import __file__ as _
|
| 64 |
from .stream import __file__ as _
|
| 65 |
from .stream_operators import __file__ as _
|
|
@@ -68,6 +67,7 @@ from .struct_data_operators import __file__ as _
|
|
| 68 |
from .system_prompts import __file__ as _
|
| 69 |
from .task import __file__ as _
|
| 70 |
from .templates import __file__ as _
|
|
|
|
| 71 |
from .text_utils import __file__ as _
|
| 72 |
from .type_utils import __file__ as _
|
| 73 |
from .types import __file__ as _
|
|
|
|
| 59 |
from .span_lableing_operators import __file__ as _
|
| 60 |
from .split_utils import __file__ as _
|
| 61 |
from .splitters import __file__ as _
|
|
|
|
| 62 |
from .standard import __file__ as _
|
| 63 |
from .stream import __file__ as _
|
| 64 |
from .stream_operators import __file__ as _
|
|
|
|
| 67 |
from .system_prompts import __file__ as _
|
| 68 |
from .task import __file__ as _
|
| 69 |
from .templates import __file__ as _
|
| 70 |
+
from .text2sql_utils import __file__ as _
|
| 71 |
from .text_utils import __file__ as _
|
| 72 |
from .type_utils import __file__ as _
|
| 73 |
from .types import __file__ as _
|
dialog_operators.py
CHANGED
|
@@ -17,7 +17,16 @@ The format of the dialog is:
|
|
| 17 |
from typing import Any, Dict, List, Optional
|
| 18 |
|
| 19 |
from .formats import SystemFormat
|
| 20 |
-
from .operators import InstanceFieldOperator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class SerializeDialog(InstanceFieldOperator):
|
|
|
|
| 17 |
from typing import Any, Dict, List, Optional
|
| 18 |
|
| 19 |
from .formats import SystemFormat
|
| 20 |
+
from .operators import FieldOperator, InstanceFieldOperator
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ToDialog(FieldOperator):
|
| 24 |
+
def process_value(self, value: Any) -> Any:
|
| 25 |
+
dialog = []
|
| 26 |
+
for question, answer in value:
|
| 27 |
+
dialog.append({"role": "user", "content": question})
|
| 28 |
+
dialog.append({"role": "agent", "content": answer})
|
| 29 |
+
return dialog
|
| 30 |
|
| 31 |
|
| 32 |
class SerializeDialog(InstanceFieldOperator):
|
dict_utils.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any, List, Tuple
|
|
| 3 |
|
| 4 |
from .text_utils import to_pretty_string
|
| 5 |
|
| 6 |
-
indx = re.compile(r"
|
| 7 |
|
| 8 |
|
| 9 |
def is_index(string):
|
|
|
|
| 3 |
|
| 4 |
from .text_utils import to_pretty_string
|
| 5 |
|
| 6 |
+
indx = re.compile(r"^-?\d+$")
|
| 7 |
|
| 8 |
|
| 9 |
def is_index(string):
|
error_utils.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from .logging_utils import get_logger
|
|
|
|
| 4 |
|
|
|
|
| 5 |
logger = get_logger()
|
| 6 |
|
| 7 |
|
|
@@ -29,12 +33,9 @@ class UnitxtError(Exception):
|
|
| 29 |
"""Exception raised for Unitxt errors.
|
| 30 |
|
| 31 |
Args:
|
| 32 |
-
message (str):
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
relative path to additional documentation on web
|
| 36 |
-
If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
|
| 37 |
-
|
| 38 |
"""
|
| 39 |
|
| 40 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
|
@@ -47,14 +48,255 @@ class UnitxtWarning:
|
|
| 47 |
"""Object to format warning message to log.
|
| 48 |
|
| 49 |
Args:
|
| 50 |
-
message (str):
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
relative path to additional documentation on web
|
| 54 |
-
If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
|
| 55 |
"""
|
| 56 |
|
| 57 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
| 58 |
if additional_info_id is not None:
|
| 59 |
message += additional_info(additional_info_id)
|
| 60 |
logger.warning(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
from typing import Any, Optional
|
| 4 |
|
| 5 |
from .logging_utils import get_logger
|
| 6 |
+
from .settings_utils import get_constants
|
| 7 |
|
| 8 |
+
constants = get_constants()
|
| 9 |
logger = get_logger()
|
| 10 |
|
| 11 |
|
|
|
|
| 33 |
"""Exception raised for Unitxt errors.
|
| 34 |
|
| 35 |
Args:
|
| 36 |
+
message (str): explanation of the error
|
| 37 |
+
additional_info_id (Optional[str]): relative path to additional documentation on web
|
| 38 |
+
If set, should be one of the DOCUMENTATION_* constants in the error_utils.py file.
|
|
|
|
|
|
|
|
|
|
| 39 |
"""
|
| 40 |
|
| 41 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
|
|
|
| 48 |
"""Object to format warning message to log.
|
| 49 |
|
| 50 |
Args:
|
| 51 |
+
message (str): explanation of the warning
|
| 52 |
+
additional_info_id (Optional[str]): relative path to additional documentation on web
|
| 53 |
+
If set, should be one of the DOCUMENTATION_* constants in the error_utils.py file.
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
|
| 56 |
def __init__(self, message: str, additional_info_id: Optional[str] = None):
|
| 57 |
if additional_info_id is not None:
|
| 58 |
message += additional_info(additional_info_id)
|
| 59 |
logger.warning(message)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
context_block_title = "🦄 Unitxt Error Context"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _visible_length(text: str) -> int:
|
| 66 |
+
import unicodedata
|
| 67 |
+
|
| 68 |
+
ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\]8;;[^\x1b]*\x1b\\")
|
| 69 |
+
clean_text = ansi_escape.sub("", text)
|
| 70 |
+
width = 0
|
| 71 |
+
for char in clean_text:
|
| 72 |
+
if (
|
| 73 |
+
unicodedata.east_asian_width(char) in ("F", "W")
|
| 74 |
+
or 0x1F300 <= ord(char) <= 0x1F9FF
|
| 75 |
+
):
|
| 76 |
+
width += 2
|
| 77 |
+
else:
|
| 78 |
+
width += 1
|
| 79 |
+
return width
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _make_object_clickable(
|
| 83 |
+
full_obj_name: str, display_name: Optional[str] = None
|
| 84 |
+
) -> str:
|
| 85 |
+
import os
|
| 86 |
+
|
| 87 |
+
if display_name is None:
|
| 88 |
+
display_name = full_obj_name.split(".")[-1]
|
| 89 |
+
if full_obj_name.startswith("unitxt."):
|
| 90 |
+
parts = full_obj_name.split(".")
|
| 91 |
+
if len(parts) >= 2:
|
| 92 |
+
module_path = ".".join(parts[:2])
|
| 93 |
+
doc_url = f"{Documentation.URL}{module_path}.html#{full_obj_name}"
|
| 94 |
+
if (
|
| 95 |
+
os.environ.get("TERM_PROGRAM") in ["iTerm.app", "vscode"]
|
| 96 |
+
or os.environ.get("TERMINAL_EMULATOR") == "JetBrains-JediTerm"
|
| 97 |
+
):
|
| 98 |
+
return f"\033]8;;{doc_url}\033\\{display_name}\033]8;;\033\\"
|
| 99 |
+
return f"{display_name} ({doc_url})"
|
| 100 |
+
return display_name
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_existing_context(error: Exception):
|
| 104 |
+
"""Extract existing context from an error if it exists."""
|
| 105 |
+
if hasattr(error, "__error_context__"):
|
| 106 |
+
existing = error.__error_context__
|
| 107 |
+
return (
|
| 108 |
+
existing["original_message"],
|
| 109 |
+
existing["context_object"],
|
| 110 |
+
existing["context"],
|
| 111 |
+
)
|
| 112 |
+
return str(error), None, {}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _format_object_context(obj: Any) -> Optional[str]:
|
| 116 |
+
"""Format an object for display in error context."""
|
| 117 |
+
if obj is None:
|
| 118 |
+
return None
|
| 119 |
+
if hasattr(obj, "__class__"):
|
| 120 |
+
class_name = obj.__class__.__name__
|
| 121 |
+
module_name = getattr(obj.__class__, "__module__", "")
|
| 122 |
+
else:
|
| 123 |
+
obj_type = type(obj)
|
| 124 |
+
class_name = obj_type.__name__
|
| 125 |
+
module_name = getattr(obj_type, "__module__", "")
|
| 126 |
+
if module_name:
|
| 127 |
+
full_name = f"{module_name}.{class_name}"
|
| 128 |
+
clickable_object = _make_object_clickable(full_name, class_name)
|
| 129 |
+
return f"Object: {clickable_object}"
|
| 130 |
+
return f"Object: {class_name}"
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _make_clickable_link(url: str) -> str:
|
| 134 |
+
"""Create a clickable terminal link."""
|
| 135 |
+
import os
|
| 136 |
+
|
| 137 |
+
if (
|
| 138 |
+
os.environ.get("TERM_PROGRAM") in ["iTerm.app", "vscode"]
|
| 139 |
+
or os.environ.get("TERMINAL_EMULATOR") == "JetBrains-JediTerm"
|
| 140 |
+
):
|
| 141 |
+
return f"\033]8;;{url}\033\\link\033]8;;\033\\"
|
| 142 |
+
return url
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _format_help_context(help_docs) -> list:
|
| 146 |
+
"""Format help documentation into context parts."""
|
| 147 |
+
parts = []
|
| 148 |
+
if isinstance(help_docs, str):
|
| 149 |
+
parts.append(f"Help: {_make_clickable_link(help_docs)}")
|
| 150 |
+
elif isinstance(help_docs, dict):
|
| 151 |
+
for label, url in help_docs.items():
|
| 152 |
+
parts.append(f"Help ({label}): {_make_clickable_link(url)}")
|
| 153 |
+
elif isinstance(help_docs, list):
|
| 154 |
+
for item in help_docs:
|
| 155 |
+
if isinstance(item, dict) and len(item) == 1:
|
| 156 |
+
label, url = next(iter(item.items()))
|
| 157 |
+
parts.append(f"Help ({label}): {_make_clickable_link(url)}")
|
| 158 |
+
elif isinstance(item, str):
|
| 159 |
+
parts.append(f"Help: {_make_clickable_link(item)}")
|
| 160 |
+
return parts
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _build_context_parts(context_object: Any, context: dict) -> list:
|
| 164 |
+
"""Build the list of context information parts."""
|
| 165 |
+
parts = []
|
| 166 |
+
ordered_keys = [
|
| 167 |
+
"Python",
|
| 168 |
+
"Unitxt",
|
| 169 |
+
"Stage",
|
| 170 |
+
"Stream",
|
| 171 |
+
"Index",
|
| 172 |
+
"Instance",
|
| 173 |
+
"Object",
|
| 174 |
+
"Action",
|
| 175 |
+
]
|
| 176 |
+
processed_keys = set()
|
| 177 |
+
|
| 178 |
+
for desired_key in ordered_keys:
|
| 179 |
+
for actual_key in context.keys():
|
| 180 |
+
if actual_key.lower() == desired_key.lower():
|
| 181 |
+
value = (
|
| 182 |
+
"unknown" if context[actual_key] is None else context[actual_key]
|
| 183 |
+
)
|
| 184 |
+
parts.append(f"{actual_key.replace('_', ' ').title()}: {value}")
|
| 185 |
+
processed_keys.add(actual_key)
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
if not any(key.lower() == "object" for key in processed_keys):
|
| 189 |
+
obj_context = _format_object_context(context_object)
|
| 190 |
+
if obj_context:
|
| 191 |
+
parts.append(obj_context)
|
| 192 |
+
|
| 193 |
+
processed_keys.add("help")
|
| 194 |
+
for key, value in context.items():
|
| 195 |
+
if key not in processed_keys:
|
| 196 |
+
value = "unknown" if value is None else value
|
| 197 |
+
parts.append(f"{key.replace('_', ' ').title()}: {value}")
|
| 198 |
+
|
| 199 |
+
if "help" in context:
|
| 200 |
+
parts.extend(_format_help_context(context["help"]))
|
| 201 |
+
else:
|
| 202 |
+
parts.append(f"Help: {_make_clickable_link(Documentation.URL)}")
|
| 203 |
+
|
| 204 |
+
return parts
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _create_context_box(parts: list) -> str:
|
| 208 |
+
"""Create a formatted box containing context information."""
|
| 209 |
+
if not parts:
|
| 210 |
+
return ""
|
| 211 |
+
max_width = (
|
| 212 |
+
max(
|
| 213 |
+
_visible_length(context_block_title),
|
| 214 |
+
max(_visible_length(part) for part in parts),
|
| 215 |
+
)
|
| 216 |
+
+ 4
|
| 217 |
+
)
|
| 218 |
+
top_line = "┌" + "─" * max_width + "┐"
|
| 219 |
+
bottom_line = "└" + "─" * max_width + "┘"
|
| 220 |
+
lines = [top_line]
|
| 221 |
+
lines.append(
|
| 222 |
+
f"│ {context_block_title}{' ' * (max_width - _visible_length(context_block_title) - 1)}│"
|
| 223 |
+
)
|
| 224 |
+
lines.append(f"│ {'-' * (max_width - 2)} │")
|
| 225 |
+
for part in parts:
|
| 226 |
+
padding = " " * (max_width - _visible_length(part) - 4)
|
| 227 |
+
lines.append(f"│ - {part}{padding}│")
|
| 228 |
+
lines.append(bottom_line)
|
| 229 |
+
return "\n".join(lines)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _store_context_attributes(
|
| 233 |
+
error: Exception, context_object: Any, context: dict, original_message: str
|
| 234 |
+
):
|
| 235 |
+
"""Store context information in error attributes."""
|
| 236 |
+
error.__error_context__ = {
|
| 237 |
+
"context_object": context_object,
|
| 238 |
+
"context": context,
|
| 239 |
+
"original_message": original_message,
|
| 240 |
+
}
|
| 241 |
+
try:
|
| 242 |
+
error.original_error = type(error)(original_message)
|
| 243 |
+
except (TypeError, ValueError):
|
| 244 |
+
error.original_error = Exception(original_message)
|
| 245 |
+
error.context_object = context_object
|
| 246 |
+
error.context = context
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _add_context_to_exception(
|
| 250 |
+
original_error: Exception, context_object: Any = None, **context
|
| 251 |
+
):
|
| 252 |
+
"""Add context information to an exception by modifying its message."""
|
| 253 |
+
original_message, existing_object, existing_context = _get_existing_context(
|
| 254 |
+
original_error
|
| 255 |
+
)
|
| 256 |
+
final_context_object = existing_object or context_object
|
| 257 |
+
final_context = {
|
| 258 |
+
"Unitxt": constants.version,
|
| 259 |
+
"Python": constants.python,
|
| 260 |
+
**existing_context,
|
| 261 |
+
**context,
|
| 262 |
+
}
|
| 263 |
+
context_parts = _build_context_parts(final_context_object, final_context)
|
| 264 |
+
context_message = _create_context_box(context_parts)
|
| 265 |
+
_store_context_attributes(
|
| 266 |
+
original_error, final_context_object, final_context, original_message
|
| 267 |
+
)
|
| 268 |
+
if context_parts:
|
| 269 |
+
formatted_message = f"\n{context_message}\n\n{original_message}"
|
| 270 |
+
original_error.args = (formatted_message,)
|
| 271 |
+
else:
|
| 272 |
+
original_error.args = (original_message,)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@contextmanager
|
| 276 |
+
def error_context(context_object: Any = None, **context):
|
| 277 |
+
"""Context manager that catches exceptions and re-raises them with additional context.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
context_object: The object being processed (optional)
|
| 281 |
+
**context: Any additional context to include in the error message.
|
| 282 |
+
You can provide any key-value pairs that help identify where the error occurred.
|
| 283 |
+
|
| 284 |
+
Special context keys:
|
| 285 |
+
- help: Documentation links to help with the error.
|
| 286 |
+
Can be a string (single URL), dict (label: URL), or list of URLs/dicts.
|
| 287 |
+
|
| 288 |
+
Examples:
|
| 289 |
+
with error_context(self, operation="validation", item_id=42):
|
| 290 |
+
result = process_item(item)
|
| 291 |
+
|
| 292 |
+
with error_context(operation="schema_validation", help="https://docs.example.com/schema"):
|
| 293 |
+
validate_schema(data)
|
| 294 |
+
|
| 295 |
+
with error_context(processor, step="preprocessing", batch_size=32):
|
| 296 |
+
results = process_batch(batch)
|
| 297 |
+
"""
|
| 298 |
+
try:
|
| 299 |
+
yield
|
| 300 |
+
except Exception as e:
|
| 301 |
+
_add_context_to_exception(e, context_object, **context)
|
| 302 |
+
raise
|
evaluate_cli.py
CHANGED
|
@@ -298,9 +298,13 @@ def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
|
|
| 298 |
dataset_query=task_str, **overwrite_args
|
| 299 |
)
|
| 300 |
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
-
test_dataset = _source_to_dataset(
|
| 304 |
logger.info(
|
| 305 |
f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
|
| 306 |
)
|
|
@@ -414,6 +418,8 @@ def initialize_inference_engine(
|
|
| 414 |
chat_kwargs_dict=chat_kwargs_dict,
|
| 415 |
)
|
| 416 |
|
|
|
|
|
|
|
| 417 |
# --- Remote Model (CrossProviderInferenceEngine) ---
|
| 418 |
elif args.model.lower() == "cross_provider":
|
| 419 |
if "model_name" not in model_args_dict:
|
|
@@ -444,6 +450,9 @@ def initialize_inference_engine(
|
|
| 444 |
model=remote_model_name,
|
| 445 |
**model_args_dict,
|
| 446 |
)
|
|
|
|
|
|
|
|
|
|
| 447 |
else:
|
| 448 |
# This case should not be reached due to argparse choices
|
| 449 |
logger.error(
|
|
@@ -682,7 +691,7 @@ def _save_results_to_disk(
|
|
| 682 |
|
| 683 |
# prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
|
| 684 |
|
| 685 |
-
timestamp = datetime.now().strftime("%Y-%m-%dT%H
|
| 686 |
|
| 687 |
results_path = prepend_timestamp_to_path(results_path, timestamp)
|
| 688 |
samples_path = prepend_timestamp_to_path(samples_path, timestamp)
|
|
@@ -825,5 +834,49 @@ def main():
|
|
| 825 |
logger.info("Unitxt Evaluation CLI finished successfully.")
|
| 826 |
|
| 827 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
if __name__ == "__main__":
|
| 829 |
main()
|
|
|
|
| 298 |
dataset_query=task_str, **overwrite_args
|
| 299 |
)
|
| 300 |
|
| 301 |
+
# this hack circumvents an issue with multi-level benchmarks (such Bluebench's translation subset) that fail when wrapped with an additional Benchmark() object.
|
| 302 |
+
if len(benchmark_subsets) == 1:
|
| 303 |
+
source = next(iter(benchmark_subsets.values()))
|
| 304 |
+
else:
|
| 305 |
+
source = Benchmark(subsets=benchmark_subsets)
|
| 306 |
|
| 307 |
+
test_dataset = _source_to_dataset(source, split=args.split)
|
| 308 |
logger.info(
|
| 309 |
f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
|
| 310 |
)
|
|
|
|
| 418 |
chat_kwargs_dict=chat_kwargs_dict,
|
| 419 |
)
|
| 420 |
|
| 421 |
+
# Keep the actual model name for the results
|
| 422 |
+
args.model = inference_model.model_name
|
| 423 |
# --- Remote Model (CrossProviderInferenceEngine) ---
|
| 424 |
elif args.model.lower() == "cross_provider":
|
| 425 |
if "model_name" not in model_args_dict:
|
|
|
|
| 450 |
model=remote_model_name,
|
| 451 |
**model_args_dict,
|
| 452 |
)
|
| 453 |
+
|
| 454 |
+
# Keep the actual model name for the results
|
| 455 |
+
args.model = inference_model.engine.model
|
| 456 |
else:
|
| 457 |
# This case should not be reached due to argparse choices
|
| 458 |
logger.error(
|
|
|
|
| 691 |
|
| 692 |
# prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
|
| 693 |
|
| 694 |
+
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
| 695 |
|
| 696 |
results_path = prepend_timestamp_to_path(results_path, timestamp)
|
| 697 |
samples_path = prepend_timestamp_to_path(samples_path, timestamp)
|
|
|
|
| 834 |
logger.info("Unitxt Evaluation CLI finished successfully.")
|
| 835 |
|
| 836 |
|
| 837 |
+
def extract_scores(directory): # pragma: no cover
|
| 838 |
+
import pandas as pd
|
| 839 |
+
|
| 840 |
+
data = []
|
| 841 |
+
|
| 842 |
+
for filename in sorted(os.listdir(directory)):
|
| 843 |
+
if filename.endswith("evaluation_results.json"):
|
| 844 |
+
file_path = os.path.join(directory, filename)
|
| 845 |
+
try:
|
| 846 |
+
with open(file_path, encoding="utf-8") as f:
|
| 847 |
+
content = json.load(f)
|
| 848 |
+
|
| 849 |
+
env_info = content.get("environment_info", {})
|
| 850 |
+
timestamp = env_info.get("timestamp_utc", "N/A")
|
| 851 |
+
model = env_info.get("parsed_arguments", {}).get("model", "N/A")
|
| 852 |
+
results = content.get("results", {})
|
| 853 |
+
|
| 854 |
+
row = {}
|
| 855 |
+
row["Model"] = model
|
| 856 |
+
row["Timestamp"] = timestamp
|
| 857 |
+
row["Average"] = results.get("score", "N/A")
|
| 858 |
+
|
| 859 |
+
for key in results.keys():
|
| 860 |
+
if isinstance(results[key], dict):
|
| 861 |
+
score = results[key].get("score", "N/A")
|
| 862 |
+
row[key] = score
|
| 863 |
+
|
| 864 |
+
data.append(row)
|
| 865 |
+
except Exception as e:
|
| 866 |
+
logger.error(f"Error parsing results file {filename}: {e}.")
|
| 867 |
+
|
| 868 |
+
return pd.DataFrame(data).sort_values(by="Timestamp", ascending=True)
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
def summarize_cli():
|
| 872 |
+
if len(sys.argv) != 2:
|
| 873 |
+
logger.error("Usage: python summarize_cli_results.py <results-directory>")
|
| 874 |
+
sys.exit(1)
|
| 875 |
+
directory = sys.argv[1]
|
| 876 |
+
df = extract_scores(directory)
|
| 877 |
+
|
| 878 |
+
logger.info(df.to_markdown(index=False))
|
| 879 |
+
|
| 880 |
+
|
| 881 |
if __name__ == "__main__":
|
| 882 |
main()
|
formats.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import re
|
| 2 |
from abc import abstractmethod
|
| 3 |
from typing import (
|
|
@@ -18,6 +19,7 @@ from .image_operators import image_to_data_url
|
|
| 18 |
from .operator import InstanceOperator
|
| 19 |
from .settings_utils import get_constants
|
| 20 |
from .type_utils import isoftype
|
|
|
|
| 21 |
from .utils import retry_connection_with_exponential_backoff
|
| 22 |
|
| 23 |
constants = get_constants()
|
|
@@ -135,6 +137,9 @@ class BaseFormat(Format):
|
|
| 135 |
def _prepare_instance_fields(self, instance) -> Tuple[str]:
|
| 136 |
instance_fields = {}
|
| 137 |
|
|
|
|
|
|
|
|
|
|
| 138 |
for field in (
|
| 139 |
"source",
|
| 140 |
constants.instruction_field,
|
|
@@ -170,6 +175,7 @@ class BaseFormat(Format):
|
|
| 170 |
target_prefix: str,
|
| 171 |
demos: List[Dict[str, Any]],
|
| 172 |
media: Optional[Dict[str, Any]] = None,
|
|
|
|
| 173 |
) -> str:
|
| 174 |
"""Abstract method for formatting instances in different subclasses.
|
| 175 |
|
|
@@ -256,7 +262,10 @@ class SystemFormat(BaseFormat):
|
|
| 256 |
target_prefix: str,
|
| 257 |
demos: List[Dict[str, Any]],
|
| 258 |
media: Optional[Dict[str, Any]] = None,
|
|
|
|
| 259 |
) -> str:
|
|
|
|
|
|
|
| 260 |
demos_string = ""
|
| 261 |
for demo in demos:
|
| 262 |
demo_str = self.demo_format.format(
|
|
@@ -356,8 +365,18 @@ class ChatAPIFormat(BaseFormat):
|
|
| 356 |
)
|
| 357 |
|
| 358 |
The resulting `messages` is now a dictionary ready for sending to the OpenAI API.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
"""
|
| 360 |
|
|
|
|
|
|
|
| 361 |
def to_content(self, text: str, media: Dict[str, Any]) -> Union[str, List[Content]]:
|
| 362 |
# Regular expression to find <img> tags with src attribute
|
| 363 |
img_tag_pattern = re.compile(
|
|
@@ -419,12 +438,15 @@ class ChatAPIFormat(BaseFormat):
|
|
| 419 |
target_prefix: str,
|
| 420 |
demos: List[Dict[str, Any]],
|
| 421 |
media: Optional[Dict[str, Any]] = None,
|
|
|
|
| 422 |
) -> List[Message]:
|
| 423 |
messages = []
|
| 424 |
|
| 425 |
-
if system_prompt or instruction:
|
| 426 |
system_content = self.to_content(
|
| 427 |
-
system_prompt
|
|
|
|
|
|
|
| 428 |
media,
|
| 429 |
)
|
| 430 |
messages.append(
|
|
@@ -435,13 +457,22 @@ class ChatAPIFormat(BaseFormat):
|
|
| 435 |
)
|
| 436 |
|
| 437 |
for demo_instance in demos:
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
assistant_content = self.to_content(
|
| 440 |
-
target_prefix + demo_instance["target"],
|
|
|
|
| 441 |
)
|
| 442 |
messages.extend(
|
| 443 |
[
|
| 444 |
-
{"role": "user", "content": user_content},
|
| 445 |
{
|
| 446 |
"role": "assistant",
|
| 447 |
"content": assistant_content,
|
|
@@ -449,9 +480,15 @@ class ChatAPIFormat(BaseFormat):
|
|
| 449 |
]
|
| 450 |
)
|
| 451 |
|
| 452 |
-
|
|
|
|
|
|
|
| 453 |
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
return messages
|
| 457 |
|
|
@@ -463,6 +500,7 @@ class ChatAPIFormat(BaseFormat):
|
|
| 463 |
target_prefix: str,
|
| 464 |
demos: List[Dict[str, Any]],
|
| 465 |
media: Optional[Dict[str, Any]] = None,
|
|
|
|
| 466 |
) -> Union[str, List[Message]]:
|
| 467 |
chat = self.to_chat(
|
| 468 |
system_prompt,
|
|
@@ -471,6 +509,7 @@ class ChatAPIFormat(BaseFormat):
|
|
| 471 |
target_prefix,
|
| 472 |
demos,
|
| 473 |
media,
|
|
|
|
| 474 |
)
|
| 475 |
media["images"] = []
|
| 476 |
return chat
|
|
@@ -492,6 +531,7 @@ class HFSystemFormat(ChatAPIFormat):
|
|
| 492 |
"""
|
| 493 |
|
| 494 |
model_name: str
|
|
|
|
| 495 |
_requirements_list = ["transformers", "Jinja2"]
|
| 496 |
|
| 497 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
|
@@ -509,13 +549,17 @@ class HFSystemFormat(ChatAPIFormat):
|
|
| 509 |
target_prefix: str,
|
| 510 |
demos: List[Dict[str, Any]],
|
| 511 |
media: Optional[Dict[str, Any]] = None,
|
|
|
|
| 512 |
) -> str:
|
| 513 |
chat = self.to_chat(
|
| 514 |
-
system_prompt, instruction, source, target_prefix, demos, media
|
| 515 |
)
|
| 516 |
return (
|
| 517 |
self.tokenizer.apply_chat_template(
|
| 518 |
-
chat,
|
|
|
|
|
|
|
|
|
|
| 519 |
)
|
| 520 |
+ target_prefix
|
| 521 |
)
|
|
|
|
| 1 |
+
import json
|
| 2 |
import re
|
| 3 |
from abc import abstractmethod
|
| 4 |
from typing import (
|
|
|
|
| 19 |
from .operator import InstanceOperator
|
| 20 |
from .settings_utils import get_constants
|
| 21 |
from .type_utils import isoftype
|
| 22 |
+
from .types import Dialog
|
| 23 |
from .utils import retry_connection_with_exponential_backoff
|
| 24 |
|
| 25 |
constants = get_constants()
|
|
|
|
| 137 |
def _prepare_instance_fields(self, instance) -> Tuple[str]:
|
| 138 |
instance_fields = {}
|
| 139 |
|
| 140 |
+
if "__turns__" in instance:
|
| 141 |
+
instance_fields["turns"] = instance["__turns__"]
|
| 142 |
+
|
| 143 |
for field in (
|
| 144 |
"source",
|
| 145 |
constants.instruction_field,
|
|
|
|
| 175 |
target_prefix: str,
|
| 176 |
demos: List[Dict[str, Any]],
|
| 177 |
media: Optional[Dict[str, Any]] = None,
|
| 178 |
+
turns: Optional[Dialog] = None,
|
| 179 |
) -> str:
|
| 180 |
"""Abstract method for formatting instances in different subclasses.
|
| 181 |
|
|
|
|
| 262 |
target_prefix: str,
|
| 263 |
demos: List[Dict[str, Any]],
|
| 264 |
media: Optional[Dict[str, Any]] = None,
|
| 265 |
+
turns: Optional[Dialog] = None,
|
| 266 |
) -> str:
|
| 267 |
+
if turns is not None and not source:
|
| 268 |
+
source = json.dumps(turns)
|
| 269 |
demos_string = ""
|
| 270 |
for demo in demos:
|
| 271 |
demo_str = self.demo_format.format(
|
|
|
|
| 365 |
)
|
| 366 |
|
| 367 |
The resulting `messages` is now a dictionary ready for sending to the OpenAI API.
|
| 368 |
+
|
| 369 |
+
By default, the instruction in the template is placed in a turn with a 'system' role.
|
| 370 |
+
However, some chat tokenizers, will not place the default system prompt for the model,
|
| 371 |
+
if there is turn with an explicit 'system' role. To keep the default system prompt,
|
| 372 |
+
set 'place_instruction_in_user_turns=True'. This will cause the instruction of the template
|
| 373 |
+
to be placed in a turn with a 'user' role. Note the instruction will also be placed
|
| 374 |
+
in every demo turn (if demos are generated.)
|
| 375 |
+
|
| 376 |
"""
|
| 377 |
|
| 378 |
+
place_instruction_in_user_turns: bool = False
|
| 379 |
+
|
| 380 |
def to_content(self, text: str, media: Dict[str, Any]) -> Union[str, List[Content]]:
|
| 381 |
# Regular expression to find <img> tags with src attribute
|
| 382 |
img_tag_pattern = re.compile(
|
|
|
|
| 438 |
target_prefix: str,
|
| 439 |
demos: List[Dict[str, Any]],
|
| 440 |
media: Optional[Dict[str, Any]] = None,
|
| 441 |
+
turns: Optional[Dialog] = None,
|
| 442 |
) -> List[Message]:
|
| 443 |
messages = []
|
| 444 |
|
| 445 |
+
if system_prompt or (instruction and not self.place_instruction_in_user_turns):
|
| 446 |
system_content = self.to_content(
|
| 447 |
+
system_prompt
|
| 448 |
+
+ ("\n" if system_prompt != "" else "")
|
| 449 |
+
+ (instruction if not self.place_instruction_in_user_turns else ""),
|
| 450 |
media,
|
| 451 |
)
|
| 452 |
messages.append(
|
|
|
|
| 457 |
)
|
| 458 |
|
| 459 |
for demo_instance in demos:
|
| 460 |
+
if "__turns__" in demo_instance:
|
| 461 |
+
messages.extend(demo_instance["__turns__"])
|
| 462 |
+
else:
|
| 463 |
+
text = demo_instance["source"]
|
| 464 |
+
|
| 465 |
+
if instruction and self.place_instruction_in_user_turns:
|
| 466 |
+
text = f"{instruction}\n{text}"
|
| 467 |
+
source_content = self.to_content(text, media)
|
| 468 |
+
messages.extend([{"role": "user", "content": source_content}])
|
| 469 |
+
|
| 470 |
assistant_content = self.to_content(
|
| 471 |
+
target_prefix + demo_instance["target"],
|
| 472 |
+
media,
|
| 473 |
)
|
| 474 |
messages.extend(
|
| 475 |
[
|
|
|
|
| 476 |
{
|
| 477 |
"role": "assistant",
|
| 478 |
"content": assistant_content,
|
|
|
|
| 480 |
]
|
| 481 |
)
|
| 482 |
|
| 483 |
+
text = source
|
| 484 |
+
if instruction and self.place_instruction_in_user_turns:
|
| 485 |
+
text = f"{instruction}\n{text}"
|
| 486 |
|
| 487 |
+
if turns is None:
|
| 488 |
+
last_user_content = self.to_content(text, media)
|
| 489 |
+
messages.extend([{"role": "user", "content": last_user_content}])
|
| 490 |
+
else:
|
| 491 |
+
messages.extend(turns)
|
| 492 |
|
| 493 |
return messages
|
| 494 |
|
|
|
|
| 500 |
target_prefix: str,
|
| 501 |
demos: List[Dict[str, Any]],
|
| 502 |
media: Optional[Dict[str, Any]] = None,
|
| 503 |
+
turns: Optional[Dialog] = None,
|
| 504 |
) -> Union[str, List[Message]]:
|
| 505 |
chat = self.to_chat(
|
| 506 |
system_prompt,
|
|
|
|
| 509 |
target_prefix,
|
| 510 |
demos,
|
| 511 |
media,
|
| 512 |
+
turns,
|
| 513 |
)
|
| 514 |
media["images"] = []
|
| 515 |
return chat
|
|
|
|
| 531 |
"""
|
| 532 |
|
| 533 |
model_name: str
|
| 534 |
+
chat_kwargs_dict: Dict[str, str] = {}
|
| 535 |
_requirements_list = ["transformers", "Jinja2"]
|
| 536 |
|
| 537 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
|
|
|
| 549 |
target_prefix: str,
|
| 550 |
demos: List[Dict[str, Any]],
|
| 551 |
media: Optional[Dict[str, Any]] = None,
|
| 552 |
+
turns: Optional[Dialog] = None,
|
| 553 |
) -> str:
|
| 554 |
chat = self.to_chat(
|
| 555 |
+
system_prompt, instruction, source, target_prefix, demos, media, turns
|
| 556 |
)
|
| 557 |
return (
|
| 558 |
self.tokenizer.apply_chat_template(
|
| 559 |
+
chat,
|
| 560 |
+
tokenize=False,
|
| 561 |
+
add_generation_prompt=True,
|
| 562 |
+
**self.chat_kwargs_dict,
|
| 563 |
)
|
| 564 |
+ target_prefix
|
| 565 |
)
|
fusion.py
CHANGED
|
@@ -2,6 +2,7 @@ from abc import abstractmethod
|
|
| 2 |
from typing import Dict, Generator, List, Optional, Union
|
| 3 |
|
| 4 |
from .dataclass import NonPositionalField
|
|
|
|
| 5 |
from .logging_utils import get_logger
|
| 6 |
from .operator import SourceOperator
|
| 7 |
from .random_utils import new_random_generator
|
|
@@ -92,7 +93,7 @@ class FixedFusion(BaseFusion):
|
|
| 92 |
max_from_this_split = max_per_this_split
|
| 93 |
|
| 94 |
logger.info(f"Processing {split} from {origin_name}...")
|
| 95 |
-
|
| 96 |
for instance in multi_stream[split]:
|
| 97 |
if (
|
| 98 |
max_from_this_split is not None
|
|
@@ -105,8 +106,6 @@ class FixedFusion(BaseFusion):
|
|
| 105 |
instance["subset"].insert(0, origin_name)
|
| 106 |
emitted_from_this_split += 1
|
| 107 |
yield instance
|
| 108 |
-
except Exception as e:
|
| 109 |
-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
|
| 110 |
|
| 111 |
|
| 112 |
class WeightedFusion(BaseFusion):
|
|
@@ -164,16 +163,15 @@ class WeightedFusion(BaseFusion):
|
|
| 164 |
weights=[self.named_weights[name] for name in population],
|
| 165 |
)[0]
|
| 166 |
iterator = iterators[origin_name]
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
if
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
raise RuntimeError(f"Exception in subset: {origin_name}") from e
|
|
|
|
| 2 |
from typing import Dict, Generator, List, Optional, Union
|
| 3 |
|
| 4 |
from .dataclass import NonPositionalField
|
| 5 |
+
from .error_utils import error_context
|
| 6 |
from .logging_utils import get_logger
|
| 7 |
from .operator import SourceOperator
|
| 8 |
from .random_utils import new_random_generator
|
|
|
|
| 93 |
max_from_this_split = max_per_this_split
|
| 94 |
|
| 95 |
logger.info(f"Processing {split} from {origin_name}...")
|
| 96 |
+
with error_context(self, subset=origin_name):
|
| 97 |
for instance in multi_stream[split]:
|
| 98 |
if (
|
| 99 |
max_from_this_split is not None
|
|
|
|
| 106 |
instance["subset"].insert(0, origin_name)
|
| 107 |
emitted_from_this_split += 1
|
| 108 |
yield instance
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
class WeightedFusion(BaseFusion):
|
|
|
|
| 163 |
weights=[self.named_weights[name] for name in population],
|
| 164 |
)[0]
|
| 165 |
iterator = iterators[origin_name]
|
| 166 |
+
with error_context(self, subset=origin_name):
|
| 167 |
+
try:
|
| 168 |
+
instance = next(iterator)
|
| 169 |
+
if isinstance(origin_name, str):
|
| 170 |
+
if "subset" not in instance:
|
| 171 |
+
instance["subset"] = []
|
| 172 |
+
instance["subset"].insert(0, origin_name)
|
| 173 |
+
total_examples += 1
|
| 174 |
+
yield instance
|
| 175 |
+
|
| 176 |
+
except StopIteration:
|
| 177 |
+
iterators.pop(origin_name)
|
|
|
inference.py
CHANGED
|
@@ -39,7 +39,7 @@ from .artifact import Artifact
|
|
| 39 |
from .base_metric import Metric
|
| 40 |
from .dataclass import InternalField, NonPositionalField
|
| 41 |
from .deprecation_utils import deprecation
|
| 42 |
-
from .error_utils import UnitxtError, UnitxtWarning
|
| 43 |
from .image_operators import (
|
| 44 |
EncodeImageToString,
|
| 45 |
ImageDataString,
|
|
@@ -121,6 +121,8 @@ class TextGenerationInferenceOutput:
|
|
| 121 |
| For example: ``[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
|
| 122 |
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]``
|
| 123 |
|
|
|
|
|
|
|
| 124 |
input_tokens (int) : number of input tokens to the model.
|
| 125 |
|
| 126 |
output_tokens (int) : number of output tokens to the model.
|
|
@@ -137,6 +139,7 @@ class TextGenerationInferenceOutput:
|
|
| 137 |
"""
|
| 138 |
|
| 139 |
prediction: Union[str, List[Dict[str, Any]]]
|
|
|
|
| 140 |
input_tokens: Optional[int] = None
|
| 141 |
output_tokens: Optional[int] = None
|
| 142 |
stop_reason: Optional[str] = None
|
|
@@ -186,12 +189,19 @@ class InferenceEngine(Artifact):
|
|
| 186 |
def prepare(self):
|
| 187 |
if not settings.mock_inference_mode:
|
| 188 |
super().prepare() # no need to prepare a mock
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
if self.use_cache:
|
| 191 |
from diskcache import Cache
|
| 192 |
|
| 193 |
self._cache = Cache(
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
)
|
| 196 |
|
| 197 |
def __call__(
|
|
@@ -199,7 +209,12 @@ class InferenceEngine(Artifact):
|
|
| 199 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 200 |
return_meta_data: bool = False,
|
| 201 |
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
def get_instance_cache_key(self, instance):
|
| 205 |
instance_key_fields = ["media", "source", "task_data"]
|
|
@@ -243,54 +258,69 @@ class InferenceEngine(Artifact):
|
|
| 243 |
result = self._mock_infer(dataset)
|
| 244 |
else:
|
| 245 |
if self.use_cache:
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
):
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
for
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
) # each element is index in batch, and value
|
| 260 |
-
else:
|
| 261 |
-
missing_examples.append(
|
| 262 |
-
(i, item)
|
| 263 |
-
) # each element is index in batch and example
|
| 264 |
-
# infare on missing examples only, without indices
|
| 265 |
-
|
| 266 |
-
logger.info(
|
| 267 |
-
f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})"
|
| 268 |
-
)
|
| 269 |
-
if len(missing_examples) > 0:
|
| 270 |
-
inferred_results = self._infer(
|
| 271 |
-
[e[1] for e in missing_examples], return_meta_data
|
| 272 |
-
)
|
| 273 |
-
# recombined to index and value
|
| 274 |
-
inferred_results = list(
|
| 275 |
-
zip([e[0] for e in missing_examples], inferred_results)
|
| 276 |
-
)
|
| 277 |
-
# Add missing examples to cache
|
| 278 |
-
for (_, item), (_, prediction) in zip(
|
| 279 |
-
missing_examples, inferred_results
|
| 280 |
-
):
|
| 281 |
-
if prediction is None:
|
| 282 |
-
continue
|
| 283 |
cache_key = self._get_cache_key(item)
|
| 284 |
-
self._cache
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
else:
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
return ListWithMetadata(
|
| 295 |
result,
|
| 296 |
metadata={
|
|
@@ -339,7 +369,16 @@ class InferenceEngine(Artifact):
|
|
| 339 |
|
| 340 |
def to_messages(self, instance):
|
| 341 |
if isinstance(instance["source"], list):
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
return [
|
| 344 |
{
|
| 345 |
"role": "user",
|
|
@@ -521,13 +560,6 @@ class HFInferenceEngineBase(
|
|
| 521 |
def get_engine_id(self):
|
| 522 |
return get_model_and_label_id(self.model_name, self.label)
|
| 523 |
|
| 524 |
-
def decode_tokens(self, tokens: Sequence, inp_length: int) -> List[str]:
|
| 525 |
-
return self.processor.decode(tokens[inp_length:], skip_special_tokens=True)
|
| 526 |
-
|
| 527 |
-
@staticmethod
|
| 528 |
-
def create_string_from_tokens(string_tokens: List[str]) -> str:
|
| 529 |
-
return "".join(token for token in string_tokens)
|
| 530 |
-
|
| 531 |
def make_predictions(self, prepared_inputs: Mapping) -> Mapping:
|
| 532 |
return self.model.generate(
|
| 533 |
**prepared_inputs,
|
|
@@ -598,6 +630,7 @@ class HFInferenceEngineBase(
|
|
| 598 |
def get_return_object(
|
| 599 |
self,
|
| 600 |
output: Union[str, List[Dict[str, Any]]],
|
|
|
|
| 601 |
output_tokens: Optional[int],
|
| 602 |
inp: Optional[str],
|
| 603 |
inp_tokens: Optional[int],
|
|
@@ -606,6 +639,7 @@ class HFInferenceEngineBase(
|
|
| 606 |
if return_meta_data:
|
| 607 |
return TextGenerationInferenceOutput(
|
| 608 |
prediction=output,
|
|
|
|
| 609 |
output_tokens=output_tokens if output_tokens is not None else None,
|
| 610 |
input_text=inp,
|
| 611 |
input_tokens=inp_tokens if inp_tokens is not None else None,
|
|
@@ -689,7 +723,8 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 689 |
# cause an error because the data is always on the gpu
|
| 690 |
# if torch.cuda.device_count() > 1:
|
| 691 |
# assert self.device == torch.device(0)
|
| 692 |
-
|
|
|
|
| 693 |
# else:
|
| 694 |
# if not self.load_in_8bit:
|
| 695 |
# args["device"] = self.device
|
|
@@ -717,15 +752,21 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 717 |
**model_args,
|
| 718 |
)
|
| 719 |
|
| 720 |
-
def prepare_inputs(self, data: Iterable) -> Mapping:
|
| 721 |
tokenizer_kargs = {}
|
| 722 |
if isinstance(data[0], list):
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
tokenizer_kargs["add_special_tokens"] = False
|
| 730 |
|
| 731 |
if self.processor.pad_token is None:
|
|
@@ -766,59 +807,71 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 766 |
total=len(dataset) // self.batch_size,
|
| 767 |
):
|
| 768 |
# Get the current batch
|
| 769 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
|
| 771 |
-
|
| 772 |
-
# 1. Tokenize inputs for the batch
|
| 773 |
-
tokenized_inputs = self.prepare_inputs(batch_sources)
|
| 774 |
|
| 775 |
-
#
|
| 776 |
input_length = (
|
| 777 |
1
|
| 778 |
if self.model.config.is_encoder_decoder
|
| 779 |
else tokenized_inputs.input_ids.shape[1]
|
| 780 |
)
|
| 781 |
|
| 782 |
-
#
|
| 783 |
predictions = self.make_predictions(tokenized_inputs)
|
| 784 |
sequences = predictions.sequences # Sequences for the current batch
|
| 785 |
|
| 786 |
-
|
| 787 |
-
string_tokens_batch = [
|
| 788 |
-
self.decode_tokens(sequence, input_length) for sequence in sequences
|
| 789 |
-
]
|
| 790 |
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
)
|
| 800 |
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
)
|
| 814 |
-
for j in range(
|
| 815 |
-
len(sequences)
|
| 816 |
-
) # Iterate through items in the current batch
|
| 817 |
-
]
|
| 818 |
|
| 819 |
-
# Add results from this batch to the overall list
|
| 820 |
all_final_outputs.extend(batch_results)
|
| 821 |
-
# --- End of batch processing ---
|
| 822 |
|
| 823 |
return all_final_outputs
|
| 824 |
|
|
@@ -847,7 +900,10 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
|
|
| 847 |
self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int]
|
| 848 |
) -> Sequence:
|
| 849 |
if not hasattr(self.model.config, "vocab_size"):
|
| 850 |
-
|
|
|
|
|
|
|
|
|
|
| 851 |
|
| 852 |
return super().compute_transition_scores(sequences, scores, beam_indices)
|
| 853 |
|
|
@@ -917,18 +973,35 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
|
|
| 917 |
|
| 918 |
predictions = self.make_predictions(processed_inputs)
|
| 919 |
|
| 920 |
-
|
| 921 |
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
|
| 928 |
results.append(
|
| 929 |
self.get_return_object(
|
| 930 |
-
output=final_outputs,
|
| 931 |
-
|
|
|
|
| 932 |
inp=instance["source"],
|
| 933 |
inp_tokens=None,
|
| 934 |
return_meta_data=return_meta_data,
|
|
@@ -1189,6 +1262,7 @@ class HFPipelineBasedInferenceEngine(
|
|
| 1189 |
if return_meta_data:
|
| 1190 |
return TextGenerationInferenceOutput(
|
| 1191 |
prediction=output["generated_text"],
|
|
|
|
| 1192 |
model_name=self.model_name,
|
| 1193 |
inference_type=self.label,
|
| 1194 |
input_text=inp,
|
|
@@ -1252,10 +1326,13 @@ class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine):
|
|
| 1252 |
for instance in dataset
|
| 1253 |
]
|
| 1254 |
|
| 1255 |
-
def get_return_object(
|
|
|
|
|
|
|
| 1256 |
if return_meta_data:
|
| 1257 |
return TextGenerationInferenceOutput(
|
| 1258 |
prediction=predict_result,
|
|
|
|
| 1259 |
input_tokens=len(instance["source"]),
|
| 1260 |
output_tokens=len(predict_result),
|
| 1261 |
model_name=self.model_name,
|
|
@@ -1369,21 +1446,25 @@ class OllamaInferenceEngine(
|
|
| 1369 |
return get_model_and_label_id(self.model, self.label)
|
| 1370 |
|
| 1371 |
def prepare_engine(self):
|
| 1372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1373 |
|
| 1374 |
def _infer(
|
| 1375 |
self,
|
| 1376 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1377 |
return_meta_data: bool = False,
|
| 1378 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1379 |
-
import ollama
|
| 1380 |
-
|
| 1381 |
args = self.to_dict([StandardAPIParamsMixin])
|
| 1382 |
results = []
|
| 1383 |
model = args.pop("model")
|
| 1384 |
for instance in dataset:
|
| 1385 |
messages = self.to_messages(instance)
|
| 1386 |
-
response =
|
| 1387 |
messages=messages,
|
| 1388 |
model=model,
|
| 1389 |
options=args,
|
|
@@ -1877,7 +1958,7 @@ class OpenAiInferenceEngine(
|
|
| 1877 |
f"Error predicting instance {messages}:{e}. Returning empty prediction"
|
| 1878 |
)
|
| 1879 |
return TextGenerationInferenceOutput(
|
| 1880 |
-
prediction="-", input_tokens=0, output_tokens=0
|
| 1881 |
)
|
| 1882 |
|
| 1883 |
@run_with_imap
|
|
@@ -1894,10 +1975,12 @@ class OpenAiInferenceEngine(
|
|
| 1894 |
top_logprobs_response = response.choices[0].logprobs.content
|
| 1895 |
pred_output = [
|
| 1896 |
{
|
|
|
|
|
|
|
| 1897 |
"top_tokens": [
|
| 1898 |
{"text": obj.token, "logprob": obj.logprob}
|
| 1899 |
for obj in generated_token.top_logprobs
|
| 1900 |
-
]
|
| 1901 |
}
|
| 1902 |
for generated_token in top_logprobs_response
|
| 1903 |
]
|
|
@@ -1907,15 +1990,21 @@ class OpenAiInferenceEngine(
|
|
| 1907 |
logging.error(
|
| 1908 |
f"Error predicting instance {messages}:{e}. Returning empty prediction"
|
| 1909 |
)
|
| 1910 |
-
prediction = [
|
|
|
|
|
|
|
| 1911 |
return TextGenerationInferenceOutput(
|
| 1912 |
-
prediction=prediction,
|
|
|
|
|
|
|
|
|
|
| 1913 |
)
|
| 1914 |
|
| 1915 |
def get_return_object(self, predict_result, response, return_meta_data):
|
| 1916 |
if return_meta_data:
|
| 1917 |
return TextGenerationInferenceOutput(
|
| 1918 |
prediction=predict_result,
|
|
|
|
| 1919 |
input_tokens=response.usage.prompt_tokens,
|
| 1920 |
output_tokens=response.usage.completion_tokens,
|
| 1921 |
model_name=self.model_name,
|
|
@@ -1973,7 +2062,12 @@ class RITSInferenceEngine(
|
|
| 1973 |
label: str = "rits"
|
| 1974 |
data_classification_policy = ["public", "proprietary"]
|
| 1975 |
|
| 1976 |
-
model_names_dict = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1977 |
|
| 1978 |
def get_default_headers(self):
|
| 1979 |
return {"RITS_API_KEY": self.credentials["api_key"]}
|
|
@@ -2606,6 +2700,7 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
| 2606 |
if return_meta_data:
|
| 2607 |
return TextGenerationInferenceOutput(
|
| 2608 |
prediction=predict_result,
|
|
|
|
| 2609 |
input_tokens=result["input_token_count"],
|
| 2610 |
output_tokens=result["generated_token_count"],
|
| 2611 |
model_name=self.model_name or self.deployment_id,
|
|
@@ -2865,6 +2960,8 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
| 2865 |
tool_call = data[idx]["tools"]["tools"] is not None
|
| 2866 |
|
| 2867 |
output = response["choices"][0][output_type]
|
|
|
|
|
|
|
| 2868 |
if tool_call:
|
| 2869 |
if "tool_calls" in output:
|
| 2870 |
func = output["tool_calls"][0]["function"]
|
|
@@ -2877,6 +2974,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
| 2877 |
results.append(
|
| 2878 |
self.get_return_object(
|
| 2879 |
prediction,
|
|
|
|
| 2880 |
response,
|
| 2881 |
str(inp),
|
| 2882 |
return_meta_data,
|
|
@@ -2885,10 +2983,13 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
| 2885 |
|
| 2886 |
return results
|
| 2887 |
|
| 2888 |
-
def get_return_object(
|
|
|
|
|
|
|
| 2889 |
if return_meta_data:
|
| 2890 |
return TextGenerationInferenceOutput(
|
| 2891 |
prediction=predict_result,
|
|
|
|
| 2892 |
input_tokens=result["usage"]["prompt_tokens"],
|
| 2893 |
output_tokens=len(predict_result)
|
| 2894 |
if isinstance(predict_result, list)
|
|
@@ -3286,6 +3387,7 @@ class LiteLLMInferenceEngine(
|
|
| 3286 |
prediction = response["choices"][0]["message"]["content"] or ""
|
| 3287 |
return TextGenerationInferenceOutput(
|
| 3288 |
prediction=prediction,
|
|
|
|
| 3289 |
input_tokens=usage.get("prompt_tokens"),
|
| 3290 |
output_tokens=usage.get("completion_tokens"),
|
| 3291 |
model_name=response.get("model", self.model),
|
|
@@ -3436,21 +3538,22 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
| 3436 |
},
|
| 3437 |
"rits": {
|
| 3438 |
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
|
|
|
| 3439 |
"granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
|
| 3440 |
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
|
| 3441 |
-
"llama-3-1-8b-instruct": "meta-llama/
|
| 3442 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
| 3443 |
"llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
|
| 3444 |
"llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
|
| 3445 |
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
| 3446 |
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
| 3447 |
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
| 3448 |
-
"llama-4-scout": "llama-4-scout-17b-16e",
|
| 3449 |
-
"llama-4-maverick": "llama-4-
|
| 3450 |
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
| 3451 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
| 3452 |
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
|
| 3453 |
-
"deepseek-v3": "deepseek-ai/
|
| 3454 |
"granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
| 3455 |
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
| 3456 |
},
|
|
|
|
| 39 |
from .base_metric import Metric
|
| 40 |
from .dataclass import InternalField, NonPositionalField
|
| 41 |
from .deprecation_utils import deprecation
|
| 42 |
+
from .error_utils import UnitxtError, UnitxtWarning, error_context
|
| 43 |
from .image_operators import (
|
| 44 |
EncodeImageToString,
|
| 45 |
ImageDataString,
|
|
|
|
| 121 |
| For example: ``[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
|
| 122 |
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]``
|
| 123 |
|
| 124 |
+
generated_text (str): The generated text generated by the model (in both _infer and _infer_log_probs calls).
|
| 125 |
+
|
| 126 |
input_tokens (int) : number of input tokens to the model.
|
| 127 |
|
| 128 |
output_tokens (int) : number of output tokens to the model.
|
|
|
|
| 139 |
"""
|
| 140 |
|
| 141 |
prediction: Union[str, List[Dict[str, Any]]]
|
| 142 |
+
generated_text: str
|
| 143 |
input_tokens: Optional[int] = None
|
| 144 |
output_tokens: Optional[int] = None
|
| 145 |
stop_reason: Optional[str] = None
|
|
|
|
| 189 |
def prepare(self):
|
| 190 |
if not settings.mock_inference_mode:
|
| 191 |
super().prepare() # no need to prepare a mock
|
| 192 |
+
with error_context(
|
| 193 |
+
self,
|
| 194 |
+
stage="Prepare Inference Engine",
|
| 195 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
| 196 |
+
):
|
| 197 |
+
self.prepare_engine()
|
| 198 |
if self.use_cache:
|
| 199 |
from diskcache import Cache
|
| 200 |
|
| 201 |
self._cache = Cache(
|
| 202 |
+
os.path.join(
|
| 203 |
+
settings.inference_engine_cache_path, self.__class__.__name__
|
| 204 |
+
)
|
| 205 |
)
|
| 206 |
|
| 207 |
def __call__(
|
|
|
|
| 209 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 210 |
return_meta_data: bool = False,
|
| 211 |
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
|
| 212 |
+
with error_context(
|
| 213 |
+
self,
|
| 214 |
+
stage="Running Inference",
|
| 215 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
| 216 |
+
):
|
| 217 |
+
return self.infer(dataset=dataset, return_meta_data=return_meta_data)
|
| 218 |
|
| 219 |
def get_instance_cache_key(self, instance):
|
| 220 |
instance_key_fields = ["media", "source", "task_data"]
|
|
|
|
| 258 |
result = self._mock_infer(dataset)
|
| 259 |
else:
|
| 260 |
if self.use_cache:
|
| 261 |
+
with error_context(
|
| 262 |
+
self,
|
| 263 |
+
stage="Inference Cache Handling",
|
| 264 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
| 265 |
):
|
| 266 |
+
number_of_batches = math.ceil(len(dataset) / self.cache_batch_size)
|
| 267 |
+
result = []
|
| 268 |
+
for batch_index, batch in enumerate(
|
| 269 |
+
batched(dataset, self.cache_batch_size)
|
| 270 |
+
):
|
| 271 |
+
cached_results = []
|
| 272 |
+
missing_examples = []
|
| 273 |
+
for i, item in enumerate(batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
cache_key = self._get_cache_key(item)
|
| 275 |
+
cached_value = self._cache.get(cache_key)
|
| 276 |
+
if cached_value is not None:
|
| 277 |
+
cached_results.append(
|
| 278 |
+
(i, cached_value)
|
| 279 |
+
) # each element is index in batch, and value
|
| 280 |
+
else:
|
| 281 |
+
missing_examples.append(
|
| 282 |
+
(i, item)
|
| 283 |
+
) # each element is index in batch and example
|
| 284 |
+
# infare on missing examples only, without indices
|
| 285 |
+
|
| 286 |
+
logger.info(
|
| 287 |
+
f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})"
|
| 288 |
+
)
|
| 289 |
+
if len(missing_examples) > 0:
|
| 290 |
+
with error_context(
|
| 291 |
+
self,
|
| 292 |
+
stage="Running Inference",
|
| 293 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
| 294 |
+
):
|
| 295 |
+
inferred_results = self._infer(
|
| 296 |
+
[e[1] for e in missing_examples], return_meta_data
|
| 297 |
+
)
|
| 298 |
+
# recombined to index and value
|
| 299 |
+
inferred_results = list(
|
| 300 |
+
zip([e[0] for e in missing_examples], inferred_results)
|
| 301 |
+
)
|
| 302 |
+
# Add missing examples to cache
|
| 303 |
+
for (_, item), (_, prediction) in zip(
|
| 304 |
+
missing_examples, inferred_results
|
| 305 |
+
):
|
| 306 |
+
if prediction is None:
|
| 307 |
+
continue
|
| 308 |
+
cache_key = self._get_cache_key(item)
|
| 309 |
+
self._cache[cache_key] = prediction
|
| 310 |
+
else:
|
| 311 |
+
inferred_results = []
|
| 312 |
+
# Combine cached and inferred results in original order
|
| 313 |
+
batch_predictions = [
|
| 314 |
+
p[1] for p in sorted(cached_results + inferred_results)
|
| 315 |
+
]
|
| 316 |
+
result.extend(batch_predictions)
|
| 317 |
else:
|
| 318 |
+
with error_context(
|
| 319 |
+
self,
|
| 320 |
+
stage="Running Inference",
|
| 321 |
+
help="https://www.unitxt.ai/en/latest/docs/inference.html",
|
| 322 |
+
):
|
| 323 |
+
result = self._infer(dataset, return_meta_data)
|
| 324 |
return ListWithMetadata(
|
| 325 |
result,
|
| 326 |
metadata={
|
|
|
|
| 369 |
|
| 370 |
def to_messages(self, instance):
|
| 371 |
if isinstance(instance["source"], list):
|
| 372 |
+
messages = []
|
| 373 |
+
for message in instance["source"]:
|
| 374 |
+
if "tool_calls" in message:
|
| 375 |
+
for tool_call in message["tool_calls"]:
|
| 376 |
+
if not isinstance(tool_call["function"]["arguments"], str):
|
| 377 |
+
tool_call["function"]["arguments"] = json.dumps(
|
| 378 |
+
tool_call["function"]["arguments"]
|
| 379 |
+
)
|
| 380 |
+
messages.append(message)
|
| 381 |
+
return messages
|
| 382 |
return [
|
| 383 |
{
|
| 384 |
"role": "user",
|
|
|
|
| 560 |
def get_engine_id(self):
|
| 561 |
return get_model_and_label_id(self.model_name, self.label)
|
| 562 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
def make_predictions(self, prepared_inputs: Mapping) -> Mapping:
|
| 564 |
return self.model.generate(
|
| 565 |
**prepared_inputs,
|
|
|
|
| 630 |
def get_return_object(
|
| 631 |
self,
|
| 632 |
output: Union[str, List[Dict[str, Any]]],
|
| 633 |
+
generated_text: str,
|
| 634 |
output_tokens: Optional[int],
|
| 635 |
inp: Optional[str],
|
| 636 |
inp_tokens: Optional[int],
|
|
|
|
| 639 |
if return_meta_data:
|
| 640 |
return TextGenerationInferenceOutput(
|
| 641 |
prediction=output,
|
| 642 |
+
generated_text=generated_text,
|
| 643 |
output_tokens=output_tokens if output_tokens is not None else None,
|
| 644 |
input_text=inp,
|
| 645 |
input_tokens=inp_tokens if inp_tokens is not None else None,
|
|
|
|
| 723 |
# cause an error because the data is always on the gpu
|
| 724 |
# if torch.cuda.device_count() > 1:
|
| 725 |
# assert self.device == torch.device(0)
|
| 726 |
+
if self.device_map is None:
|
| 727 |
+
args["device_map"] = "auto"
|
| 728 |
# else:
|
| 729 |
# if not self.load_in_8bit:
|
| 730 |
# args["device"] = self.device
|
|
|
|
| 752 |
**model_args,
|
| 753 |
)
|
| 754 |
|
| 755 |
+
def prepare_inputs(self, data: Iterable, tools: Iterable) -> Mapping:
|
| 756 |
tokenizer_kargs = {}
|
| 757 |
if isinstance(data[0], list):
|
| 758 |
+
processed = []
|
| 759 |
+
for item, item_tools in zip(data, tools):
|
| 760 |
+
processed.append(
|
| 761 |
+
self.processor.apply_chat_template(
|
| 762 |
+
item,
|
| 763 |
+
tokenize=False,
|
| 764 |
+
tools=item_tools,
|
| 765 |
+
add_generation_prompt=True,
|
| 766 |
+
**self.chat_kwargs_dict,
|
| 767 |
+
)
|
| 768 |
+
)
|
| 769 |
+
data = processed
|
| 770 |
tokenizer_kargs["add_special_tokens"] = False
|
| 771 |
|
| 772 |
if self.processor.pad_token is None:
|
|
|
|
| 807 |
total=len(dataset) // self.batch_size,
|
| 808 |
):
|
| 809 |
# Get the current batch
|
| 810 |
+
sources = []
|
| 811 |
+
tools = []
|
| 812 |
+
for instance in batch:
|
| 813 |
+
sources.append(instance["source"])
|
| 814 |
+
if "task_data" in instance and "__tools__" in instance["task_data"]:
|
| 815 |
+
task_data = instance["task_data"]
|
| 816 |
+
if isinstance(task_data, str):
|
| 817 |
+
task_data = json.loads(task_data)
|
| 818 |
+
tools.append(task_data["__tools__"])
|
| 819 |
+
else:
|
| 820 |
+
tools.append(None)
|
| 821 |
+
# Tokenize inputs for the batch
|
| 822 |
|
| 823 |
+
tokenized_inputs = self.prepare_inputs(sources, tools)
|
|
|
|
|
|
|
| 824 |
|
| 825 |
+
# Determine input length (handle encoder-decoder models)
|
| 826 |
input_length = (
|
| 827 |
1
|
| 828 |
if self.model.config.is_encoder_decoder
|
| 829 |
else tokenized_inputs.input_ids.shape[1]
|
| 830 |
)
|
| 831 |
|
| 832 |
+
# Make predictions for the batch
|
| 833 |
predictions = self.make_predictions(tokenized_inputs)
|
| 834 |
sequences = predictions.sequences # Sequences for the current batch
|
| 835 |
|
| 836 |
+
output_tokens = sequences[:, input_length:]
|
|
|
|
|
|
|
|
|
|
| 837 |
|
| 838 |
+
output_tokens_strings = []
|
| 839 |
+
for tokens in output_tokens:
|
| 840 |
+
output_tokens_strings.append(
|
| 841 |
+
[
|
| 842 |
+
self.processor.decode(token, skip_special_tokens=True)
|
| 843 |
+
for token in tokens
|
| 844 |
+
]
|
| 845 |
+
)
|
|
|
|
| 846 |
|
| 847 |
+
output_strings = []
|
| 848 |
+
for tokens in output_tokens:
|
| 849 |
+
output_strings.append(
|
| 850 |
+
self.processor.decode(tokens, skip_special_tokens=True)
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
if return_logprobs:
|
| 854 |
+
outputs = self.get_logprobs(predictions, output_tokens_strings)
|
| 855 |
+
else:
|
| 856 |
+
outputs = output_strings
|
| 857 |
+
|
| 858 |
+
# Create return objects for the batch
|
| 859 |
+
batch_results = []
|
| 860 |
+
for i in range(len(sequences)):
|
| 861 |
+
batch_results.append(
|
| 862 |
+
self.get_return_object(
|
| 863 |
+
output=outputs[i],
|
| 864 |
+
generated_text=output_strings[i],
|
| 865 |
+
output_tokens=len(output_tokens_strings[i]),
|
| 866 |
+
inp=sources[i],
|
| 867 |
+
inp_tokens=len(tokenized_inputs.encodings[i].tokens)
|
| 868 |
+
if tokenized_inputs.encodings is not None
|
| 869 |
+
else None,
|
| 870 |
+
return_meta_data=return_meta_data,
|
| 871 |
+
)
|
| 872 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
|
|
|
|
| 874 |
all_final_outputs.extend(batch_results)
|
|
|
|
| 875 |
|
| 876 |
return all_final_outputs
|
| 877 |
|
|
|
|
| 900 |
self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int]
|
| 901 |
) -> Sequence:
|
| 902 |
if not hasattr(self.model.config, "vocab_size"):
|
| 903 |
+
try:
|
| 904 |
+
self.model.config.vocab_size = self.model.vocab_size
|
| 905 |
+
except:
|
| 906 |
+
self.model.config.vocab_size = self.model.config.text_config.vocab_size
|
| 907 |
|
| 908 |
return super().compute_transition_scores(sequences, scores, beam_indices)
|
| 909 |
|
|
|
|
| 973 |
|
| 974 |
predictions = self.make_predictions(processed_inputs)
|
| 975 |
|
| 976 |
+
sequences = predictions.sequences # Sequences for the current batch
|
| 977 |
|
| 978 |
+
output_tokens = sequences[:, input_len:]
|
| 979 |
+
|
| 980 |
+
output_tokens_strings = []
|
| 981 |
+
for tokens in output_tokens:
|
| 982 |
+
output_tokens_strings.append(
|
| 983 |
+
[
|
| 984 |
+
self.processor.decode(token, skip_special_tokens=True)
|
| 985 |
+
for token in tokens
|
| 986 |
+
]
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
output_strings = []
|
| 990 |
+
for tokens in output_tokens:
|
| 991 |
+
output_strings.append(
|
| 992 |
+
self.processor.decode(tokens, skip_special_tokens=True)
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
if return_logprobs:
|
| 996 |
+
final_outputs = self.get_logprobs(predictions, output_tokens_strings)
|
| 997 |
+
else:
|
| 998 |
+
final_outputs = output_strings
|
| 999 |
|
| 1000 |
results.append(
|
| 1001 |
self.get_return_object(
|
| 1002 |
+
output=final_outputs[0],
|
| 1003 |
+
generated_text=output_strings,
|
| 1004 |
+
output_tokens=len(output_tokens_strings[0]),
|
| 1005 |
inp=instance["source"],
|
| 1006 |
inp_tokens=None,
|
| 1007 |
return_meta_data=return_meta_data,
|
|
|
|
| 1262 |
if return_meta_data:
|
| 1263 |
return TextGenerationInferenceOutput(
|
| 1264 |
prediction=output["generated_text"],
|
| 1265 |
+
generated_text=output["generated_text"],
|
| 1266 |
model_name=self.model_name,
|
| 1267 |
inference_type=self.label,
|
| 1268 |
input_text=inp,
|
|
|
|
| 1326 |
for instance in dataset
|
| 1327 |
]
|
| 1328 |
|
| 1329 |
+
def get_return_object(
|
| 1330 |
+
self, predict_result, generated_text, instance, return_meta_data
|
| 1331 |
+
):
|
| 1332 |
if return_meta_data:
|
| 1333 |
return TextGenerationInferenceOutput(
|
| 1334 |
prediction=predict_result,
|
| 1335 |
+
generated_text=self.default_inference_value,
|
| 1336 |
input_tokens=len(instance["source"]),
|
| 1337 |
output_tokens=len(predict_result),
|
| 1338 |
model_name=self.model_name,
|
|
|
|
| 1446 |
return get_model_and_label_id(self.model, self.label)
|
| 1447 |
|
| 1448 |
def prepare_engine(self):
|
| 1449 |
+
from ollama import Client
|
| 1450 |
+
|
| 1451 |
+
self.client = Client(
|
| 1452 |
+
host=self.credentials["api_base"]
|
| 1453 |
+
if self.credentials is not None and "api_base" in self.credentials
|
| 1454 |
+
else None
|
| 1455 |
+
)
|
| 1456 |
|
| 1457 |
def _infer(
|
| 1458 |
self,
|
| 1459 |
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1460 |
return_meta_data: bool = False,
|
| 1461 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
|
|
|
|
|
|
| 1462 |
args = self.to_dict([StandardAPIParamsMixin])
|
| 1463 |
results = []
|
| 1464 |
model = args.pop("model")
|
| 1465 |
for instance in dataset:
|
| 1466 |
messages = self.to_messages(instance)
|
| 1467 |
+
response = self.client.chat(
|
| 1468 |
messages=messages,
|
| 1469 |
model=model,
|
| 1470 |
options=args,
|
|
|
|
| 1958 |
f"Error predicting instance {messages}:{e}. Returning empty prediction"
|
| 1959 |
)
|
| 1960 |
return TextGenerationInferenceOutput(
|
| 1961 |
+
prediction="-", generated_text="-", input_tokens=0, output_tokens=0
|
| 1962 |
)
|
| 1963 |
|
| 1964 |
@run_with_imap
|
|
|
|
| 1975 |
top_logprobs_response = response.choices[0].logprobs.content
|
| 1976 |
pred_output = [
|
| 1977 |
{
|
| 1978 |
+
"text": generated_token.token,
|
| 1979 |
+
"logprob": generated_token.logprob,
|
| 1980 |
"top_tokens": [
|
| 1981 |
{"text": obj.token, "logprob": obj.logprob}
|
| 1982 |
for obj in generated_token.top_logprobs
|
| 1983 |
+
],
|
| 1984 |
}
|
| 1985 |
for generated_token in top_logprobs_response
|
| 1986 |
]
|
|
|
|
| 1990 |
logging.error(
|
| 1991 |
f"Error predicting instance {messages}:{e}. Returning empty prediction"
|
| 1992 |
)
|
| 1993 |
+
prediction = [
|
| 1994 |
+
{"text": "-", "logprob": 0, "top_tokens": [{"text": "-", "logprob": 0}]}
|
| 1995 |
+
]
|
| 1996 |
return TextGenerationInferenceOutput(
|
| 1997 |
+
prediction=prediction,
|
| 1998 |
+
generated_text=prediction,
|
| 1999 |
+
input_tokens=0,
|
| 2000 |
+
output_tokens=0,
|
| 2001 |
)
|
| 2002 |
|
| 2003 |
def get_return_object(self, predict_result, response, return_meta_data):
|
| 2004 |
if return_meta_data:
|
| 2005 |
return TextGenerationInferenceOutput(
|
| 2006 |
prediction=predict_result,
|
| 2007 |
+
generated_text=response.choices[0].message.content,
|
| 2008 |
input_tokens=response.usage.prompt_tokens,
|
| 2009 |
output_tokens=response.usage.completion_tokens,
|
| 2010 |
model_name=self.model_name,
|
|
|
|
| 2062 |
label: str = "rits"
|
| 2063 |
data_classification_policy = ["public", "proprietary"]
|
| 2064 |
|
| 2065 |
+
model_names_dict = {
|
| 2066 |
+
"microsoft/phi-4": "microsoft-phi-4",
|
| 2067 |
+
"meta-llama/llama-4-maverick-17b-128e-instruct-fp8": "llama-4-mvk-17b-128e-fp8",
|
| 2068 |
+
"deepseek-ai/DeepSeek-V3": "deepseek-v3-h200",
|
| 2069 |
+
"meta-llama/Llama-3.1-8B-Instruct": "llama-3-1-8b-instruct",
|
| 2070 |
+
}
|
| 2071 |
|
| 2072 |
def get_default_headers(self):
|
| 2073 |
return {"RITS_API_KEY": self.credentials["api_key"]}
|
|
|
|
| 2700 |
if return_meta_data:
|
| 2701 |
return TextGenerationInferenceOutput(
|
| 2702 |
prediction=predict_result,
|
| 2703 |
+
generated_text=result["generated_text"],
|
| 2704 |
input_tokens=result["input_token_count"],
|
| 2705 |
output_tokens=result["generated_token_count"],
|
| 2706 |
model_name=self.model_name or self.deployment_id,
|
|
|
|
| 2960 |
tool_call = data[idx]["tools"]["tools"] is not None
|
| 2961 |
|
| 2962 |
output = response["choices"][0][output_type]
|
| 2963 |
+
if "content" not in output:
|
| 2964 |
+
output["content"] = ""
|
| 2965 |
if tool_call:
|
| 2966 |
if "tool_calls" in output:
|
| 2967 |
func = output["tool_calls"][0]["function"]
|
|
|
|
| 2974 |
results.append(
|
| 2975 |
self.get_return_object(
|
| 2976 |
prediction,
|
| 2977 |
+
response["choices"][0]["message"]["content"],
|
| 2978 |
response,
|
| 2979 |
str(inp),
|
| 2980 |
return_meta_data,
|
|
|
|
| 2983 |
|
| 2984 |
return results
|
| 2985 |
|
| 2986 |
+
def get_return_object(
|
| 2987 |
+
self, predict_result, generated_text, result, input_text, return_meta_data
|
| 2988 |
+
):
|
| 2989 |
if return_meta_data:
|
| 2990 |
return TextGenerationInferenceOutput(
|
| 2991 |
prediction=predict_result,
|
| 2992 |
+
generated_text=generated_text,
|
| 2993 |
input_tokens=result["usage"]["prompt_tokens"],
|
| 2994 |
output_tokens=len(predict_result)
|
| 2995 |
if isinstance(predict_result, list)
|
|
|
|
| 3387 |
prediction = response["choices"][0]["message"]["content"] or ""
|
| 3388 |
return TextGenerationInferenceOutput(
|
| 3389 |
prediction=prediction,
|
| 3390 |
+
generated_text=response["choices"][0]["message"]["content"],
|
| 3391 |
input_tokens=usage.get("prompt_tokens"),
|
| 3392 |
output_tokens=usage.get("completion_tokens"),
|
| 3393 |
model_name=response.get("model", self.model),
|
|
|
|
| 3538 |
},
|
| 3539 |
"rits": {
|
| 3540 |
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
| 3541 |
+
"granite-3-1-8b-instruct": "ibm-granite/granite-3.1-8b-instruct",
|
| 3542 |
"granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
|
| 3543 |
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
|
| 3544 |
+
"llama-3-1-8b-instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
| 3545 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
| 3546 |
"llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
|
| 3547 |
"llama-3-1-405b-instruct-fp8": "meta-llama/llama-3-1-405b-instruct-fp8",
|
| 3548 |
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
| 3549 |
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
| 3550 |
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
| 3551 |
+
"llama-4-scout": "meta-llama/llama-4-scout-17b-16e",
|
| 3552 |
+
"llama-4-maverick": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
|
| 3553 |
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
| 3554 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
| 3555 |
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
|
| 3556 |
+
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
|
| 3557 |
"granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
| 3558 |
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
| 3559 |
},
|
llm_as_judge_constants.py
CHANGED
|
@@ -125,7 +125,7 @@ EVALUATOR_TO_MODEL_ID = {
|
|
| 125 |
EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
|
| 126 |
EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
|
| 127 |
EvaluatorNameEnum.GRANITE3_3_8B: "granite-3-3-8b-instruct",
|
| 128 |
-
EvaluatorNameEnum.DEEPSEEK_V3: "deepseek-
|
| 129 |
EvaluatorNameEnum.GEMMA_2_5_PRO: "gemma-2-5-pro",
|
| 130 |
EvaluatorNameEnum.GEMINI_2_5_FLASH: "gemini-2-5-flash",
|
| 131 |
}
|
|
@@ -198,7 +198,6 @@ EVALUATORS_METADATA = [
|
|
| 198 |
[
|
| 199 |
ModelProviderEnum.WATSONX,
|
| 200 |
ModelProviderEnum.TOGETHER_AI,
|
| 201 |
-
ModelProviderEnum.RITS,
|
| 202 |
ModelProviderEnum.OLLAMA,
|
| 203 |
],
|
| 204 |
),
|
|
|
|
| 125 |
EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
|
| 126 |
EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
|
| 127 |
EvaluatorNameEnum.GRANITE3_3_8B: "granite-3-3-8b-instruct",
|
| 128 |
+
EvaluatorNameEnum.DEEPSEEK_V3: "deepseek-v3",
|
| 129 |
EvaluatorNameEnum.GEMMA_2_5_PRO: "gemma-2-5-pro",
|
| 130 |
EvaluatorNameEnum.GEMINI_2_5_FLASH: "gemini-2-5-flash",
|
| 131 |
}
|
|
|
|
| 198 |
[
|
| 199 |
ModelProviderEnum.WATSONX,
|
| 200 |
ModelProviderEnum.TOGETHER_AI,
|
|
|
|
| 201 |
ModelProviderEnum.OLLAMA,
|
| 202 |
],
|
| 203 |
),
|
loaders.py
CHANGED
|
@@ -66,7 +66,7 @@ from tqdm import tqdm
|
|
| 66 |
|
| 67 |
from .dataclass import NonPositionalField
|
| 68 |
from .dict_utils import dict_get
|
| 69 |
-
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
| 70 |
from .fusion import FixedFusion
|
| 71 |
from .logging_utils import get_logger
|
| 72 |
from .operator import SourceOperator
|
|
@@ -90,23 +90,27 @@ class UnitxtUnverifiedCodeError(UnitxtError):
|
|
| 90 |
|
| 91 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
| 92 |
def hf_load_dataset(path: str, *args, **kwargs):
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
|
@@ -218,13 +222,15 @@ class Loader(SourceOperator):
|
|
| 218 |
pass
|
| 219 |
|
| 220 |
def load_data(self) -> MultiStream:
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
iterables = self.load_iterables()
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
return iterables
|
| 227 |
-
return MultiStream.from_iterables(iterables, copying=True)
|
| 228 |
|
| 229 |
def process(self) -> MultiStream:
|
| 230 |
self._maybe_set_classification_policy()
|
|
@@ -514,9 +520,13 @@ class LoadCSV(LoadWithPandas):
|
|
| 514 |
sep: str = ","
|
| 515 |
|
| 516 |
def read_dataframe(self, file) -> pd.DataFrame:
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
|
| 522 |
def read_file(source) -> bytes:
|
|
@@ -560,32 +570,36 @@ class LoadJsonFile(LoadWithPandas):
|
|
| 560 |
data_field: Optional[str] = None
|
| 561 |
|
| 562 |
def read_dataframe(self, file) -> pd.DataFrame:
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
)
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
instances = data
|
| 577 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
raise UnitxtError(
|
| 579 |
-
|
| 580 |
)
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
if self.data_field is not None:
|
| 584 |
-
raise UnitxtError(
|
| 585 |
-
"Can not load from a specific 'data_field' when loading multiple lines (lines=True)"
|
| 586 |
-
)
|
| 587 |
-
dataframe = pd.read_json(file, lines=self.lines, **args)
|
| 588 |
-
return dataframe
|
| 589 |
|
| 590 |
|
| 591 |
class LoadFromSklearn(LazyLoader):
|
|
@@ -631,8 +645,12 @@ class LoadFromSklearn(LazyLoader):
|
|
| 631 |
dataset_id = str(self) + "_" + split
|
| 632 |
dataset = self.__class__._loader_cache.get(dataset_id, None)
|
| 633 |
if dataset is None:
|
| 634 |
-
|
| 635 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
df = pd.DataFrame([split_data["data"], targets]).T
|
| 637 |
df.columns = ["data", "target"]
|
| 638 |
dataset = df.to_dict("records")
|
|
@@ -851,18 +869,22 @@ class LoadFromIBMCloud(Loader):
|
|
| 851 |
if self.data_dir is not None
|
| 852 |
else data_file
|
| 853 |
)
|
| 854 |
-
with
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
|
| 867 |
if isinstance(self.data_files, list):
|
| 868 |
dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
|
|
@@ -946,22 +968,26 @@ class LoadFromDictionary(Loader):
|
|
| 946 |
|
| 947 |
def verify(self):
|
| 948 |
super().verify()
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
for
|
| 960 |
-
if
|
| 961 |
-
raise ValueError(
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 965 |
|
| 966 |
def _maybe_set_classification_policy(self):
|
| 967 |
self.set_default_data_classification(
|
|
@@ -1127,7 +1153,7 @@ class LoadFromAPI(Loader):
|
|
| 1127 |
chunksize: int = 100000
|
| 1128 |
loader_limit: Optional[int] = None
|
| 1129 |
streaming: bool = False
|
| 1130 |
-
api_key_env_var: Optional[str] =
|
| 1131 |
headers: Optional[Dict[str, Any]] = None
|
| 1132 |
data_field: str = "data"
|
| 1133 |
method: str = "GET"
|
|
|
|
| 66 |
|
| 67 |
from .dataclass import NonPositionalField
|
| 68 |
from .dict_utils import dict_get
|
| 69 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
|
| 70 |
from .fusion import FixedFusion
|
| 71 |
from .logging_utils import get_logger
|
| 72 |
from .operator import SourceOperator
|
|
|
|
| 90 |
|
| 91 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
| 92 |
def hf_load_dataset(path: str, *args, **kwargs):
|
| 93 |
+
with error_context(
|
| 94 |
+
stage="Raw Dataset Download",
|
| 95 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
| 96 |
+
):
|
| 97 |
+
if settings.hf_offline_datasets_path is not None:
|
| 98 |
+
path = os.path.join(settings.hf_offline_datasets_path, path)
|
| 99 |
+
try:
|
| 100 |
+
return _hf_load_dataset(
|
| 101 |
+
path,
|
| 102 |
+
*args,
|
| 103 |
+
**kwargs,
|
| 104 |
+
verification_mode="no_checks",
|
| 105 |
+
trust_remote_code=settings.allow_unverified_code,
|
| 106 |
+
download_mode="force_redownload"
|
| 107 |
+
if settings.disable_hf_datasets_cache
|
| 108 |
+
else "reuse_dataset_if_exists",
|
| 109 |
+
)
|
| 110 |
+
except ValueError as e:
|
| 111 |
+
if "trust_remote_code" in str(e):
|
| 112 |
+
raise UnitxtUnverifiedCodeError(path) from e
|
| 113 |
+
raise e # Re raise
|
| 114 |
|
| 115 |
|
| 116 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
|
|
|
| 222 |
pass
|
| 223 |
|
| 224 |
def load_data(self) -> MultiStream:
|
| 225 |
+
with error_context(
|
| 226 |
+
self,
|
| 227 |
+
stage="Data Loading",
|
| 228 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
| 229 |
+
):
|
| 230 |
iterables = self.load_iterables()
|
| 231 |
+
if isoftype(iterables, MultiStream):
|
| 232 |
+
return iterables
|
| 233 |
+
return MultiStream.from_iterables(iterables, copying=True)
|
|
|
|
|
|
|
| 234 |
|
| 235 |
def process(self) -> MultiStream:
|
| 236 |
self._maybe_set_classification_policy()
|
|
|
|
| 520 |
sep: str = ","
|
| 521 |
|
| 522 |
def read_dataframe(self, file) -> pd.DataFrame:
|
| 523 |
+
with error_context(
|
| 524 |
+
stage="Raw Dataset Loading",
|
| 525 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
| 526 |
+
):
|
| 527 |
+
return pd.read_csv(
|
| 528 |
+
file, sep=self.sep, low_memory=self.streaming, **self.get_args()
|
| 529 |
+
)
|
| 530 |
|
| 531 |
|
| 532 |
def read_file(source) -> bytes:
|
|
|
|
| 570 |
data_field: Optional[str] = None
|
| 571 |
|
| 572 |
def read_dataframe(self, file) -> pd.DataFrame:
|
| 573 |
+
with error_context(
|
| 574 |
+
stage="Raw Dataset Loading",
|
| 575 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
| 576 |
+
):
|
| 577 |
+
args = self.get_args()
|
| 578 |
+
if not self.lines:
|
| 579 |
+
data = json.loads(read_file(file))
|
| 580 |
+
if self.data_field:
|
| 581 |
+
instances = dict_get(data, self.data_field)
|
| 582 |
+
if not isoftype(instances, List[Dict[str, Any]]):
|
| 583 |
+
raise UnitxtError(
|
| 584 |
+
f"{self.data_field} of file {file} is not a list of dictionariess in LoadJsonFile loader"
|
| 585 |
+
)
|
|
|
|
| 586 |
else:
|
| 587 |
+
if isoftype(data, Dict[str, Any]):
|
| 588 |
+
instances = [data]
|
| 589 |
+
elif isoftype(data, List[Dict[str, Any]]):
|
| 590 |
+
instances = data
|
| 591 |
+
else:
|
| 592 |
+
raise UnitxtError(
|
| 593 |
+
f"data of file {file} is not dictionary or a list of dictionaries in LoadJsonFile loader"
|
| 594 |
+
)
|
| 595 |
+
dataframe = pd.DataFrame(instances)
|
| 596 |
+
else:
|
| 597 |
+
if self.data_field is not None:
|
| 598 |
raise UnitxtError(
|
| 599 |
+
"Can not load from a specific 'data_field' when loading multiple lines (lines=True)"
|
| 600 |
)
|
| 601 |
+
dataframe = pd.read_json(file, lines=self.lines, **args)
|
| 602 |
+
return dataframe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
|
| 605 |
class LoadFromSklearn(LazyLoader):
|
|
|
|
| 645 |
dataset_id = str(self) + "_" + split
|
| 646 |
dataset = self.__class__._loader_cache.get(dataset_id, None)
|
| 647 |
if dataset is None:
|
| 648 |
+
with error_context(
|
| 649 |
+
stage="Raw Dataset Loading",
|
| 650 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
| 651 |
+
):
|
| 652 |
+
split_data = self.downloader(subset=split)
|
| 653 |
+
targets = [split_data["target_names"][t] for t in split_data["target"]]
|
| 654 |
df = pd.DataFrame([split_data["data"], targets]).T
|
| 655 |
df.columns = ["data", "target"]
|
| 656 |
dataset = df.to_dict("records")
|
|
|
|
| 869 |
if self.data_dir is not None
|
| 870 |
else data_file
|
| 871 |
)
|
| 872 |
+
with error_context(
|
| 873 |
+
stage="Raw Dataset Download",
|
| 874 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
| 875 |
+
):
|
| 876 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
| 877 |
+
# Download to a temporary file in same file partition, and then do an atomic move
|
| 878 |
+
self._download_from_cos(
|
| 879 |
+
cos,
|
| 880 |
+
self.bucket_name,
|
| 881 |
+
object_key,
|
| 882 |
+
local_dir + "/" + os.path.basename(temp_file.name),
|
| 883 |
+
)
|
| 884 |
+
os.renames(
|
| 885 |
+
local_dir + "/" + os.path.basename(temp_file.name),
|
| 886 |
+
local_dir + "/" + data_file,
|
| 887 |
+
)
|
| 888 |
|
| 889 |
if isinstance(self.data_files, list):
|
| 890 |
dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
|
|
|
|
| 968 |
|
| 969 |
def verify(self):
|
| 970 |
super().verify()
|
| 971 |
+
with error_context(
|
| 972 |
+
stage="Dataset Loading",
|
| 973 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
| 974 |
+
):
|
| 975 |
+
if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
|
| 976 |
+
raise ValueError(
|
| 977 |
+
f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
|
| 978 |
+
f"Expected data should map between split name and list of instances.\n"
|
| 979 |
+
f"Received value: {self.data}\n"
|
| 980 |
+
)
|
| 981 |
+
for split in self.data.keys():
|
| 982 |
+
if len(self.data[split]) == 0:
|
| 983 |
+
raise ValueError(f"Split {split} has no instances.")
|
| 984 |
+
first_instance = self.data[split][0]
|
| 985 |
+
for instance in self.data[split]:
|
| 986 |
+
if instance.keys() != first_instance.keys():
|
| 987 |
+
raise ValueError(
|
| 988 |
+
f"Not all instances in split '{split}' have the same fields.\n"
|
| 989 |
+
f"instance {instance} has different fields different from {first_instance}"
|
| 990 |
+
)
|
| 991 |
|
| 992 |
def _maybe_set_classification_policy(self):
|
| 993 |
self.set_default_data_classification(
|
|
|
|
| 1153 |
chunksize: int = 100000
|
| 1154 |
loader_limit: Optional[int] = None
|
| 1155 |
streaming: bool = False
|
| 1156 |
+
api_key_env_var: Optional[str] = None
|
| 1157 |
headers: Optional[Dict[str, Any]] = None
|
| 1158 |
data_field: str = "data"
|
| 1159 |
method: str = "GET"
|
metric.py
CHANGED
|
@@ -56,7 +56,6 @@ from .settings_utils import get_constants
|
|
| 56 |
from .span_lableing_operators import __file__ as _
|
| 57 |
from .split_utils import __file__ as _
|
| 58 |
from .splitters import __file__ as _
|
| 59 |
-
from .sql_utils import __file__ as _
|
| 60 |
from .standard import __file__ as _
|
| 61 |
from .stream import __file__ as _
|
| 62 |
from .stream_operators import __file__ as _
|
|
@@ -65,6 +64,7 @@ from .struct_data_operators import __file__ as _
|
|
| 65 |
from .system_prompts import __file__ as _
|
| 66 |
from .task import __file__ as _
|
| 67 |
from .templates import __file__ as _
|
|
|
|
| 68 |
from .text_utils import __file__ as _
|
| 69 |
from .type_utils import __file__ as _
|
| 70 |
from .types import __file__ as _
|
|
|
|
| 56 |
from .span_lableing_operators import __file__ as _
|
| 57 |
from .split_utils import __file__ as _
|
| 58 |
from .splitters import __file__ as _
|
|
|
|
| 59 |
from .standard import __file__ as _
|
| 60 |
from .stream import __file__ as _
|
| 61 |
from .stream_operators import __file__ as _
|
|
|
|
| 64 |
from .system_prompts import __file__ as _
|
| 65 |
from .task import __file__ as _
|
| 66 |
from .templates import __file__ as _
|
| 67 |
+
from .text2sql_utils import __file__ as _
|
| 68 |
from .text_utils import __file__ as _
|
| 69 |
from .type_utils import __file__ as _
|
| 70 |
from .types import __file__ as _
|
metric_utils.py
CHANGED
|
@@ -9,7 +9,7 @@ import pandas as pd
|
|
| 9 |
from datasets import Features, Value
|
| 10 |
|
| 11 |
from .dataclass import Dataclass
|
| 12 |
-
from .error_utils import Documentation, UnitxtError
|
| 13 |
from .operator import (
|
| 14 |
InstanceOperator,
|
| 15 |
MultiStreamOperator,
|
|
@@ -36,6 +36,9 @@ from .utils import recursive_copy
|
|
| 36 |
|
| 37 |
constants = get_constants()
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def nan_mean(scores):
|
| 41 |
result = mean(score for score in scores if score == score)
|
|
@@ -56,7 +59,10 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
| 56 |
yield {**original, "prediction": prediction}
|
| 57 |
|
| 58 |
def process(
|
| 59 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 60 |
) -> MultiStream:
|
| 61 |
return MultiStream(
|
| 62 |
{
|
|
@@ -152,7 +158,7 @@ class SplitSubsetsAndGroups(MultiStreamOperator):
|
|
| 152 |
|
| 153 |
subset_stream_name = (
|
| 154 |
stream_name
|
| 155 |
-
+
|
| 156 |
+ "/".join(instance[self.subsets_field][: self.subset_depth])
|
| 157 |
)
|
| 158 |
|
|
@@ -190,7 +196,7 @@ def group_str_to_key_value(group_str):
|
|
| 190 |
|
| 191 |
@lru_cache(maxsize=None)
|
| 192 |
def stream_name_to_origin_subset_group(stream_name):
|
| 193 |
-
origin, subset_group = stream_name.split(
|
| 194 |
if "?" in subset_group:
|
| 195 |
subset, group = subset_group.split("?")
|
| 196 |
else:
|
|
@@ -734,22 +740,23 @@ def _compute(
|
|
| 734 |
predictions: List[Any],
|
| 735 |
references: Iterable,
|
| 736 |
flatten: bool = False,
|
| 737 |
-
split_name: str =
|
| 738 |
calc_confidence_intervals: bool = True,
|
| 739 |
):
|
| 740 |
_reset_env_local_catalogs()
|
| 741 |
register_all_artifacts()
|
| 742 |
recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals)
|
| 743 |
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
|
|
|
| 747 |
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
|
| 752 |
-
|
| 753 |
return EvaluationResults(stream)
|
| 754 |
|
| 755 |
|
|
|
|
| 9 |
from datasets import Features, Value
|
| 10 |
|
| 11 |
from .dataclass import Dataclass
|
| 12 |
+
from .error_utils import Documentation, UnitxtError, error_context
|
| 13 |
from .operator import (
|
| 14 |
InstanceOperator,
|
| 15 |
MultiStreamOperator,
|
|
|
|
| 36 |
|
| 37 |
constants = get_constants()
|
| 38 |
|
| 39 |
+
DEFAULT_STREAM_NAME = "all_data"
|
| 40 |
+
DEFAULT_STREAM_SUBSET_SEPARATOR = ">>"
|
| 41 |
+
|
| 42 |
|
| 43 |
def nan_mean(scores):
|
| 44 |
result = mean(score for score in scores if score == score)
|
|
|
|
| 59 |
yield {**original, "prediction": prediction}
|
| 60 |
|
| 61 |
def process(
|
| 62 |
+
self,
|
| 63 |
+
predictions: List[str],
|
| 64 |
+
references: Iterable,
|
| 65 |
+
split_name: str = DEFAULT_STREAM_NAME,
|
| 66 |
) -> MultiStream:
|
| 67 |
return MultiStream(
|
| 68 |
{
|
|
|
|
| 158 |
|
| 159 |
subset_stream_name = (
|
| 160 |
stream_name
|
| 161 |
+
+ DEFAULT_STREAM_SUBSET_SEPARATOR
|
| 162 |
+ "/".join(instance[self.subsets_field][: self.subset_depth])
|
| 163 |
)
|
| 164 |
|
|
|
|
| 196 |
|
| 197 |
@lru_cache(maxsize=None)
|
| 198 |
def stream_name_to_origin_subset_group(stream_name):
|
| 199 |
+
origin, subset_group = stream_name.split(DEFAULT_STREAM_SUBSET_SEPARATOR)
|
| 200 |
if "?" in subset_group:
|
| 201 |
subset, group = subset_group.split("?")
|
| 202 |
else:
|
|
|
|
| 740 |
predictions: List[Any],
|
| 741 |
references: Iterable,
|
| 742 |
flatten: bool = False,
|
| 743 |
+
split_name: str = DEFAULT_STREAM_NAME,
|
| 744 |
calc_confidence_intervals: bool = True,
|
| 745 |
):
|
| 746 |
_reset_env_local_catalogs()
|
| 747 |
register_all_artifacts()
|
| 748 |
recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals)
|
| 749 |
|
| 750 |
+
with error_context(stage="Metric Processing"):
|
| 751 |
+
multi_stream = recipe(
|
| 752 |
+
predictions=predictions, references=references, split_name=split_name
|
| 753 |
+
)
|
| 754 |
|
| 755 |
+
if flatten:
|
| 756 |
+
operator = FlattenInstances()
|
| 757 |
+
multi_stream = operator(multi_stream)
|
| 758 |
|
| 759 |
+
stream = multi_stream[split_name]
|
| 760 |
return EvaluationResults(stream)
|
| 761 |
|
| 762 |
|
metrics.py
CHANGED
|
@@ -8,7 +8,8 @@ import uuid
|
|
| 8 |
import warnings
|
| 9 |
from abc import ABC, abstractmethod
|
| 10 |
from collections import Counter, defaultdict
|
| 11 |
-
from dataclasses import field
|
|
|
|
| 12 |
from enum import Enum
|
| 13 |
from functools import lru_cache
|
| 14 |
from typing import (
|
|
@@ -42,7 +43,8 @@ from .dataclass import (
|
|
| 42 |
OptionalField,
|
| 43 |
)
|
| 44 |
from .deprecation_utils import deprecation
|
| 45 |
-
from .
|
|
|
|
| 46 |
from .inference import (
|
| 47 |
HFPipelineBasedInferenceEngine,
|
| 48 |
InferenceEngine,
|
|
@@ -64,6 +66,7 @@ from .operators import ArtifactFetcherMixin, Copy, FieldOperator, Set
|
|
| 64 |
from .random_utils import get_seed
|
| 65 |
from .settings_utils import get_settings
|
| 66 |
from .stream import MultiStream, Stream
|
|
|
|
| 67 |
from .type_utils import isoftype, parse_type_string, to_type_string
|
| 68 |
from .types import ToolCall
|
| 69 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
|
@@ -382,28 +385,35 @@ class MapReduceMetric(
|
|
| 382 |
return intermediates
|
| 383 |
|
| 384 |
def process(self, stream: Stream, stream_name: Optional[str] = None):
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
|
| 408 |
def compute(self, stream: Stream, stream_name: Optional[str] = None):
|
| 409 |
evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
|
|
@@ -453,6 +463,43 @@ class DictReduction(AggregationReduction[Dict[str, float]]):
|
|
| 453 |
return result
|
| 454 |
|
| 455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
class MeanReduction(DictReduction):
|
| 457 |
def reduce_list(self, lst: List[float]):
|
| 458 |
return nan_mean(lst)
|
|
@@ -468,6 +515,91 @@ class MaxReduction(DictReduction):
|
|
| 468 |
return float(nan_max(lst))
|
| 469 |
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
class ReductionInstanceMetric(
|
| 472 |
MapReduceMetric[PredictionType, IntermediateType],
|
| 473 |
Generic[PredictionType, IntermediateType],
|
|
@@ -704,6 +836,52 @@ class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
|
| 704 |
}
|
| 705 |
|
| 706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
class MetricWithConfidenceInterval(Metric):
|
| 708 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
| 709 |
# Use None to disable confidence interval computation.
|
|
@@ -954,83 +1132,88 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 954 |
process_single_instances = True
|
| 955 |
|
| 956 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
|
| 961 |
-
|
| 962 |
|
| 963 |
-
|
| 964 |
-
|
| 965 |
|
| 966 |
-
|
| 967 |
-
|
| 968 |
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
|
|
|
| 1007 |
)
|
| 1008 |
-
)
|
| 1009 |
-
|
| 1010 |
-
global_score = {"num_of_instances": len(instances)}
|
| 1011 |
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
|
|
|
| 1016 |
)
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
score_names = [global_score["score_name"]]
|
| 1024 |
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
|
| 1035 |
def _compute(
|
| 1036 |
self,
|
|
@@ -1080,96 +1263,105 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1080 |
return instance
|
| 1081 |
|
| 1082 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
predictions
|
| 1101 |
-
|
| 1102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1103 |
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
|
|
|
| 1116 |
)
|
| 1117 |
-
)
|
| 1118 |
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
-
|
| 1133 |
-
|
| 1134 |
-
|
| 1135 |
-
|
|
|
|
|
|
|
| 1136 |
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
|
| 1157 |
-
|
| 1158 |
-
|
| 1159 |
-
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
-
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
|
|
|
|
|
|
| 1169 |
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
|
| 1174 |
@abstractmethod
|
| 1175 |
def compute(
|
|
@@ -1475,91 +1667,97 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1475 |
assert isinstance(fields["score_fields"], list)
|
| 1476 |
|
| 1477 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1478 |
-
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
-
|
| 1482 |
-
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
|
| 1487 |
-
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
|
| 1491 |
-
#
|
| 1492 |
-
|
| 1493 |
-
|
| 1494 |
-
|
| 1495 |
-
|
| 1496 |
-
|
| 1497 |
-
|
| 1498 |
-
|
| 1499 |
-
|
| 1500 |
-
|
| 1501 |
-
|
| 1502 |
-
|
| 1503 |
-
|
| 1504 |
-
|
| 1505 |
-
|
| 1506 |
-
|
| 1507 |
-
|
| 1508 |
-
|
| 1509 |
-
|
| 1510 |
-
|
| 1511 |
-
|
| 1512 |
-
|
| 1513 |
-
|
| 1514 |
-
|
| 1515 |
-
|
| 1516 |
-
|
| 1517 |
-
|
| 1518 |
-
|
| 1519 |
-
|
| 1520 |
-
|
| 1521 |
-
|
| 1522 |
-
|
| 1523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1524 |
|
| 1525 |
-
|
| 1526 |
-
|
| 1527 |
-
|
| 1528 |
-
|
| 1529 |
-
|
| 1530 |
-
|
| 1531 |
-
|
| 1532 |
-
|
| 1533 |
-
|
| 1534 |
-
|
| 1535 |
-
|
| 1536 |
-
|
| 1537 |
-
|
| 1538 |
-
|
| 1539 |
-
|
| 1540 |
-
|
| 1541 |
-
|
| 1542 |
-
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
|
| 1546 |
-
|
| 1547 |
-
|
| 1548 |
-
|
| 1549 |
-
|
| 1550 |
-
|
| 1551 |
-
|
| 1552 |
-
|
| 1553 |
-
|
| 1554 |
-
|
| 1555 |
-
|
|
|
|
| 1556 |
|
| 1557 |
-
|
| 1558 |
-
|
| 1559 |
|
| 1560 |
-
|
| 1561 |
-
|
| 1562 |
-
|
| 1563 |
|
| 1564 |
def compute_instance_scores(
|
| 1565 |
self, stream: Stream, stream_name: Optional[str] = None
|
|
@@ -6436,391 +6634,102 @@ RISK_TYPE_TO_CLASS: Dict[RiskType, GraniteGuardianBase] = {
|
|
| 6436 |
}
|
| 6437 |
|
| 6438 |
|
| 6439 |
-
class
|
| 6440 |
-
|
| 6441 |
-
|
| 6442 |
-
|
| 6443 |
-
|
| 6444 |
-
"subset_non_empty_execution_result",
|
| 6445 |
-
"non_empty_gold_df",
|
| 6446 |
-
"gold_sql_runtime",
|
| 6447 |
-
"predicted_sql_runtime",
|
| 6448 |
-
"pred_to_gold_runtime_ratio",
|
| 6449 |
-
"gold_error",
|
| 6450 |
-
"predicted_error",
|
| 6451 |
-
]
|
| 6452 |
-
}
|
| 6453 |
main_score = "non_empty_execution_accuracy"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6454 |
ci_scores = [
|
| 6455 |
"execution_accuracy",
|
| 6456 |
"non_empty_execution_accuracy",
|
| 6457 |
-
"
|
|
|
|
| 6458 |
"gold_sql_runtime",
|
| 6459 |
"predicted_sql_runtime",
|
| 6460 |
]
|
| 6461 |
|
| 6462 |
-
|
| 6463 |
-
|
| 6464 |
-
|
| 6465 |
-
|
| 6466 |
-
|
| 6467 |
-
|
| 6468 |
-
|
| 6469 |
-
"""Compares two DataFrames based on row content, ignoring column names.
|
| 6470 |
-
|
| 6471 |
-
Args:
|
| 6472 |
-
df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
|
| 6473 |
-
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
| 6474 |
-
|
| 6475 |
-
Returns:
|
| 6476 |
-
True if the DataFrames have the same ordered rows (ignoring column names),
|
| 6477 |
-
False otherwise.
|
| 6478 |
-
"""
|
| 6479 |
-
df1.fillna(0, inplace=True)
|
| 6480 |
-
df2.fillna(0, inplace=True)
|
| 6481 |
-
|
| 6482 |
-
# Compare row counts first for a quick check
|
| 6483 |
-
if df1.shape != df2.shape:
|
| 6484 |
-
return False
|
| 6485 |
-
|
| 6486 |
-
# Convert DataFrames to numpy arrays of strings to handle mixed types
|
| 6487 |
-
df1_array = df1.values.astype(str)
|
| 6488 |
-
df2_array = df2.values.astype(str)
|
| 6489 |
-
|
| 6490 |
-
# Sort each row's elements (column order independence)
|
| 6491 |
-
df1_sorted_rows = np.array([np.sort(row) for row in df1_array])
|
| 6492 |
-
df2_sorted_rows = np.array([np.sort(row) for row in df2_array])
|
| 6493 |
-
|
| 6494 |
-
# Compare the sorted rows in order
|
| 6495 |
-
return np.array_equal(df1_sorted_rows, df2_sorted_rows)
|
| 6496 |
-
|
| 6497 |
-
@staticmethod
|
| 6498 |
-
def compare_dfs_ignore_colnames_unordered_rows(df1, df2):
|
| 6499 |
-
"""Compares two DataFrames based on row content, ignoring row order and column names.
|
| 6500 |
-
|
| 6501 |
-
Args:
|
| 6502 |
-
df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
|
| 6503 |
-
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
| 6504 |
-
|
| 6505 |
-
Returns:
|
| 6506 |
-
True if the DataFrames have the same content (ignoring column names and row order),
|
| 6507 |
-
False otherwise.
|
| 6508 |
-
"""
|
| 6509 |
-
# Compare shapes early on
|
| 6510 |
-
if df1.shape != df2.shape:
|
| 6511 |
-
return False
|
| 6512 |
-
|
| 6513 |
-
# Convert DataFrames to numpy arrays of strings (to handle mixed data types)
|
| 6514 |
-
df1_array = df1.values.astype(str)
|
| 6515 |
-
df2_array = df2.values.astype(str)
|
| 6516 |
-
|
| 6517 |
-
# Sort columns first, then sort rows
|
| 6518 |
-
df1_sorted = np.sort(np.sort(df1_array, axis=1), axis=0)
|
| 6519 |
-
df2_sorted = np.sort(np.sort(df2_array, axis=1), axis=0)
|
| 6520 |
-
|
| 6521 |
-
# Compare the sorted arrays
|
| 6522 |
-
return np.array_equal(df1_sorted, df2_sorted)
|
| 6523 |
-
|
| 6524 |
-
@staticmethod
|
| 6525 |
-
def compare_dfs_ignore_colnames_subset(df1, df2, ignore_row_order=True):
|
| 6526 |
-
"""Checks if the values of either DataFrame are a subset of the values in the other DataFrame.
|
| 6527 |
-
|
| 6528 |
-
Comparison is column order independent, and could optionally be row order independent.
|
| 6529 |
-
We interpret "subset" as follows:
|
| 6530 |
-
|
| 6531 |
-
- For each row in df1, there must be a matching (or superset) row in df2, i.e. the set of values
|
| 6532 |
-
in the df1 row is a subset of the set of values in that df2 row. Then do the same check in reverse.
|
| 6533 |
-
- If either condition (df1 is subset of df2 OR df2 is subset of df1) is satisfied, return True.
|
| 6534 |
-
|
| 6535 |
-
We treat an empty dataframe as a subset of nothing, while in theory is a subset of any dataframe.
|
| 6536 |
-
|
| 6537 |
-
Args:
|
| 6538 |
-
df1 (pd.DataFrame): Pandas DataFrame 1 to compare.
|
| 6539 |
-
df2 (pd.DataFrame): Pandas DataFrame 2 to compare.
|
| 6540 |
-
ignore_row_order (bool): If True, row order doesn't matter; if False, row order is respected.
|
| 6541 |
-
|
| 6542 |
-
Returns:
|
| 6543 |
-
bool: True if df1 is a subset of df2 or vice versa, based on the specified row-order condition.
|
| 6544 |
-
|
| 6545 |
-
"""
|
| 6546 |
-
df1_array = df1.values.astype(str)
|
| 6547 |
-
df2_array = df2.values.astype(str)
|
| 6548 |
-
|
| 6549 |
-
df1_sorted_rows = [np.sort(row) for row in df1_array]
|
| 6550 |
-
df2_sorted_rows = [np.sort(row) for row in df2_array]
|
| 6551 |
-
|
| 6552 |
-
def row_is_subset(r_small, r_big):
|
| 6553 |
-
"""Check if all elements of r_small are in r_big."""
|
| 6554 |
-
return set(r_small).issubset(set(r_big))
|
| 6555 |
-
|
| 6556 |
-
def df_is_subset_of_another(rows_small, rows_big, respect_order):
|
| 6557 |
-
"""Check if the rows_small is subset of rows_big under the given order condition."""
|
| 6558 |
-
if not rows_small:
|
| 6559 |
-
return False # DataFrame needs to be non-empty
|
| 6560 |
-
|
| 6561 |
-
# If row order matters:
|
| 6562 |
-
if respect_order:
|
| 6563 |
-
i, j = 0, 0
|
| 6564 |
-
while i < len(rows_small) and j < len(rows_big):
|
| 6565 |
-
if row_is_subset(rows_small[i], rows_big[j]):
|
| 6566 |
-
i += 1
|
| 6567 |
-
j += 1
|
| 6568 |
-
return i == len(rows_small)
|
| 6569 |
-
# Row order doesn't matter:
|
| 6570 |
-
matched_indices = set()
|
| 6571 |
-
for r_small in rows_small:
|
| 6572 |
-
found_match = False
|
| 6573 |
-
for idx, r_big in enumerate(rows_big):
|
| 6574 |
-
if idx not in matched_indices and row_is_subset(r_small, r_big):
|
| 6575 |
-
found_match = True
|
| 6576 |
-
matched_indices.add(idx)
|
| 6577 |
-
break
|
| 6578 |
-
if not found_match:
|
| 6579 |
-
return False
|
| 6580 |
-
return True
|
| 6581 |
-
|
| 6582 |
-
df1_sub_df2 = df_is_subset_of_another(
|
| 6583 |
-
df1_sorted_rows, df2_sorted_rows, not ignore_row_order
|
| 6584 |
-
)
|
| 6585 |
-
df2_sub_df1 = df_is_subset_of_another(
|
| 6586 |
-
df2_sorted_rows, df1_sorted_rows, not ignore_row_order
|
| 6587 |
)
|
| 6588 |
|
| 6589 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6590 |
|
| 6591 |
-
|
| 6592 |
-
|
| 6593 |
-
|
| 6594 |
-
|
| 6595 |
|
| 6596 |
-
|
| 6597 |
-
|
| 6598 |
-
|
| 6599 |
-
|
|
|
|
| 6600 |
|
| 6601 |
-
Returns:
|
| 6602 |
-
a 12-tuple of
|
| 6603 |
-
1. execution_result: if df responses match
|
| 6604 |
-
2. non_empty_execution_result: if dfs are non-empty and match
|
| 6605 |
-
3. subset_non_empty_execution_result: if non-empty dfs and one is a subset of the other
|
| 6606 |
-
4. non_empty_gold_df: if gt df is non-empty
|
| 6607 |
-
5. gold_sql_runtime: ground truth query runtime
|
| 6608 |
-
6. predicted_sql_runtime: predicted query runtime
|
| 6609 |
-
7. pred_to_gold_runtime_ratio: ratio of predicted query runtime to gt query runtime
|
| 6610 |
-
8. gold_error: if gt has an error
|
| 6611 |
-
9. predicted_error: if predicted query has an error
|
| 6612 |
-
10. ground truth dataframe
|
| 6613 |
-
11. predicted query's dataframe
|
| 6614 |
-
12. error message (if any)
|
| 6615 |
-
"""
|
| 6616 |
-
import time
|
| 6617 |
|
| 6618 |
-
|
| 6619 |
-
|
|
|
|
|
|
|
| 6620 |
|
| 6621 |
-
|
| 6622 |
|
| 6623 |
-
|
| 6624 |
-
|
| 6625 |
-
|
| 6626 |
-
|
| 6627 |
-
|
| 6628 |
-
gold_res, gold_error = func_timeout(
|
| 6629 |
-
self.sql_timeout,
|
| 6630 |
-
connector.execute_query,
|
| 6631 |
-
args=(gold_sql,),
|
| 6632 |
-
)
|
| 6633 |
-
end_time = time.perf_counter()
|
| 6634 |
-
gold_sql_runtime = end_time - start_time
|
| 6635 |
-
except FunctionTimedOut as e:
|
| 6636 |
-
pred_error = f"Timeout error executing gold SQL: {e}"
|
| 6637 |
-
logger.warning(pred_error)
|
| 6638 |
-
except Exception as e:
|
| 6639 |
-
gold_error = f"Error executing gold SQL: {e}"
|
| 6640 |
-
if gold_error is not None:
|
| 6641 |
-
return (
|
| 6642 |
-
0,
|
| 6643 |
-
0,
|
| 6644 |
-
0,
|
| 6645 |
-
0,
|
| 6646 |
-
gold_sql_runtime,
|
| 6647 |
-
0,
|
| 6648 |
-
0,
|
| 6649 |
-
0,
|
| 6650 |
-
0,
|
| 6651 |
-
"",
|
| 6652 |
-
"",
|
| 6653 |
-
gold_error,
|
| 6654 |
-
)
|
| 6655 |
|
| 6656 |
-
|
| 6657 |
-
gold_res = gold_res["results"]
|
| 6658 |
-
gold_df = pd.DataFrame(gold_res)
|
| 6659 |
-
non_empty_gold_df = 0 if gold_df.empty else 1
|
| 6660 |
|
| 6661 |
-
|
| 6662 |
-
|
| 6663 |
-
|
| 6664 |
-
|
| 6665 |
-
|
| 6666 |
-
|
| 6667 |
-
|
| 6668 |
-
|
| 6669 |
-
0,
|
| 6670 |
-
0,
|
| 6671 |
-
gold_df.to_json(),
|
| 6672 |
-
"",
|
| 6673 |
-
"",
|
| 6674 |
-
)
|
| 6675 |
-
if predicted_sql.lower().strip() == gold_sql.lower().strip():
|
| 6676 |
-
return no_execution_match_result
|
| 6677 |
-
try:
|
| 6678 |
-
if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
|
| 6679 |
-
return no_execution_match_result
|
| 6680 |
-
except Exception as e: # Catch specific exceptions if possible
|
| 6681 |
-
logger.info(
|
| 6682 |
-
f"Couldn't test equivalent_sqls: {e}. Treating as non-equivalent and going to test with the db."
|
| 6683 |
-
)
|
| 6684 |
|
| 6685 |
-
|
| 6686 |
-
|
| 6687 |
-
|
| 6688 |
-
|
| 6689 |
-
|
| 6690 |
-
pred_res, pred_error = func_timeout(
|
| 6691 |
-
self.sql_timeout,
|
| 6692 |
-
connector.execute_query,
|
| 6693 |
-
args=(predicted_sql,),
|
| 6694 |
-
)
|
| 6695 |
-
end_time = time.perf_counter()
|
| 6696 |
-
pred_sql_runtime = end_time - start_time
|
| 6697 |
-
except FunctionTimedOut as e:
|
| 6698 |
-
pred_error = f"Timeout error executing predicted SQL: {e}"
|
| 6699 |
-
logger.info(pred_error)
|
| 6700 |
-
except Exception as e:
|
| 6701 |
-
pred_error = f"Error executing predicted SQL: {e}"
|
| 6702 |
-
logger.info(pred_error)
|
| 6703 |
-
|
| 6704 |
-
pred_to_gold_runtime_ratio = (
|
| 6705 |
-
float(pred_sql_runtime) / gold_sql_runtime if gold_sql_runtime > 0 else 0
|
| 6706 |
)
|
| 6707 |
|
| 6708 |
-
|
| 6709 |
-
|
| 6710 |
-
0,
|
| 6711 |
-
0,
|
| 6712 |
-
0,
|
| 6713 |
-
0,
|
| 6714 |
-
gold_sql_runtime,
|
| 6715 |
-
pred_sql_runtime,
|
| 6716 |
-
pred_to_gold_runtime_ratio,
|
| 6717 |
-
0,
|
| 6718 |
-
1,
|
| 6719 |
-
"",
|
| 6720 |
-
"",
|
| 6721 |
-
pred_error,
|
| 6722 |
-
)
|
| 6723 |
-
|
| 6724 |
-
if isinstance(pred_res, dict) and "results" in pred_res:
|
| 6725 |
-
pred_res = pred_res["results"]
|
| 6726 |
-
predicted_df = pd.DataFrame(pred_res)
|
| 6727 |
-
|
| 6728 |
-
subset_non_empty_execution_result = 0
|
| 6729 |
-
non_empty_execution_result = 0
|
| 6730 |
-
if "ORDER BY" in gold_sql.upper():
|
| 6731 |
-
execution_result = (
|
| 6732 |
-
1
|
| 6733 |
-
if self.compare_dfs_ignore_colnames_ordered_rows(predicted_df, gold_df)
|
| 6734 |
-
else 0
|
| 6735 |
-
)
|
| 6736 |
-
if non_empty_gold_df:
|
| 6737 |
-
if execution_result == 1:
|
| 6738 |
-
non_empty_execution_result = 1
|
| 6739 |
-
if self.compare_dfs_ignore_colnames_subset(
|
| 6740 |
-
gold_df, predicted_df, ignore_row_order=False
|
| 6741 |
-
):
|
| 6742 |
-
subset_non_empty_execution_result = 1
|
| 6743 |
-
else:
|
| 6744 |
-
execution_result = (
|
| 6745 |
-
1
|
| 6746 |
-
if self.compare_dfs_ignore_colnames_unordered_rows(
|
| 6747 |
-
predicted_df, gold_df
|
| 6748 |
-
)
|
| 6749 |
-
else 0
|
| 6750 |
-
)
|
| 6751 |
-
if non_empty_gold_df:
|
| 6752 |
-
if execution_result == 1:
|
| 6753 |
-
non_empty_execution_result = 1
|
| 6754 |
-
if self.compare_dfs_ignore_colnames_subset(
|
| 6755 |
-
gold_df, predicted_df, ignore_row_order=True
|
| 6756 |
-
):
|
| 6757 |
-
subset_non_empty_execution_result = 1
|
| 6758 |
|
| 6759 |
-
|
| 6760 |
-
|
| 6761 |
-
|
| 6762 |
-
subset_non_empty_execution_result,
|
| 6763 |
-
non_empty_gold_df,
|
| 6764 |
-
gold_sql_runtime,
|
| 6765 |
-
pred_sql_runtime,
|
| 6766 |
-
pred_to_gold_runtime_ratio,
|
| 6767 |
-
0,
|
| 6768 |
-
0,
|
| 6769 |
-
gold_df.to_json(),
|
| 6770 |
-
predicted_df.to_json(),
|
| 6771 |
-
pred_error,
|
| 6772 |
)
|
| 6773 |
|
| 6774 |
-
|
| 6775 |
-
from .sql_utils import get_db_connector
|
| 6776 |
-
|
| 6777 |
-
predicted_sql = prediction
|
| 6778 |
-
execution_result: float = 0.0
|
| 6779 |
-
|
| 6780 |
-
if predicted_sql and predicted_sql.strip() != "":
|
| 6781 |
-
if not predicted_sql.startswith("SELECT") and "SELECT" in predicted_sql:
|
| 6782 |
-
predicted_sql = predicted_sql[predicted_sql.find("SELECT") :]
|
| 6783 |
-
if ";" in predicted_sql:
|
| 6784 |
-
predicted_sql = predicted_sql[: predicted_sql.find(";") + 1]
|
| 6785 |
-
|
| 6786 |
-
db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
|
| 6787 |
-
|
| 6788 |
-
logger.debug(
|
| 6789 |
-
f"Starting to get SQL execution results over DB: {task_data['db']}"
|
| 6790 |
-
)
|
| 6791 |
-
(
|
| 6792 |
-
execution_result,
|
| 6793 |
-
non_empty_execution_result,
|
| 6794 |
-
subset_non_empty_execution_result,
|
| 6795 |
-
non_empty_gold_df,
|
| 6796 |
-
gold_sql_runtime,
|
| 6797 |
-
predicted_sql_runtime,
|
| 6798 |
-
pred_to_gold_runtime_ratio,
|
| 6799 |
-
gold_error,
|
| 6800 |
-
predicted_error,
|
| 6801 |
-
gold_df_json,
|
| 6802 |
-
predicted_df_json,
|
| 6803 |
-
error_message,
|
| 6804 |
-
) = self.get_sql_execution_results(
|
| 6805 |
-
predicted_sql, references[0], db_connector
|
| 6806 |
-
)
|
| 6807 |
-
|
| 6808 |
-
result = {
|
| 6809 |
-
"execution_accuracy": float(execution_result),
|
| 6810 |
-
"non_empty_execution_accuracy": float(non_empty_execution_result),
|
| 6811 |
-
"subset_non_empty_execution_result": float(
|
| 6812 |
-
subset_non_empty_execution_result
|
| 6813 |
-
),
|
| 6814 |
-
"non_empty_gold_df": float(non_empty_gold_df),
|
| 6815 |
-
"gold_sql_runtime": float(gold_sql_runtime),
|
| 6816 |
-
"predicted_sql_runtime": float(predicted_sql_runtime),
|
| 6817 |
-
"pred_to_gold_runtime_ratio": float(pred_to_gold_runtime_ratio),
|
| 6818 |
-
"gold_error": float(gold_error),
|
| 6819 |
-
"predicted_error": float(predicted_error),
|
| 6820 |
-
"error_message": str(error_message),
|
| 6821 |
-
"gold_df_json": str(gold_df_json),
|
| 6822 |
-
"predicted_df_json": str(predicted_df_json),
|
| 6823 |
-
}
|
| 6824 |
result["score"] = result[self.main_score]
|
| 6825 |
result["score_name"] = self.main_score
|
| 6826 |
logger.debug(f"SQL Execution Accuracy Result: {result}")
|
|
@@ -6828,34 +6737,22 @@ class SQLExecutionAccuracy(InstanceMetric):
|
|
| 6828 |
|
| 6829 |
|
| 6830 |
class SQLNonExecutionAccuracy(InstanceMetric):
|
| 6831 |
-
|
| 6832 |
-
|
| 6833 |
-
|
| 6834 |
-
|
| 6835 |
-
"sqlglot_equivalence",
|
| 6836 |
-
"sqlglot_optimized_equivalence",
|
| 6837 |
-
"sqlparse_equivalence",
|
| 6838 |
-
"sql_exact_match",
|
| 6839 |
-
"sql_syntactic_equivalence",
|
| 6840 |
-
]
|
| 6841 |
-
}
|
| 6842 |
-
main_score = "sqlglot_equivalence"
|
| 6843 |
-
ci_scores = [
|
| 6844 |
-
"sqlglot_validity",
|
| 6845 |
-
"sqlparse_validity",
|
| 6846 |
-
"sqlglot_equivalence",
|
| 6847 |
-
"sqlglot_optimized_equivalence",
|
| 6848 |
-
"sqlparse_equivalence",
|
| 6849 |
-
"sql_exact_match",
|
| 6850 |
-
"sql_syntactic_equivalence",
|
| 6851 |
]
|
|
|
|
|
|
|
|
|
|
| 6852 |
|
| 6853 |
prediction_type = "Any" # string representation is compared
|
| 6854 |
|
| 6855 |
_requirements_list = ["sqlglot", "sqlparse"]
|
| 6856 |
|
| 6857 |
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
| 6858 |
-
from .
|
|
|
|
| 6859 |
is_sqlglot_parsable,
|
| 6860 |
is_sqlparse_parsable,
|
| 6861 |
sql_exact_match,
|
|
@@ -6864,48 +6761,45 @@ class SQLNonExecutionAccuracy(InstanceMetric):
|
|
| 6864 |
sqlparse_queries_equivalent,
|
| 6865 |
)
|
| 6866 |
|
| 6867 |
-
predicted_sql = prediction
|
| 6868 |
gold_sql = references[0]
|
| 6869 |
-
|
| 6870 |
-
if predicted_sql and predicted_sql.strip() != "":
|
| 6871 |
-
if not predicted_sql.startswith("SELECT") and "SELECT" in predicted_sql:
|
| 6872 |
-
predicted_sql = predicted_sql[predicted_sql.find("SELECT") :]
|
| 6873 |
-
if ";" in predicted_sql:
|
| 6874 |
-
predicted_sql = predicted_sql[: predicted_sql.find(";") + 1]
|
| 6875 |
|
| 6876 |
is_sqlglot_parsable = is_sqlglot_parsable(predicted_sql)
|
| 6877 |
is_sqlparse_parsable = is_sqlparse_parsable(predicted_sql)
|
| 6878 |
-
|
| 6879 |
-
|
| 6880 |
-
|
| 6881 |
-
|
| 6882 |
sqlglot_parsed_queries_equivalent(predicted_sql, gold_sql)
|
| 6883 |
if is_sqlglot_parsable
|
| 6884 |
else 0
|
| 6885 |
),
|
| 6886 |
-
|
| 6887 |
sqlglot_optimized_equivalence(predicted_sql, gold_sql)
|
| 6888 |
if is_sqlglot_parsable
|
| 6889 |
else 0
|
| 6890 |
),
|
| 6891 |
-
|
| 6892 |
sqlparse_queries_equivalent(predicted_sql, gold_sql)
|
| 6893 |
if is_sqlparse_parsable
|
| 6894 |
else 0
|
| 6895 |
),
|
| 6896 |
-
|
| 6897 |
-
|
| 6898 |
-
|
|
|
|
|
|
|
| 6899 |
any(
|
| 6900 |
-
|
| 6901 |
-
|
| 6902 |
-
|
| 6903 |
-
|
| 6904 |
-
|
| 6905 |
-
"sql_exact_match",
|
| 6906 |
]
|
| 6907 |
)
|
| 6908 |
)
|
|
|
|
|
|
|
| 6909 |
logger.debug(f"SQL Non Execution Accuracy Result: {result}")
|
| 6910 |
result["score"] = result[self.main_score]
|
| 6911 |
result["score_name"] = self.main_score
|
|
|
|
| 8 |
import warnings
|
| 9 |
from abc import ABC, abstractmethod
|
| 10 |
from collections import Counter, defaultdict
|
| 11 |
+
from dataclasses import asdict, field
|
| 12 |
+
from dataclasses import fields as dataclasses_fields
|
| 13 |
from enum import Enum
|
| 14 |
from functools import lru_cache
|
| 15 |
from typing import (
|
|
|
|
| 43 |
OptionalField,
|
| 44 |
)
|
| 45 |
from .deprecation_utils import deprecation
|
| 46 |
+
from .dict_utils import dict_get
|
| 47 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
|
| 48 |
from .inference import (
|
| 49 |
HFPipelineBasedInferenceEngine,
|
| 50 |
InferenceEngine,
|
|
|
|
| 66 |
from .random_utils import get_seed
|
| 67 |
from .settings_utils import get_settings
|
| 68 |
from .stream import MultiStream, Stream
|
| 69 |
+
from .text2sql_utils import SQLExecutionResult, SQLNonExecutionMetricResult
|
| 70 |
from .type_utils import isoftype, parse_type_string, to_type_string
|
| 71 |
from .types import ToolCall
|
| 72 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
|
|
|
| 385 |
return intermediates
|
| 386 |
|
| 387 |
def process(self, stream: Stream, stream_name: Optional[str] = None):
|
| 388 |
+
with error_context(
|
| 389 |
+
self,
|
| 390 |
+
stage="Evaluating Metric",
|
| 391 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
|
| 392 |
+
):
|
| 393 |
+
instances_scores, global_scores = self.compute(stream, stream_name)
|
| 394 |
+
for i, (instance, instance_scores) in enumerate(
|
| 395 |
+
zip(stream, instances_scores)
|
| 396 |
+
):
|
| 397 |
+
previous_score = instance.get("score", {"global": {}, "instance": {}})
|
| 398 |
+
|
| 399 |
+
if i == 0:
|
| 400 |
+
for key in global_scores:
|
| 401 |
+
if is_original_key(key) and key in previous_score["global"]:
|
| 402 |
+
UnitxtWarning(
|
| 403 |
+
message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
|
| 404 |
+
f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
|
| 405 |
+
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
|
| 406 |
+
f"which will yield, in this case, a score named: 'my_second_{key}')",
|
| 407 |
+
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
| 408 |
+
)
|
| 409 |
|
| 410 |
+
global_scores = {**previous_score["global"], **global_scores}
|
| 411 |
+
instance_scores = {**previous_score["instance"], **instance_scores}
|
| 412 |
|
| 413 |
+
yield {
|
| 414 |
+
**instance,
|
| 415 |
+
"score": {"global": global_scores, "instance": instance_scores},
|
| 416 |
+
}
|
| 417 |
|
| 418 |
def compute(self, stream: Stream, stream_name: Optional[str] = None):
|
| 419 |
evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
|
|
|
|
| 463 |
return result
|
| 464 |
|
| 465 |
|
| 466 |
+
class GroupReduction(AggregationReduction[Tuple[str, Dict[str, float]]]):
|
| 467 |
+
def reduce_list(self, lst: List[Tuple[str, float]]):
|
| 468 |
+
pass
|
| 469 |
+
|
| 470 |
+
def reduce(self, intermidates: Tuple[str, Dict[str, float]]):
|
| 471 |
+
lists = {}
|
| 472 |
+
for id, intermidate in intermidates:
|
| 473 |
+
for key, val in intermidate.items():
|
| 474 |
+
if key not in lists:
|
| 475 |
+
lists[key] = []
|
| 476 |
+
lists[key].append((id, val))
|
| 477 |
+
|
| 478 |
+
result = {}
|
| 479 |
+
for key, val_list in lists.items():
|
| 480 |
+
result[key] = self.reduce_list(val_list)
|
| 481 |
+
return result
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class GroupMean(GroupReduction):
|
| 485 |
+
def reduce_list(self, lst: List[Tuple[str, float]]):
|
| 486 |
+
return nan_mean([item[1] for item in lst])
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
class SequentialSuccess(GroupReduction):
|
| 490 |
+
threshold: float = 0.5
|
| 491 |
+
|
| 492 |
+
def reduce_list(self, lst: List[Tuple[str, float]]):
|
| 493 |
+
sorted_items = [item for _, item in sorted(lst, key=lambda x: x[0])]
|
| 494 |
+
successful = 0
|
| 495 |
+
for item in sorted_items:
|
| 496 |
+
if item > self.threshold:
|
| 497 |
+
successful += 1
|
| 498 |
+
else:
|
| 499 |
+
break
|
| 500 |
+
return successful / len(lst)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
class MeanReduction(DictReduction):
|
| 504 |
def reduce_list(self, lst: List[float]):
|
| 505 |
return nan_mean(lst)
|
|
|
|
| 515 |
return float(nan_max(lst))
|
| 516 |
|
| 517 |
|
| 518 |
+
class GroupMetric(
|
| 519 |
+
MapReduceMetric[PredictionType, IntermediateType],
|
| 520 |
+
Generic[PredictionType, IntermediateType],
|
| 521 |
+
):
|
| 522 |
+
main_score: str = None
|
| 523 |
+
metric: MapReduceMetric[PredictionType, IntermediateType]
|
| 524 |
+
group_id_field: str
|
| 525 |
+
item_id_field: str
|
| 526 |
+
in_group_reduction: GroupReduction = GroupMean()
|
| 527 |
+
cross_group_reduction: GroupReduction = GroupMean()
|
| 528 |
+
n_resamples = None
|
| 529 |
+
|
| 530 |
+
def _get_group_id(self, task_data) -> str:
|
| 531 |
+
return str(dict_get(task_data, self.group_id_field))
|
| 532 |
+
|
| 533 |
+
def _get_item_id(self, task_data) -> str:
|
| 534 |
+
return str(dict_get(task_data, self.item_id_field))
|
| 535 |
+
|
| 536 |
+
def prepare(self):
|
| 537 |
+
super().prepare()
|
| 538 |
+
self.main_score = self.metric.main_score
|
| 539 |
+
|
| 540 |
+
def map_stream(
|
| 541 |
+
self,
|
| 542 |
+
evaluation_inputs_stream: Generator[
|
| 543 |
+
EvaluationInput[PredictionType], None, None
|
| 544 |
+
],
|
| 545 |
+
) -> List[Tuple[IntermediateType, str, str]]:
|
| 546 |
+
group_ids: List[str] = []
|
| 547 |
+
item_ids: List[str] = []
|
| 548 |
+
|
| 549 |
+
def multi_turn_stream(
|
| 550 |
+
evaluation_inputs_stream: Generator[
|
| 551 |
+
EvaluationInput[PredictionType], None, None
|
| 552 |
+
],
|
| 553 |
+
) -> Generator[
|
| 554 |
+
Tuple[PredictionType, List[PredictionType], Dict[str, Any]], None, None
|
| 555 |
+
]:
|
| 556 |
+
for prediction, references, task_data in evaluation_inputs_stream:
|
| 557 |
+
group_ids.append(self._get_group_id(task_data))
|
| 558 |
+
item_ids.append(self._get_item_id(task_data))
|
| 559 |
+
yield prediction, references, task_data
|
| 560 |
+
|
| 561 |
+
intermediates: List[IntermediateType] = list(
|
| 562 |
+
self.metric.map_stream(multi_turn_stream(evaluation_inputs_stream))
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
return list(zip(intermediates, group_ids, item_ids))
|
| 566 |
+
|
| 567 |
+
def reduce_group(self, dialog_data: Dict[str, Dict[str, Any]]):
|
| 568 |
+
return self.in_group_reduction.reduce(list(dialog_data.items()))
|
| 569 |
+
|
| 570 |
+
def reduce_one(self, intermidate: Tuple[IntermediateType, str, str]):
|
| 571 |
+
return self.metric.reduce_one(intermidate[0])
|
| 572 |
+
|
| 573 |
+
def reduce(
|
| 574 |
+
self, intermediates: List[Tuple[IntermediateType, str, str]]
|
| 575 |
+
) -> Dict[str, Any]:
|
| 576 |
+
data: Dict[str, Dict[str, Any]] = {}
|
| 577 |
+
for intermediate, group_id, item_id in intermediates:
|
| 578 |
+
if group_id not in data:
|
| 579 |
+
data[group_id] = {}
|
| 580 |
+
data[group_id][item_id] = self.metric.reduce_one(intermediate)
|
| 581 |
+
|
| 582 |
+
group_scores: Dict[str, Dict[str, Any]] = {
|
| 583 |
+
dialog_id: self.reduce_group(dialog_data)
|
| 584 |
+
for dialog_id, dialog_data in data.items()
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
return self.cross_group_reduction.reduce(list(group_scores.items()))
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class MultiTurnMetric(
|
| 591 |
+
GroupMetric[PredictionType, IntermediateType],
|
| 592 |
+
Generic[PredictionType, IntermediateType],
|
| 593 |
+
):
|
| 594 |
+
group_id_field = "conversation/id"
|
| 595 |
+
item_id_field = "conversation/dialog"
|
| 596 |
+
|
| 597 |
+
def _get_item_id(self, task_data):
|
| 598 |
+
return "assistant_turn_" + str(
|
| 599 |
+
len(dict_get(task_data, self.item_id_field)) // 2
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
class ReductionInstanceMetric(
|
| 604 |
MapReduceMetric[PredictionType, IntermediateType],
|
| 605 |
Generic[PredictionType, IntermediateType],
|
|
|
|
| 836 |
}
|
| 837 |
|
| 838 |
|
| 839 |
+
class MultiTurnToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
| 840 |
+
"""Compares each predicted tool call with list of references tool call."""
|
| 841 |
+
|
| 842 |
+
main_score = "argument_schema_validation"
|
| 843 |
+
reduction = MeanReduction()
|
| 844 |
+
prediction_type = List[ToolCall]
|
| 845 |
+
_requirements_list = ["jsonschema-rs"]
|
| 846 |
+
|
| 847 |
+
def prepare(self):
|
| 848 |
+
super().prepare()
|
| 849 |
+
import jsonschema_rs
|
| 850 |
+
|
| 851 |
+
self._schema = jsonschema_rs
|
| 852 |
+
|
| 853 |
+
def map(
|
| 854 |
+
self,
|
| 855 |
+
prediction: List[ToolCall],
|
| 856 |
+
references: List[List[ToolCall]],
|
| 857 |
+
task_data: Dict[str, Any],
|
| 858 |
+
) -> Dict[str, float]:
|
| 859 |
+
validation_scores = []
|
| 860 |
+
for tool_call in prediction:
|
| 861 |
+
parameters = None
|
| 862 |
+
for tool in task_data["__tools__"]:
|
| 863 |
+
if tool["function"]["name"] == tool_call["name"]:
|
| 864 |
+
parameters = tool["function"]["parameters"]
|
| 865 |
+
|
| 866 |
+
if parameters is None:
|
| 867 |
+
validation_scores.append(0.0)
|
| 868 |
+
else:
|
| 869 |
+
try:
|
| 870 |
+
self._schema.validate(
|
| 871 |
+
parameters,
|
| 872 |
+
tool_call["arguments"],
|
| 873 |
+
)
|
| 874 |
+
validation_scores.append(1.0)
|
| 875 |
+
except self._schema.ValidationError:
|
| 876 |
+
validation_scores.append(0.0)
|
| 877 |
+
|
| 878 |
+
argument_schema_validation = sum(validation_scores) / len(validation_scores)
|
| 879 |
+
|
| 880 |
+
return {
|
| 881 |
+
"argument_schema_validation": argument_schema_validation,
|
| 882 |
+
}
|
| 883 |
+
|
| 884 |
+
|
| 885 |
class MetricWithConfidenceInterval(Metric):
|
| 886 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
| 887 |
# Use None to disable confidence interval computation.
|
|
|
|
| 1132 |
process_single_instances = True
|
| 1133 |
|
| 1134 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1135 |
+
with error_context(
|
| 1136 |
+
self,
|
| 1137 |
+
stage="Evaluating Metric",
|
| 1138 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
|
| 1139 |
+
):
|
| 1140 |
+
references = []
|
| 1141 |
+
predictions = []
|
| 1142 |
+
task_data = []
|
| 1143 |
|
| 1144 |
+
instances = []
|
| 1145 |
|
| 1146 |
+
for instance in stream:
|
| 1147 |
+
instance = self.verify_instance(instance)
|
| 1148 |
|
| 1149 |
+
if "score" not in instance:
|
| 1150 |
+
instance["score"] = {"global": {}, "instance": {}}
|
| 1151 |
|
| 1152 |
+
instance_references, instance_prediction = (
|
| 1153 |
+
instance["references"],
|
| 1154 |
+
instance["prediction"],
|
| 1155 |
+
)
|
| 1156 |
|
| 1157 |
+
references.append(instance_references)
|
| 1158 |
+
predictions.append(instance_prediction)
|
| 1159 |
+
instances.append(instance)
|
| 1160 |
|
| 1161 |
+
instance_task_data = (
|
| 1162 |
+
instance["task_data"] if "task_data" in instance else {}
|
| 1163 |
+
)
|
| 1164 |
+
task_data.append(instance_task_data)
|
| 1165 |
+
instance_score = None
|
| 1166 |
|
| 1167 |
+
# for backward compatibility
|
| 1168 |
+
no_score_value = np.nan
|
| 1169 |
+
if self.process_single_instances:
|
| 1170 |
+
try:
|
| 1171 |
+
instance_score = self._compute(
|
| 1172 |
+
[instance_references],
|
| 1173 |
+
[instance_prediction],
|
| 1174 |
+
[instance_task_data],
|
| 1175 |
+
)
|
| 1176 |
+
except:
|
| 1177 |
+
no_score_value = None
|
| 1178 |
+
if not instance_score:
|
| 1179 |
+
instance_score = {
|
| 1180 |
+
"score": no_score_value,
|
| 1181 |
+
"score_name": self.main_score,
|
| 1182 |
+
}
|
| 1183 |
|
| 1184 |
+
if isinstance(self.main_score, str):
|
| 1185 |
+
instance_score[self.main_score] = no_score_value
|
| 1186 |
|
| 1187 |
+
instance["score"]["instance"].update(
|
| 1188 |
+
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
| 1189 |
+
instance_score, instance["score"]["instance"]
|
| 1190 |
+
)
|
| 1191 |
)
|
| 1192 |
+
self._validate_references_and_prediction(references, predictions)
|
| 1193 |
+
global_score = {"num_of_instances": len(instances)}
|
|
|
|
| 1194 |
|
| 1195 |
+
result = self._compute(references, predictions, task_data)
|
| 1196 |
+
global_score.update(
|
| 1197 |
+
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
| 1198 |
+
result, global_score
|
| 1199 |
+
)
|
| 1200 |
)
|
| 1201 |
+
if self.ci_scores:
|
| 1202 |
+
score_names = [
|
| 1203 |
+
self._add_score_prefix(score_name) for score_name in self.ci_scores
|
| 1204 |
+
]
|
| 1205 |
+
else:
|
| 1206 |
+
score_names = [global_score["score_name"]]
|
|
|
|
| 1207 |
|
| 1208 |
+
for score_name in score_names:
|
| 1209 |
+
confidence_interval = self.compute_global_confidence_intervals(
|
| 1210 |
+
references, predictions, task_data, score_name
|
| 1211 |
+
)
|
| 1212 |
+
global_score.update(confidence_interval)
|
| 1213 |
|
| 1214 |
+
for instance in instances:
|
| 1215 |
+
self.update_and_adjust_global_score(instance, global_score)
|
| 1216 |
+
yield instance
|
| 1217 |
|
| 1218 |
def _compute(
|
| 1219 |
self,
|
|
|
|
| 1263 |
return instance
|
| 1264 |
|
| 1265 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1266 |
+
with error_context(
|
| 1267 |
+
self,
|
| 1268 |
+
stage="Evaluating Metrics",
|
| 1269 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
|
| 1270 |
+
):
|
| 1271 |
+
instances = []
|
| 1272 |
+
for instance in stream:
|
| 1273 |
+
self.verify_instance(instance)
|
| 1274 |
+
instance = self.preprocess_instance(instance)
|
| 1275 |
+
instances.append(instance)
|
| 1276 |
+
|
| 1277 |
+
predictions = [instance["prediction"] for instance in instances]
|
| 1278 |
+
references = [instance["references"] for instance in instances]
|
| 1279 |
+
task_data = [
|
| 1280 |
+
instance["task_data"] if "task_data" in instance else {}
|
| 1281 |
+
for instance in instances
|
| 1282 |
+
]
|
| 1283 |
+
self._validate_references_and_prediction(references, predictions)
|
| 1284 |
+
global_score = {"num_of_instances": len(instances)}
|
| 1285 |
+
# compute the metric over all refs and preds
|
| 1286 |
+
instance_scores = self.compute(
|
| 1287 |
+
references=references,
|
| 1288 |
+
predictions=predictions,
|
| 1289 |
+
task_data=task_data,
|
| 1290 |
+
)
|
| 1291 |
|
| 1292 |
+
# add the score and score_name fields
|
| 1293 |
+
for instance_score in instance_scores:
|
| 1294 |
+
instance_score["score"] = instance_score[self.main_score]
|
| 1295 |
+
instance_score["score_name"] = self.main_score
|
| 1296 |
|
| 1297 |
+
for instance, score in zip(instances, instance_scores):
|
| 1298 |
+
if "score" not in instance:
|
| 1299 |
+
instance["score"] = {"global": {}, "instance": {}}
|
| 1300 |
|
| 1301 |
+
instance["score"]["instance"].update(
|
| 1302 |
+
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
| 1303 |
+
score, instance["score"]["instance"]
|
| 1304 |
+
)
|
| 1305 |
)
|
|
|
|
| 1306 |
|
| 1307 |
+
for reduction, fields in self.reduction_map.items():
|
| 1308 |
+
assert (
|
| 1309 |
+
reduction in self.implemented_reductions
|
| 1310 |
+
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
|
| 1311 |
+
|
| 1312 |
+
if reduction == "mean":
|
| 1313 |
+
for field_name in fields:
|
| 1314 |
+
field_name_with_prefix = self._add_score_prefix(field_name)
|
| 1315 |
+
global_score[field_name_with_prefix] = nan_mean(
|
| 1316 |
+
[
|
| 1317 |
+
instance["score"]["instance"][field_name_with_prefix]
|
| 1318 |
+
for instance in instances
|
| 1319 |
+
]
|
| 1320 |
+
)
|
| 1321 |
+
if field_name == self.main_score:
|
| 1322 |
+
global_score["score"] = global_score[field_name_with_prefix]
|
| 1323 |
+
global_score["score_name"] = (
|
| 1324 |
+
self.score_prefix + self.main_score
|
| 1325 |
+
)
|
| 1326 |
|
| 1327 |
+
ci_fields = (
|
| 1328 |
+
list(set(self.ci_scores))
|
| 1329 |
+
if self.ci_scores is not None
|
| 1330 |
+
else [self.main_score]
|
| 1331 |
+
)
|
| 1332 |
+
ci_fields_with_prefix = [
|
| 1333 |
+
self._add_score_prefix(ci_field) for ci_field in ci_fields
|
| 1334 |
+
]
|
| 1335 |
+
confidence_interval = self.score_based_confidence_interval(
|
| 1336 |
+
instances=instances, score_names=ci_fields_with_prefix
|
| 1337 |
+
)
|
| 1338 |
+
global_score.update(confidence_interval)
|
| 1339 |
+
if reduction == "weighted_win_rate":
|
| 1340 |
+
for field_name in fields:
|
| 1341 |
+
field_name_with_prefix = self._add_score_prefix(field_name)
|
| 1342 |
+
total_battles = 0
|
| 1343 |
+
wins = 0
|
| 1344 |
+
for instance in instances:
|
| 1345 |
+
s = instance["score"]["instance"][field_name_with_prefix]
|
| 1346 |
+
if s > 0:
|
| 1347 |
+
total_battles += s
|
| 1348 |
+
wins += s
|
| 1349 |
+
elif s < 0:
|
| 1350 |
+
total_battles += abs(s)
|
| 1351 |
+
else:
|
| 1352 |
+
total_battles += 2
|
| 1353 |
+
wins += 1
|
| 1354 |
+
|
| 1355 |
+
global_score[field_name_with_prefix] = wins / total_battles
|
| 1356 |
+
if field_name == self.main_score:
|
| 1357 |
+
global_score["score"] = global_score[field_name_with_prefix]
|
| 1358 |
+
global_score["score_name"] = (
|
| 1359 |
+
self.score_prefix + self.main_score
|
| 1360 |
+
)
|
| 1361 |
|
| 1362 |
+
for instance in instances:
|
| 1363 |
+
self.update_and_adjust_global_score(instance, global_score)
|
| 1364 |
+
yield instance
|
| 1365 |
|
| 1366 |
@abstractmethod
|
| 1367 |
def compute(
|
|
|
|
| 1667 |
assert isinstance(fields["score_fields"], list)
|
| 1668 |
|
| 1669 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1670 |
+
with error_context(
|
| 1671 |
+
self,
|
| 1672 |
+
stage="Evaluating Metrics",
|
| 1673 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_metric.html",
|
| 1674 |
+
):
|
| 1675 |
+
instance_scores = self.compute_instance_scores(stream)
|
| 1676 |
+
global_score = {"num_of_instances": len(instance_scores)}
|
| 1677 |
+
for reduction_type, reduction_params in self.reduction_map.items():
|
| 1678 |
+
assert (
|
| 1679 |
+
reduction_type in self.implemented_reductions
|
| 1680 |
+
), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
|
| 1681 |
+
|
| 1682 |
+
field_name_full_prefix = ""
|
| 1683 |
+
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
|
| 1684 |
+
aggregation_function = None
|
| 1685 |
+
if reduction_type == "mean":
|
| 1686 |
+
aggregation_function = self.average_item_scores
|
| 1687 |
+
reduction_fields = list(set(reduction_params))
|
| 1688 |
+
# no group reduction, so resample instances individually
|
| 1689 |
+
scores_to_resample = instance_scores
|
| 1690 |
+
elif reduction_type == "max":
|
| 1691 |
+
aggregation_function = self.max_item_scores
|
| 1692 |
+
reduction_fields = list(set(reduction_params))
|
| 1693 |
+
# no group reduction, so resample instances individually
|
| 1694 |
+
scores_to_resample = instance_scores
|
| 1695 |
+
elif reduction_type == "group_mean":
|
| 1696 |
+
aggregation_function = self.average_item_scores
|
| 1697 |
+
self._validate_group_mean_reduction()
|
| 1698 |
+
reduction_fields = (
|
| 1699 |
+
[self.main_score]
|
| 1700 |
+
if "score_fields" not in reduction_params
|
| 1701 |
+
else list(set(reduction_params["score_fields"]))
|
| 1702 |
+
)
|
| 1703 |
+
aggregation_function_name = str(reduction_params["agg_func"][0])
|
| 1704 |
+
field_name_full_prefix = "group_" + aggregation_function_name + "_"
|
| 1705 |
+
do_resample_as_group = reduction_params["agg_func"][2]
|
| 1706 |
+
if do_resample_as_group:
|
| 1707 |
+
# append fixed_ to name because resamples the groups as fixed units
|
| 1708 |
+
field_name_full_prefix = "fixed_" + field_name_full_prefix
|
| 1709 |
+
(
|
| 1710 |
+
scores_to_resample,
|
| 1711 |
+
aggregation_function,
|
| 1712 |
+
) = self._set_up_group_mean_aggregation(
|
| 1713 |
+
instance_scores,
|
| 1714 |
+
reduction_params,
|
| 1715 |
+
reduction_fields,
|
| 1716 |
+
)
|
| 1717 |
+
else:
|
| 1718 |
+
raise ValueError(
|
| 1719 |
+
f"Reduction {reduction_type} is not supported, please specify a valid reduction method in reduction_map {self.reduction_map}."
|
| 1720 |
+
)
|
| 1721 |
|
| 1722 |
+
# calculate global scores for each reduction field
|
| 1723 |
+
for field_name in reduction_fields:
|
| 1724 |
+
field_name_full = (
|
| 1725 |
+
field_name_full_prefix + self.score_prefix + field_name
|
| 1726 |
+
)
|
| 1727 |
+
# if group resampling (3rd element of agg_func parameter) is True, then
|
| 1728 |
+
# 1. scores_to_resample are the group scores, and
|
| 1729 |
+
# 2. aggregation_function is to take the raw mean
|
| 1730 |
+
# if no group resampling (3rd element of agg_func parameter) is False, then
|
| 1731 |
+
# 1. scores_to_resample are the original instance scores, and
|
| 1732 |
+
# 2. aggregation_function is to apply the group aggregation from the instance scores
|
| 1733 |
+
# either way, the application of aggregation_function to scores_to_resample yields the global score
|
| 1734 |
+
global_score[field_name_full] = aggregation_function(
|
| 1735 |
+
scores_to_resample, self.score_prefix + field_name
|
| 1736 |
+
)
|
| 1737 |
+
if field_name == self.main_score:
|
| 1738 |
+
global_score["score"] = global_score[field_name_full]
|
| 1739 |
+
global_score["score_name"] = field_name_full
|
| 1740 |
+
|
| 1741 |
+
# need to specify which fields should have CIs calculated for them through ci_scores
|
| 1742 |
+
# (will not automatically calculate CIs for fields in reduction map)
|
| 1743 |
+
if self.ci_scores is not None:
|
| 1744 |
+
confidence_interval = self.score_based_confidence_interval(
|
| 1745 |
+
instances=scores_to_resample,
|
| 1746 |
+
score_names=[
|
| 1747 |
+
self.score_prefix + ci_score
|
| 1748 |
+
for ci_score in set(self.ci_scores)
|
| 1749 |
+
],
|
| 1750 |
+
ci_score_prefix=field_name_full_prefix,
|
| 1751 |
+
aggregation_func=aggregation_function,
|
| 1752 |
+
)
|
| 1753 |
+
global_score.update(confidence_interval)
|
| 1754 |
|
| 1755 |
+
for instance in instance_scores:
|
| 1756 |
+
self.update_and_adjust_global_score(instance, global_score)
|
| 1757 |
|
| 1758 |
+
for i, instance in enumerate(stream):
|
| 1759 |
+
instance["score"] = recursive_copy(instance_scores[i]["score"])
|
| 1760 |
+
yield instance
|
| 1761 |
|
| 1762 |
def compute_instance_scores(
|
| 1763 |
self, stream: Stream, stream_name: Optional[str] = None
|
|
|
|
| 6634 |
}
|
| 6635 |
|
| 6636 |
|
| 6637 |
+
class SQLExecutionLogicAccuracy(InstanceMetric):
|
| 6638 |
+
sql_timeout: float = 60.0
|
| 6639 |
+
prediction_type = "Any"
|
| 6640 |
+
_requirements_list = ["sqlglot", "func_timeout"]
|
| 6641 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6642 |
main_score = "non_empty_execution_accuracy"
|
| 6643 |
+
|
| 6644 |
+
all_metrics = [
|
| 6645 |
+
f.name
|
| 6646 |
+
for f in dataclasses_fields(SQLExecutionResult)
|
| 6647 |
+
if isinstance(f.type, type) and f.type in (int, float)
|
| 6648 |
+
]
|
| 6649 |
+
|
| 6650 |
+
reduction_map = {"mean": all_metrics}
|
| 6651 |
+
|
| 6652 |
ci_scores = [
|
| 6653 |
"execution_accuracy",
|
| 6654 |
"non_empty_execution_accuracy",
|
| 6655 |
+
"subset_non_empty_execution_accuracy",
|
| 6656 |
+
"execution_accuracy_bird",
|
| 6657 |
"gold_sql_runtime",
|
| 6658 |
"predicted_sql_runtime",
|
| 6659 |
]
|
| 6660 |
|
| 6661 |
+
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
| 6662 |
+
from .text2sql_utils import (
|
| 6663 |
+
ALL_DIALECTS,
|
| 6664 |
+
extract_sql_from_text,
|
| 6665 |
+
get_db_connector,
|
| 6666 |
+
get_sql_execution_results,
|
| 6667 |
+
replace_select_clause,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6668 |
)
|
| 6669 |
|
| 6670 |
+
predicted_sql = extract_sql_from_text(prediction)
|
| 6671 |
+
gold_sql = references[0]
|
| 6672 |
+
dialect = task_data["db"]["db_type"]
|
| 6673 |
+
if dialect not in ALL_DIALECTS:
|
| 6674 |
+
dialect = None
|
| 6675 |
+
revised_sql = (
|
| 6676 |
+
replace_select_clause(gold_sql, predicted_sql, dialect)
|
| 6677 |
+
if gold_sql and predicted_sql
|
| 6678 |
+
else ""
|
| 6679 |
+
)
|
| 6680 |
|
| 6681 |
+
db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
|
| 6682 |
+
result_obj = get_sql_execution_results(
|
| 6683 |
+
revised_sql, gold_sql, db_connector, self.sql_timeout
|
| 6684 |
+
)
|
| 6685 |
|
| 6686 |
+
result = asdict(result_obj)
|
| 6687 |
+
result["score"] = result[self.main_score]
|
| 6688 |
+
result["score_name"] = self.main_score
|
| 6689 |
+
logger.debug(f"SQL Execution Accuracy Result: {result}")
|
| 6690 |
+
return result
|
| 6691 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6692 |
|
| 6693 |
+
class SQLExecutionAccuracy(InstanceMetric):
|
| 6694 |
+
sql_timeout: float = 60.0
|
| 6695 |
+
prediction_type = "Any"
|
| 6696 |
+
_requirements_list = ["sqlglot", "func_timeout"]
|
| 6697 |
|
| 6698 |
+
main_score = "non_empty_execution_accuracy"
|
| 6699 |
|
| 6700 |
+
all_metrics = [
|
| 6701 |
+
f.name
|
| 6702 |
+
for f in dataclasses_fields(SQLExecutionResult)
|
| 6703 |
+
if isinstance(f.type, type) and f.type in (int, float)
|
| 6704 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6705 |
|
| 6706 |
+
reduction_map = {"mean": all_metrics}
|
|
|
|
|
|
|
|
|
|
| 6707 |
|
| 6708 |
+
ci_scores = [
|
| 6709 |
+
"execution_accuracy",
|
| 6710 |
+
"non_empty_execution_accuracy",
|
| 6711 |
+
"subset_non_empty_execution_accuracy",
|
| 6712 |
+
"execution_accuracy_bird",
|
| 6713 |
+
"gold_sql_runtime",
|
| 6714 |
+
"predicted_sql_runtime",
|
| 6715 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6716 |
|
| 6717 |
+
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
| 6718 |
+
from .text2sql_utils import (
|
| 6719 |
+
extract_sql_from_text,
|
| 6720 |
+
get_db_connector,
|
| 6721 |
+
get_sql_execution_results,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6722 |
)
|
| 6723 |
|
| 6724 |
+
predicted_sql = extract_sql_from_text(prediction)
|
| 6725 |
+
gold_sql = references[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6726 |
|
| 6727 |
+
db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"])
|
| 6728 |
+
result_obj = get_sql_execution_results(
|
| 6729 |
+
predicted_sql, gold_sql, db_connector, self.sql_timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6730 |
)
|
| 6731 |
|
| 6732 |
+
result = asdict(result_obj)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6733 |
result["score"] = result[self.main_score]
|
| 6734 |
result["score_name"] = self.main_score
|
| 6735 |
logger.debug(f"SQL Execution Accuracy Result: {result}")
|
|
|
|
| 6737 |
|
| 6738 |
|
| 6739 |
class SQLNonExecutionAccuracy(InstanceMetric):
|
| 6740 |
+
all_metrics = [
|
| 6741 |
+
f.name
|
| 6742 |
+
for f in dataclasses_fields(SQLNonExecutionMetricResult)
|
| 6743 |
+
if isinstance(f.type, type) and f.type in (int, float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6744 |
]
|
| 6745 |
+
reduction_map = {"mean": all_metrics}
|
| 6746 |
+
main_score = "sqlglot_equivalence"
|
| 6747 |
+
ci_scores = all_metrics
|
| 6748 |
|
| 6749 |
prediction_type = "Any" # string representation is compared
|
| 6750 |
|
| 6751 |
_requirements_list = ["sqlglot", "sqlparse"]
|
| 6752 |
|
| 6753 |
def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict:
|
| 6754 |
+
from .text2sql_utils import (
|
| 6755 |
+
extract_sql_from_text,
|
| 6756 |
is_sqlglot_parsable,
|
| 6757 |
is_sqlparse_parsable,
|
| 6758 |
sql_exact_match,
|
|
|
|
| 6761 |
sqlparse_queries_equivalent,
|
| 6762 |
)
|
| 6763 |
|
|
|
|
| 6764 |
gold_sql = references[0]
|
| 6765 |
+
predicted_sql = extract_sql_from_text(prediction)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6766 |
|
| 6767 |
is_sqlglot_parsable = is_sqlglot_parsable(predicted_sql)
|
| 6768 |
is_sqlparse_parsable = is_sqlparse_parsable(predicted_sql)
|
| 6769 |
+
result_obj = SQLNonExecutionMetricResult(
|
| 6770 |
+
sqlglot_validity=int(is_sqlglot_parsable),
|
| 6771 |
+
sqlparse_validity=int(is_sqlparse_parsable),
|
| 6772 |
+
sqlglot_equivalence=int(
|
| 6773 |
sqlglot_parsed_queries_equivalent(predicted_sql, gold_sql)
|
| 6774 |
if is_sqlglot_parsable
|
| 6775 |
else 0
|
| 6776 |
),
|
| 6777 |
+
sqlglot_optimized_equivalence=int(
|
| 6778 |
sqlglot_optimized_equivalence(predicted_sql, gold_sql)
|
| 6779 |
if is_sqlglot_parsable
|
| 6780 |
else 0
|
| 6781 |
),
|
| 6782 |
+
sqlparse_equivalence=int(
|
| 6783 |
sqlparse_queries_equivalent(predicted_sql, gold_sql)
|
| 6784 |
if is_sqlparse_parsable
|
| 6785 |
else 0
|
| 6786 |
),
|
| 6787 |
+
sql_exact_match=int(sql_exact_match(predicted_sql, gold_sql)),
|
| 6788 |
+
sql_syntactic_equivalence=0, # will update below
|
| 6789 |
+
)
|
| 6790 |
+
|
| 6791 |
+
result_obj.sql_syntactic_equivalence = int(
|
| 6792 |
any(
|
| 6793 |
+
[
|
| 6794 |
+
result_obj.sqlglot_equivalence,
|
| 6795 |
+
result_obj.sqlglot_optimized_equivalence,
|
| 6796 |
+
result_obj.sqlparse_equivalence,
|
| 6797 |
+
result_obj.sql_exact_match,
|
|
|
|
| 6798 |
]
|
| 6799 |
)
|
| 6800 |
)
|
| 6801 |
+
|
| 6802 |
+
result = asdict(result_obj)
|
| 6803 |
logger.debug(f"SQL Non Execution Accuracy Result: {result}")
|
| 6804 |
result["score"] = result[self.main_score]
|
| 6805 |
result["score_name"] = self.main_score
|
operator.py
CHANGED
|
@@ -6,6 +6,7 @@ from pkg_resources import DistributionNotFound, VersionConflict, require
|
|
| 6 |
|
| 7 |
from .artifact import Artifact
|
| 8 |
from .dataclass import FinalField, InternalField, NonPositionalField
|
|
|
|
| 9 |
from .settings_utils import get_constants
|
| 10 |
from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
|
| 11 |
|
|
@@ -346,7 +347,8 @@ class StreamOperator(MultiStreamOperator):
|
|
| 346 |
def _process_stream(
|
| 347 |
self, stream: Stream, stream_name: Optional[str] = None
|
| 348 |
) -> Generator:
|
| 349 |
-
|
|
|
|
| 350 |
|
| 351 |
@abstractmethod
|
| 352 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
|
@@ -384,12 +386,28 @@ class PagedStreamOperator(StreamOperator):
|
|
| 384 |
self, stream: Stream, stream_name: Optional[str] = None
|
| 385 |
) -> Generator:
|
| 386 |
page = []
|
|
|
|
| 387 |
for instance in stream:
|
| 388 |
page.append(instance)
|
| 389 |
if len(page) >= self.page_size:
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
page = []
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
def _process_page(
|
| 395 |
self, page: List[Dict], stream_name: Optional[str] = None
|
|
@@ -442,17 +460,9 @@ class InstanceOperator(StreamOperator):
|
|
| 442 |
def _process_stream(
|
| 443 |
self, stream: Stream, stream_name: Optional[str] = None
|
| 444 |
) -> Generator:
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
for _index, instance in enumerate(stream):
|
| 448 |
yield self._process_instance(instance, stream_name)
|
| 449 |
-
except Exception as e:
|
| 450 |
-
if _index is None:
|
| 451 |
-
raise e
|
| 452 |
-
else:
|
| 453 |
-
raise ValueError(
|
| 454 |
-
f"Error processing instance '{_index}' from stream '{stream_name}' in {self.__class__.__name__} due to the exception above."
|
| 455 |
-
) from e
|
| 456 |
|
| 457 |
def _process_instance(
|
| 458 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
|
|
|
| 6 |
|
| 7 |
from .artifact import Artifact
|
| 8 |
from .dataclass import FinalField, InternalField, NonPositionalField
|
| 9 |
+
from .error_utils import error_context
|
| 10 |
from .settings_utils import get_constants
|
| 11 |
from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
|
| 12 |
|
|
|
|
| 347 |
def _process_stream(
|
| 348 |
self, stream: Stream, stream_name: Optional[str] = None
|
| 349 |
) -> Generator:
|
| 350 |
+
with error_context(self, stream=stream_name):
|
| 351 |
+
yield from self.process(stream, stream_name)
|
| 352 |
|
| 353 |
@abstractmethod
|
| 354 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
|
|
|
| 386 |
self, stream: Stream, stream_name: Optional[str] = None
|
| 387 |
) -> Generator:
|
| 388 |
page = []
|
| 389 |
+
page_number = 0
|
| 390 |
for instance in stream:
|
| 391 |
page.append(instance)
|
| 392 |
if len(page) >= self.page_size:
|
| 393 |
+
with error_context(
|
| 394 |
+
self,
|
| 395 |
+
stream=stream_name,
|
| 396 |
+
page=page_number,
|
| 397 |
+
page_size=len(page),
|
| 398 |
+
):
|
| 399 |
+
yield from self.process(page, stream_name)
|
| 400 |
page = []
|
| 401 |
+
page_number += 1
|
| 402 |
+
if page: # Handle any remaining instances in the last partial page
|
| 403 |
+
with error_context(
|
| 404 |
+
self,
|
| 405 |
+
stream=stream_name,
|
| 406 |
+
page=page_number,
|
| 407 |
+
page_size=len(page),
|
| 408 |
+
final_page=True,
|
| 409 |
+
):
|
| 410 |
+
yield from self._process_page(page, stream_name)
|
| 411 |
|
| 412 |
def _process_page(
|
| 413 |
self, page: List[Dict], stream_name: Optional[str] = None
|
|
|
|
| 460 |
def _process_stream(
|
| 461 |
self, stream: Stream, stream_name: Optional[str] = None
|
| 462 |
) -> Generator:
|
| 463 |
+
for _index, instance in enumerate(stream):
|
| 464 |
+
with error_context(self, stream=stream_name, instance=_index):
|
|
|
|
| 465 |
yield self._process_instance(instance, stream_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
def _process_instance(
|
| 468 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
operators.py
CHANGED
|
@@ -67,7 +67,7 @@ from .artifact import Artifact, fetch_artifact
|
|
| 67 |
from .dataclass import NonPositionalField, OptionalField
|
| 68 |
from .deprecation_utils import deprecation
|
| 69 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
| 70 |
-
from .error_utils import UnitxtError
|
| 71 |
from .generator_utils import ReusableGenerator
|
| 72 |
from .operator import (
|
| 73 |
InstanceOperator,
|
|
@@ -309,7 +309,9 @@ def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
|
|
| 309 |
if not isinstance(item, dict) and item not in value_remove
|
| 310 |
]
|
| 311 |
elif isinstance(value, dict):
|
| 312 |
-
|
|
|
|
|
|
|
| 313 |
elif value in value_remove:
|
| 314 |
keys_to_delete.append(key)
|
| 315 |
elif value in value_map:
|
|
@@ -436,6 +438,7 @@ class InstanceFieldOperator(InstanceOperator):
|
|
| 436 |
field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
|
| 437 |
use_query: Optional[bool] = None
|
| 438 |
process_every_value: bool = False
|
|
|
|
| 439 |
get_default: Any = None
|
| 440 |
not_exist_ok: bool = False
|
| 441 |
not_exist_do_nothing: bool = False
|
|
@@ -521,7 +524,7 @@ class InstanceFieldOperator(InstanceOperator):
|
|
| 521 |
) -> Dict[str, Any]:
|
| 522 |
self.verify_field_definition()
|
| 523 |
for from_field, to_field in self._field_to_field:
|
| 524 |
-
|
| 525 |
old_value = dict_get(
|
| 526 |
instance,
|
| 527 |
from_field,
|
|
@@ -532,11 +535,8 @@ class InstanceFieldOperator(InstanceOperator):
|
|
| 532 |
if self.not_exist_do_nothing:
|
| 533 |
continue
|
| 534 |
old_value = self.get_default
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
f"Failed to get '{from_field}' from instance due to the exception above."
|
| 538 |
-
) from e
|
| 539 |
-
try:
|
| 540 |
if self.process_every_value:
|
| 541 |
new_value = [
|
| 542 |
self.process_instance_value(value, instance)
|
|
@@ -544,15 +544,13 @@ class InstanceFieldOperator(InstanceOperator):
|
|
| 544 |
]
|
| 545 |
else:
|
| 546 |
new_value = self.process_instance_value(old_value, instance)
|
| 547 |
-
|
| 548 |
-
raise ValueError(
|
| 549 |
-
f"Failed to process field '{from_field}' from instance due to the exception above."
|
| 550 |
-
) from e
|
| 551 |
dict_set(
|
| 552 |
instance,
|
| 553 |
to_field,
|
| 554 |
new_value,
|
| 555 |
not_exist_ok=True,
|
|
|
|
| 556 |
)
|
| 557 |
return instance
|
| 558 |
|
|
@@ -610,11 +608,29 @@ class Rename(FieldOperator):
|
|
| 610 |
return res
|
| 611 |
|
| 612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
@deprecation(version="2.0.0", alternative=Rename)
|
| 614 |
class RenameFields(Rename):
|
| 615 |
pass
|
| 616 |
|
| 617 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
class AddConstant(FieldOperator):
|
| 619 |
"""Adds a constant, being argument 'add', to the processed value.
|
| 620 |
|
|
@@ -1200,9 +1216,10 @@ class ApplyOperatorsField(InstanceOperator):
|
|
| 1200 |
) -> Dict[str, Any]:
|
| 1201 |
operator_names = instance.get(self.operators_field)
|
| 1202 |
if operator_names is None:
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
|
|
|
|
| 1206 |
operator_names = self.default_operators
|
| 1207 |
|
| 1208 |
if isinstance(operator_names, str):
|
|
@@ -1436,7 +1453,7 @@ class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
|
|
| 1436 |
def process(
|
| 1437 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 1438 |
) -> Dict[str, Any]:
|
| 1439 |
-
instance
|
| 1440 |
return instance
|
| 1441 |
|
| 1442 |
|
|
@@ -1821,54 +1838,58 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
| 1821 |
|
| 1822 |
# to be populated only when two or more metrics
|
| 1823 |
accumulated_scores = []
|
|
|
|
|
|
|
| 1824 |
|
| 1825 |
-
|
| 1826 |
-
|
| 1827 |
-
|
| 1828 |
-
|
| 1829 |
-
raise RuntimeError(
|
| 1830 |
-
f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
|
| 1831 |
-
)
|
| 1832 |
-
|
| 1833 |
-
if isinstance(metric_names, str):
|
| 1834 |
-
metric_names = [metric_names]
|
| 1835 |
-
|
| 1836 |
-
metrics_list = []
|
| 1837 |
-
for metric_name in metric_names:
|
| 1838 |
-
metric = self.get_artifact(metric_name)
|
| 1839 |
-
if isinstance(metric, MetricsList):
|
| 1840 |
-
metrics_list.extend(list(metric.items))
|
| 1841 |
-
elif isinstance(metric, Metric):
|
| 1842 |
-
metrics_list.append(metric)
|
| 1843 |
-
else:
|
| 1844 |
-
raise ValueError(
|
| 1845 |
-
f"Operator {metric_name} must be a Metric or MetricsList"
|
| 1846 |
)
|
| 1847 |
|
| 1848 |
-
|
| 1849 |
-
|
| 1850 |
-
|
| 1851 |
-
|
| 1852 |
-
|
| 1853 |
-
|
| 1854 |
-
|
| 1855 |
-
|
| 1856 |
-
|
| 1857 |
-
|
| 1858 |
-
|
| 1859 |
-
|
| 1860 |
-
|
| 1861 |
-
|
| 1862 |
-
|
|
|
|
|
|
|
|
|
|
| 1863 |
)
|
| 1864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1865 |
|
| 1866 |
-
|
| 1867 |
|
| 1868 |
-
|
| 1869 |
-
|
| 1870 |
-
|
| 1871 |
-
|
| 1872 |
|
| 1873 |
yield from multi_stream["tmp"]
|
| 1874 |
|
|
|
|
| 67 |
from .dataclass import NonPositionalField, OptionalField
|
| 68 |
from .deprecation_utils import deprecation
|
| 69 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
| 70 |
+
from .error_utils import UnitxtError, error_context
|
| 71 |
from .generator_utils import ReusableGenerator
|
| 72 |
from .operator import (
|
| 73 |
InstanceOperator,
|
|
|
|
| 309 |
if not isinstance(item, dict) and item not in value_remove
|
| 310 |
]
|
| 311 |
elif isinstance(value, dict):
|
| 312 |
+
recursive_key_value_replace(
|
| 313 |
+
value, target_key, value_map, value_remove
|
| 314 |
+
)
|
| 315 |
elif value in value_remove:
|
| 316 |
keys_to_delete.append(key)
|
| 317 |
elif value in value_map:
|
|
|
|
| 438 |
field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
|
| 439 |
use_query: Optional[bool] = None
|
| 440 |
process_every_value: bool = False
|
| 441 |
+
set_every_value: bool = NonPositionalField(default=False)
|
| 442 |
get_default: Any = None
|
| 443 |
not_exist_ok: bool = False
|
| 444 |
not_exist_do_nothing: bool = False
|
|
|
|
| 524 |
) -> Dict[str, Any]:
|
| 525 |
self.verify_field_definition()
|
| 526 |
for from_field, to_field in self._field_to_field:
|
| 527 |
+
with error_context(self, field=from_field, action="Read Field"):
|
| 528 |
old_value = dict_get(
|
| 529 |
instance,
|
| 530 |
from_field,
|
|
|
|
| 535 |
if self.not_exist_do_nothing:
|
| 536 |
continue
|
| 537 |
old_value = self.get_default
|
| 538 |
+
|
| 539 |
+
with error_context(self, field=from_field, action="Process Field"):
|
|
|
|
|
|
|
|
|
|
| 540 |
if self.process_every_value:
|
| 541 |
new_value = [
|
| 542 |
self.process_instance_value(value, instance)
|
|
|
|
| 544 |
]
|
| 545 |
else:
|
| 546 |
new_value = self.process_instance_value(old_value, instance)
|
| 547 |
+
|
|
|
|
|
|
|
|
|
|
| 548 |
dict_set(
|
| 549 |
instance,
|
| 550 |
to_field,
|
| 551 |
new_value,
|
| 552 |
not_exist_ok=True,
|
| 553 |
+
set_multiple=self.set_every_value,
|
| 554 |
)
|
| 555 |
return instance
|
| 556 |
|
|
|
|
| 608 |
return res
|
| 609 |
|
| 610 |
|
| 611 |
+
class Move(InstanceOperator):
|
| 612 |
+
field: str
|
| 613 |
+
to_field: str
|
| 614 |
+
|
| 615 |
+
def process(
|
| 616 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 617 |
+
) -> Dict[str, Any]:
|
| 618 |
+
value = dict_get(instance, self.field)
|
| 619 |
+
dict_delete(instance, self.field)
|
| 620 |
+
dict_set(instance, self.to_field, value=value)
|
| 621 |
+
return instance
|
| 622 |
+
|
| 623 |
+
|
| 624 |
@deprecation(version="2.0.0", alternative=Rename)
|
| 625 |
class RenameFields(Rename):
|
| 626 |
pass
|
| 627 |
|
| 628 |
|
| 629 |
+
class BytesToString(FieldOperator):
|
| 630 |
+
def process_value(self, value: Any) -> Any:
|
| 631 |
+
return str(value)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
class AddConstant(FieldOperator):
|
| 635 |
"""Adds a constant, being argument 'add', to the processed value.
|
| 636 |
|
|
|
|
| 1216 |
) -> Dict[str, Any]:
|
| 1217 |
operator_names = instance.get(self.operators_field)
|
| 1218 |
if operator_names is None:
|
| 1219 |
+
if self.default_operators is None:
|
| 1220 |
+
raise ValueError(
|
| 1221 |
+
f"No operators found in field '{self.operators_field}', and no default operators provided."
|
| 1222 |
+
)
|
| 1223 |
operator_names = self.default_operators
|
| 1224 |
|
| 1225 |
if isinstance(operator_names, str):
|
|
|
|
| 1453 |
def process(
|
| 1454 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 1455 |
) -> Dict[str, Any]:
|
| 1456 |
+
dict_set(instance, self.to_field, self.compute_expression(instance))
|
| 1457 |
return instance
|
| 1458 |
|
| 1459 |
|
|
|
|
| 1838 |
|
| 1839 |
# to be populated only when two or more metrics
|
| 1840 |
accumulated_scores = []
|
| 1841 |
+
with error_context(self, stage="Load Metrics"):
|
| 1842 |
+
first_instance = stream.peek()
|
| 1843 |
|
| 1844 |
+
metric_names = first_instance.get(self.metric_field, [])
|
| 1845 |
+
if not metric_names:
|
| 1846 |
+
raise RuntimeError(
|
| 1847 |
+
f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1848 |
)
|
| 1849 |
|
| 1850 |
+
if isinstance(metric_names, str):
|
| 1851 |
+
metric_names = [metric_names]
|
| 1852 |
+
|
| 1853 |
+
metrics_list = []
|
| 1854 |
+
for metric_name in metric_names:
|
| 1855 |
+
metric = self.get_artifact(metric_name)
|
| 1856 |
+
if isinstance(metric, MetricsList):
|
| 1857 |
+
metrics_list.extend(list(metric.items))
|
| 1858 |
+
elif isinstance(metric, Metric):
|
| 1859 |
+
metrics_list.append(metric)
|
| 1860 |
+
else:
|
| 1861 |
+
raise ValueError(
|
| 1862 |
+
f"Operator {metric_name} must be a Metric or MetricsList"
|
| 1863 |
+
)
|
| 1864 |
+
with error_context(self, stage="Setup Metrics"):
|
| 1865 |
+
for metric in metrics_list:
|
| 1866 |
+
metric.set_confidence_interval_calculation(
|
| 1867 |
+
self.calc_confidence_intervals
|
| 1868 |
)
|
| 1869 |
+
# Each metric operator computes its score and then sets the main score, overwriting
|
| 1870 |
+
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
|
| 1871 |
+
# This will cause the first listed metric to run last, and the main score will be set
|
| 1872 |
+
# by the first listed metric (as desired).
|
| 1873 |
+
metrics_list = list(reversed(metrics_list))
|
| 1874 |
+
|
| 1875 |
+
for i, metric in enumerate(metrics_list):
|
| 1876 |
+
if i == 0: # first metric
|
| 1877 |
+
multi_stream = MultiStream({"tmp": stream})
|
| 1878 |
+
else: # metrics with previous scores
|
| 1879 |
+
reusable_generator = ReusableGenerator(
|
| 1880 |
+
generator=update_scores_of_stream_instances,
|
| 1881 |
+
gen_kwargs={"stream": stream, "scores": accumulated_scores},
|
| 1882 |
+
)
|
| 1883 |
+
multi_stream = MultiStream.from_generators(
|
| 1884 |
+
{"tmp": reusable_generator}
|
| 1885 |
+
)
|
| 1886 |
|
| 1887 |
+
multi_stream = metric(multi_stream)
|
| 1888 |
|
| 1889 |
+
if i < len(metrics_list) - 1: # last metric
|
| 1890 |
+
accumulated_scores = []
|
| 1891 |
+
for inst in multi_stream["tmp"]:
|
| 1892 |
+
accumulated_scores.append(recursive_copy(inst["score"]))
|
| 1893 |
|
| 1894 |
yield from multi_stream["tmp"]
|
| 1895 |
|
processors.py
CHANGED
|
@@ -98,6 +98,16 @@ class ExtractWithRegex(RegexParser):
|
|
| 98 |
return ""
|
| 99 |
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
class ListToEmptyEntitiesTuples(FieldOperator):
|
| 102 |
def process_value(self, lst: Any) -> Any:
|
| 103 |
try:
|
|
@@ -286,7 +296,7 @@ class StringOrNotString(StringEquals):
|
|
| 286 |
|
| 287 |
class ExtractMtBenchRatingJudgment(FieldOperator):
|
| 288 |
def process_value(self, text: Any) -> Any:
|
| 289 |
-
match = re.search(r"\[\[([\d]+\.?[\d]*)
|
| 290 |
try:
|
| 291 |
return float(match.group(1)) / 10
|
| 292 |
except:
|
|
|
|
| 98 |
return ""
|
| 99 |
|
| 100 |
|
| 101 |
+
class GroupDictWithRegex(FieldOperator):
|
| 102 |
+
pattern: str
|
| 103 |
+
|
| 104 |
+
def process_value(self, value: Any) -> Any:
|
| 105 |
+
match = re.match(self.pattern, value)
|
| 106 |
+
if match:
|
| 107 |
+
return match.groupdict()
|
| 108 |
+
return {}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
class ListToEmptyEntitiesTuples(FieldOperator):
|
| 112 |
def process_value(self, lst: Any) -> Any:
|
| 113 |
try:
|
|
|
|
| 296 |
|
| 297 |
class ExtractMtBenchRatingJudgment(FieldOperator):
|
| 298 |
def process_value(self, text: Any) -> Any:
|
| 299 |
+
match = re.search(r"\[\[([\s*\d]+\.?[\d]*\s*)(/\s*10)?\s*\]\]", text)
|
| 300 |
try:
|
| 301 |
return float(match.group(1)) / 10
|
| 302 |
except:
|
schema.py
CHANGED
|
@@ -59,7 +59,7 @@ def get_schema(stream_name):
|
|
| 59 |
def load_chat_source(chat_str):
|
| 60 |
chat = json.loads(chat_str)
|
| 61 |
for turn in chat:
|
| 62 |
-
if isinstance(turn["content"], list):
|
| 63 |
for content in turn["content"]:
|
| 64 |
if content["type"] == "image_url":
|
| 65 |
content["image_url"]["url"] = ImageDataString(
|
|
|
|
| 59 |
def load_chat_source(chat_str):
|
| 60 |
chat = json.loads(chat_str)
|
| 61 |
for turn in chat:
|
| 62 |
+
if "content" in turn and isinstance(turn["content"], list):
|
| 63 |
for content in turn["content"]:
|
| 64 |
if content["type"] == "image_url":
|
| 65 |
content["image_url"]["url"] = ImageDataString(
|
serializers.py
CHANGED
|
@@ -9,6 +9,7 @@ from .operators import InstanceFieldOperator
|
|
| 9 |
from .settings_utils import get_constants
|
| 10 |
from .type_utils import isoftype, to_type_string
|
| 11 |
from .types import (
|
|
|
|
| 12 |
Dialog,
|
| 13 |
Document,
|
| 14 |
Image,
|
|
@@ -75,7 +76,22 @@ class DialogSerializer(SingleTypeSerializer):
|
|
| 75 |
|
| 76 |
def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
|
| 77 |
# Convert the Dialog into a string representation, typically combining roles and content
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
class NumberSerializer(SingleTypeSerializer):
|
|
@@ -225,7 +241,7 @@ class SQLDatabaseAsSchemaSerializer(SingleTypeSerializer):
|
|
| 225 |
serialized_type = SQLDatabase
|
| 226 |
|
| 227 |
def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
|
| 228 |
-
from .
|
| 229 |
|
| 230 |
connector = get_db_connector(value["db_type"])(value)
|
| 231 |
return connector.get_table_schema()
|
|
|
|
| 9 |
from .settings_utils import get_constants
|
| 10 |
from .type_utils import isoftype, to_type_string
|
| 11 |
from .types import (
|
| 12 |
+
Conversation,
|
| 13 |
Dialog,
|
| 14 |
Document,
|
| 15 |
Image,
|
|
|
|
| 76 |
|
| 77 |
def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str:
|
| 78 |
# Convert the Dialog into a string representation, typically combining roles and content
|
| 79 |
+
turns = []
|
| 80 |
+
for turn in value:
|
| 81 |
+
turn_str = f"{turn['role']}: "
|
| 82 |
+
if "content" in turn:
|
| 83 |
+
turn_str += str(turn["content"])
|
| 84 |
+
if "tool_calls" in turn:
|
| 85 |
+
turn_str += "\n" + json.dumps(turn["tool_calls"])
|
| 86 |
+
turns.append(turn_str)
|
| 87 |
+
return "\n".join(turns)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ConversationSerializer(DialogSerializer):
|
| 91 |
+
serialized_type = Conversation
|
| 92 |
+
|
| 93 |
+
def serialize(self, value: Conversation, instance: Dict[str, Any]) -> str:
|
| 94 |
+
return super().serialize(value["dialog"], instance)
|
| 95 |
|
| 96 |
|
| 97 |
class NumberSerializer(SingleTypeSerializer):
|
|
|
|
| 241 |
serialized_type = SQLDatabase
|
| 242 |
|
| 243 |
def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str:
|
| 244 |
+
from .text2sql_utils import get_db_connector
|
| 245 |
|
| 246 |
connector = get_db_connector(value["db_type"])(value)
|
| 247 |
return connector.get_table_schema()
|
settings_utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import importlib.metadata
|
| 2 |
import importlib.util
|
| 3 |
import os
|
|
|
|
| 4 |
from contextlib import contextmanager
|
| 5 |
|
| 6 |
from .version import version
|
|
@@ -177,6 +178,9 @@ if Constants.is_uninitilized():
|
|
| 177 |
constants.dataset_url = "unitxt/data"
|
| 178 |
constants.metric_url = "unitxt/metric"
|
| 179 |
constants.version = version
|
|
|
|
|
|
|
|
|
|
| 180 |
constants.catalog_hierarchy_sep = "."
|
| 181 |
constants.env_local_catalogs_paths_sep = ":"
|
| 182 |
constants.non_registered_files = [
|
|
|
|
| 1 |
import importlib.metadata
|
| 2 |
import importlib.util
|
| 3 |
import os
|
| 4 |
+
import sys
|
| 5 |
from contextlib import contextmanager
|
| 6 |
|
| 7 |
from .version import version
|
|
|
|
| 178 |
constants.dataset_url = "unitxt/data"
|
| 179 |
constants.metric_url = "unitxt/metric"
|
| 180 |
constants.version = version
|
| 181 |
+
constants.python = (
|
| 182 |
+
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
| 183 |
+
)
|
| 184 |
constants.catalog_hierarchy_sep = "."
|
| 185 |
constants.env_local_catalogs_paths_sep = ":"
|
| 186 |
constants.non_registered_files = [
|
struct_data_operators.py
CHANGED
|
@@ -23,6 +23,7 @@ For key-value pairs, expected input format is:
|
|
| 23 |
{"key1": "value1", "key2": value2, "key3": "value3"}
|
| 24 |
"""
|
| 25 |
|
|
|
|
| 26 |
import json
|
| 27 |
import random
|
| 28 |
from abc import ABC, abstractmethod
|
|
@@ -754,11 +755,40 @@ class LoadJson(FieldOperator):
|
|
| 754 |
return json.loads(value, strict=False)
|
| 755 |
|
| 756 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
class ToolCallPostProcessor(FieldOperator):
|
| 758 |
failure_value: Any = None
|
| 759 |
allow_failure: bool = False
|
| 760 |
|
| 761 |
def process_value(self, value: str) -> ToolCall:
|
|
|
|
|
|
|
|
|
|
| 762 |
if self.allow_failure:
|
| 763 |
try:
|
| 764 |
result = json.loads(value)
|
|
@@ -776,6 +806,25 @@ class ToolCallPostProcessor(FieldOperator):
|
|
| 776 |
return result
|
| 777 |
|
| 778 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
class DumpJson(FieldOperator):
|
| 780 |
def process_value(self, value: str) -> str:
|
| 781 |
return json.dumps(value)
|
|
|
|
| 23 |
{"key1": "value1", "key2": value2, "key3": "value3"}
|
| 24 |
"""
|
| 25 |
|
| 26 |
+
import ast
|
| 27 |
import json
|
| 28 |
import random
|
| 29 |
from abc import ABC, abstractmethod
|
|
|
|
| 755 |
return json.loads(value, strict=False)
|
| 756 |
|
| 757 |
|
| 758 |
+
class PythonCallProcessor(FieldOperator):
|
| 759 |
+
def process_value(self, value: str) -> ToolCall:
|
| 760 |
+
expr = ast.parse(value, mode="eval").body
|
| 761 |
+
function = expr.func.id
|
| 762 |
+
args = {}
|
| 763 |
+
for kw in expr.keywords:
|
| 764 |
+
args[kw.arg] = ast.literal_eval(kw.value)
|
| 765 |
+
# Handle positional args, if any
|
| 766 |
+
if expr.args:
|
| 767 |
+
args["_args"] = [ast.literal_eval(arg) for arg in expr.args]
|
| 768 |
+
return {"name": function, "arguments": args}
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def extract_possible_json_str(text):
|
| 772 |
+
"""Extract potential JSON string from text by finding outermost braces/brackets."""
|
| 773 |
+
# Find first opening delimiter
|
| 774 |
+
start_positions = [pos for pos in [text.find("{"), text.find("[")] if pos != -1]
|
| 775 |
+
start = min(start_positions) if start_positions else 0
|
| 776 |
+
|
| 777 |
+
# Find last closing delimiter
|
| 778 |
+
end_positions = [pos for pos in [text.rfind("}"), text.rfind("]")] if pos != -1]
|
| 779 |
+
end = max(end_positions) if end_positions else len(text) - 1
|
| 780 |
+
|
| 781 |
+
return text[start : end + 1]
|
| 782 |
+
|
| 783 |
+
|
| 784 |
class ToolCallPostProcessor(FieldOperator):
|
| 785 |
failure_value: Any = None
|
| 786 |
allow_failure: bool = False
|
| 787 |
|
| 788 |
def process_value(self, value: str) -> ToolCall:
|
| 789 |
+
value = extract_possible_json_str(
|
| 790 |
+
value
|
| 791 |
+
) # clear tokens such as <tool_call> focusing on the call json itself
|
| 792 |
if self.allow_failure:
|
| 793 |
try:
|
| 794 |
result = json.loads(value)
|
|
|
|
| 806 |
return result
|
| 807 |
|
| 808 |
|
| 809 |
+
class MultipleToolCallPostProcessor(FieldOperator):
|
| 810 |
+
failure_value: Any = None
|
| 811 |
+
allow_failure: bool = False
|
| 812 |
+
|
| 813 |
+
def process_value(self, value: str) -> List[ToolCall]:
|
| 814 |
+
if self.allow_failure:
|
| 815 |
+
try:
|
| 816 |
+
result = json.loads(value)
|
| 817 |
+
except json.JSONDecodeError:
|
| 818 |
+
return self.failure_value
|
| 819 |
+
else:
|
| 820 |
+
result = json.loads(value, strict=False)
|
| 821 |
+
if isoftype(result, List[ToolCall]):
|
| 822 |
+
return result
|
| 823 |
+
if not isoftype(result, ToolCall):
|
| 824 |
+
return self.failure_value
|
| 825 |
+
return [result]
|
| 826 |
+
|
| 827 |
+
|
| 828 |
class DumpJson(FieldOperator):
|
| 829 |
def process_value(self, value: str) -> str:
|
| 830 |
return json.dumps(value)
|
task.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
| 3 |
|
| 4 |
from .artifact import fetch_artifact
|
| 5 |
from .deprecation_utils import deprecation
|
| 6 |
-
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
| 7 |
from .logging_utils import get_logger
|
| 8 |
from .metrics import MetricsList
|
| 9 |
from .operator import InstanceOperator
|
|
@@ -285,13 +285,18 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
| 285 |
) -> Dict[str, Any]:
|
| 286 |
instance = self.set_default_values(instance)
|
| 287 |
|
| 288 |
-
|
| 289 |
-
self
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
input_fields = {key: instance[key] for key in self.input_fields.keys()}
|
| 296 |
data_classification_policy = instance.get("data_classification_policy", [])
|
| 297 |
|
|
|
|
| 3 |
|
| 4 |
from .artifact import fetch_artifact
|
| 5 |
from .deprecation_utils import deprecation
|
| 6 |
+
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
|
| 7 |
from .logging_utils import get_logger
|
| 8 |
from .metrics import MetricsList
|
| 9 |
from .operator import InstanceOperator
|
|
|
|
| 285 |
) -> Dict[str, Any]:
|
| 286 |
instance = self.set_default_values(instance)
|
| 287 |
|
| 288 |
+
with error_context(
|
| 289 |
+
self,
|
| 290 |
+
stage="Schema Verification",
|
| 291 |
+
help="https://www.unitxt.ai/en/latest/docs/adding_task.html",
|
| 292 |
+
):
|
| 293 |
+
verify_required_schema(
|
| 294 |
+
self.input_fields,
|
| 295 |
+
instance,
|
| 296 |
+
class_name="Task",
|
| 297 |
+
id=self.__id__,
|
| 298 |
+
description=self.__description__,
|
| 299 |
+
)
|
| 300 |
input_fields = {key: instance[key] for key in self.input_fields.keys()}
|
| 301 |
data_classification_policy = instance.get("data_classification_policy", [])
|
| 302 |
|
templates.py
CHANGED
|
@@ -6,11 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
| 6 |
from .artifact import Artifact
|
| 7 |
from .collections import DictCollection, ListCollection
|
| 8 |
from .dataclass import NonPositionalField
|
| 9 |
-
from .dict_utils import dict_set
|
| 10 |
from .error_utils import Documentation, UnitxtError
|
| 11 |
from .operator import InstanceOperator, Operator
|
| 12 |
from .random_utils import new_random_generator
|
| 13 |
from .serializers import (
|
|
|
|
| 14 |
DialogSerializer,
|
| 15 |
ImageSerializer,
|
| 16 |
ListSerializer,
|
|
@@ -68,6 +69,7 @@ class Template(InstanceOperator):
|
|
| 68 |
ToolCallSerializer(),
|
| 69 |
ToolsSerializer(),
|
| 70 |
DialogSerializer(),
|
|
|
|
| 71 |
ListSerializer(),
|
| 72 |
SQLDatabaseAsSchemaSerializer(),
|
| 73 |
]
|
|
@@ -942,6 +944,16 @@ class MultiReferenceTemplate(InputOutputTemplate):
|
|
| 942 |
return target, references
|
| 943 |
|
| 944 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 945 |
def escape_chars(s, chars_to_escape):
|
| 946 |
for char in chars_to_escape:
|
| 947 |
s = s.replace(char, f"\\{char}")
|
|
|
|
| 6 |
from .artifact import Artifact
|
| 7 |
from .collections import DictCollection, ListCollection
|
| 8 |
from .dataclass import NonPositionalField
|
| 9 |
+
from .dict_utils import dict_get, dict_set
|
| 10 |
from .error_utils import Documentation, UnitxtError
|
| 11 |
from .operator import InstanceOperator, Operator
|
| 12 |
from .random_utils import new_random_generator
|
| 13 |
from .serializers import (
|
| 14 |
+
ConversationSerializer,
|
| 15 |
DialogSerializer,
|
| 16 |
ImageSerializer,
|
| 17 |
ListSerializer,
|
|
|
|
| 69 |
ToolCallSerializer(),
|
| 70 |
ToolsSerializer(),
|
| 71 |
DialogSerializer(),
|
| 72 |
+
ConversationSerializer(),
|
| 73 |
ListSerializer(),
|
| 74 |
SQLDatabaseAsSchemaSerializer(),
|
| 75 |
]
|
|
|
|
| 944 |
return target, references
|
| 945 |
|
| 946 |
|
| 947 |
+
class MultiTurnTemplate(MultiReferenceTemplate):
|
| 948 |
+
input_format = ""
|
| 949 |
+
turns_field: str
|
| 950 |
+
|
| 951 |
+
def post_process_instance(self, instance):
|
| 952 |
+
turns = dict_get(instance["input_fields"], self.turns_field)
|
| 953 |
+
instance["__turns__"] = turns
|
| 954 |
+
return super().post_process_instance(instance)
|
| 955 |
+
|
| 956 |
+
|
| 957 |
def escape_chars(s, chars_to_escape):
|
| 958 |
for char in chars_to_escape:
|
| 959 |
s = s.replace(char, f"\\{char}")
|
sql_utils.py → text2sql_utils.py
RENAMED
|
@@ -7,9 +7,13 @@ import re
|
|
| 7 |
import sqlite3
|
| 8 |
import time
|
| 9 |
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
| 10 |
from functools import lru_cache
|
| 11 |
-
from typing import Any, List, Optional
|
| 12 |
|
|
|
|
|
|
|
| 13 |
import requests
|
| 14 |
from huggingface_hub import snapshot_download
|
| 15 |
from requests.exceptions import ConnectionError, ReadTimeout
|
|
@@ -539,6 +543,17 @@ def get_db_connector(db_type: str):
|
|
| 539 |
return connector
|
| 540 |
|
| 541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
|
| 543 |
"""Returns True if sqlglot does not encounter any error, False otherwise."""
|
| 544 |
from sqlglot import parse
|
|
@@ -695,7 +710,7 @@ def extract_select_info(sql: str):
|
|
| 695 |
|
| 696 |
|
| 697 |
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
|
| 698 |
-
"""
|
| 699 |
try:
|
| 700 |
info1 = extract_select_info(sql1)
|
| 701 |
info2 = extract_select_info(sql2)
|
|
@@ -713,6 +728,7 @@ def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
|
|
| 713 |
|
| 714 |
|
| 715 |
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
|
|
|
|
| 716 |
from sqlglot import exp, parse_one
|
| 717 |
|
| 718 |
try:
|
|
@@ -754,3 +770,473 @@ def sql_exact_match(sql1: str, sql2: str) -> bool:
|
|
| 754 |
return s.upper()
|
| 755 |
|
| 756 |
return normalize_sql(sql1) == normalize_sql(sql2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import sqlite3
|
| 8 |
import time
|
| 9 |
from abc import ABC, abstractmethod
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
from functools import lru_cache
|
| 13 |
+
from typing import Any, List, Optional, Tuple
|
| 14 |
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
import requests
|
| 18 |
from huggingface_hub import snapshot_download
|
| 19 |
from requests.exceptions import ConnectionError, ReadTimeout
|
|
|
|
| 543 |
return connector
|
| 544 |
|
| 545 |
|
| 546 |
+
@dataclass
|
| 547 |
+
class SQLNonExecutionMetricResult:
|
| 548 |
+
sqlglot_validity: int # Whether SQL parses with sqlglot
|
| 549 |
+
sqlparse_validity: int # Whether SQL parses with sqlparse
|
| 550 |
+
sqlglot_equivalence: int # Semantic equivalence using sqlglot AST
|
| 551 |
+
sqlglot_optimized_equivalence: int # Equivalence after optimization via sqlglot
|
| 552 |
+
sqlparse_equivalence: int # Equivalence using sqlparse AST
|
| 553 |
+
sql_exact_match: int # Exact string match of predicted and gold SQL
|
| 554 |
+
sql_syntactic_equivalence: int # Any of the above equivalence conditions hold
|
| 555 |
+
|
| 556 |
+
|
| 557 |
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
|
| 558 |
"""Returns True if sqlglot does not encounter any error, False otherwise."""
|
| 559 |
from sqlglot import parse
|
|
|
|
| 710 |
|
| 711 |
|
| 712 |
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
|
| 713 |
+
"""Returns True if both SQL queries are naively considered equivalent."""
|
| 714 |
try:
|
| 715 |
info1 = extract_select_info(sql1)
|
| 716 |
info2 = extract_select_info(sql2)
|
|
|
|
| 728 |
|
| 729 |
|
| 730 |
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
|
| 731 |
+
"""Return True if two SQL queries match after parsing with SQLGlot."""
|
| 732 |
from sqlglot import exp, parse_one
|
| 733 |
|
| 734 |
try:
|
|
|
|
| 770 |
return s.upper()
|
| 771 |
|
| 772 |
return normalize_sql(sql1) == normalize_sql(sql2)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
@dataclass
|
| 776 |
+
class SQLExecutionResult:
|
| 777 |
+
execution_accuracy: int # Whether the predicted and gold SQL results match exactly
|
| 778 |
+
non_empty_execution_accuracy: (
|
| 779 |
+
int # Same as execution_accuracy but only if gold is non-empty
|
| 780 |
+
)
|
| 781 |
+
subset_non_empty_execution_accuracy: (
|
| 782 |
+
int # Whether predicted is a subset of gold or vice versa, non-empty only
|
| 783 |
+
)
|
| 784 |
+
execution_accuracy_bird: (
|
| 785 |
+
int # Whether the predicted SQL matches gold using BIRD evaluation logic
|
| 786 |
+
)
|
| 787 |
+
non_empty_gold_df: int # Whether the gold SQL produced a non-empty dataframe
|
| 788 |
+
gold_sql_runtime: float # Time taken to execute the gold SQL
|
| 789 |
+
predicted_sql_runtime: float # Time taken to execute the predicted SQL
|
| 790 |
+
pred_to_gold_runtime_ratio: float # Ratio of predicted runtime to gold runtime
|
| 791 |
+
gold_error: int # Whether the gold SQL had an execution error
|
| 792 |
+
predicted_error: int # Whether the predicted SQL had an execution error
|
| 793 |
+
gold_df_json: str # JSON representation of the gold SQL result dataframe
|
| 794 |
+
predicted_df_json: str # JSON representation of the predicted SQL result dataframe
|
| 795 |
+
error_message: str # Error message from predicted execution if any
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def compare_dfs_ignore_colnames_ordered_rows(
|
| 799 |
+
df1: pd.DataFrame, df2: pd.DataFrame
|
| 800 |
+
) -> bool:
|
| 801 |
+
if df1.shape != df2.shape:
|
| 802 |
+
return False
|
| 803 |
+
df1_sorted_rows = np.array([np.sort(row) for row in df1.values.astype(str)])
|
| 804 |
+
df2_sorted_rows = np.array([np.sort(row) for row in df2.values.astype(str)])
|
| 805 |
+
return np.array_equal(df1_sorted_rows, df2_sorted_rows)
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def compare_dfs_ignore_colnames_unordered_rows(
|
| 809 |
+
df1: pd.DataFrame, df2: pd.DataFrame
|
| 810 |
+
) -> bool:
|
| 811 |
+
if df1.shape != df2.shape:
|
| 812 |
+
return False
|
| 813 |
+
df1_sorted = np.sort(np.sort(df1.values.astype(str), axis=1), axis=0)
|
| 814 |
+
df2_sorted = np.sort(np.sort(df2.values.astype(str), axis=1), axis=0)
|
| 815 |
+
return np.array_equal(df1_sorted, df2_sorted)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def compare_dfs_ignore_colnames_subset(
|
| 819 |
+
df1: pd.DataFrame, df2: pd.DataFrame, ignore_row_order: bool = True
|
| 820 |
+
) -> bool:
|
| 821 |
+
"""Checks if the smaller of the two DataFrames is likely a subset of the other.
|
| 822 |
+
|
| 823 |
+
Subset comparison is column-based, to support Text2SQL evaluation for when the
|
| 824 |
+
predicted SQL dataframe has missing or additional columns. Each row is treated as
|
| 825 |
+
a multiset of (stringified) values, and the function checks if every row in the
|
| 826 |
+
smaller DataFrame (by column count) is a multiset subset of the corresponding row
|
| 827 |
+
in the larger DataFrame. When ground truth SQL does not have ORDER BY,
|
| 828 |
+
ignore_row_order can be set to True to ignore the order of rows. In this case,
|
| 829 |
+
column values are sorted before comparison. This means that there could be cases
|
| 830 |
+
where the dataframes have the exact same number of rows and column values after
|
| 831 |
+
sort are the same, but the dataframes are not actually a subset of each other.
|
| 832 |
+
This is unlikely to happen in practice, but the score is not guaranteed to be
|
| 833 |
+
100% accurate and may overestimate the accuracy.
|
| 834 |
+
|
| 835 |
+
Args:
|
| 836 |
+
df1 (pd.DataFrame): The first DataFrame to compare.
|
| 837 |
+
df2 (pd.DataFrame): The second DataFrame to compare.
|
| 838 |
+
ignore_row_order (bool, optional): If True, ignores the order of rows by
|
| 839 |
+
sorting them before comparison. Defaults to True.
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
bool: True if the smaller DataFrame (column-wise) is likely a subset of the
|
| 843 |
+
larger one, False otherwise.
|
| 844 |
+
"""
|
| 845 |
+
|
| 846 |
+
def row_to_multiset(row):
|
| 847 |
+
return Counter(str(x) for x in row)
|
| 848 |
+
|
| 849 |
+
def rows_to_multisets(df):
|
| 850 |
+
return [row_to_multiset(row) for row in df.values]
|
| 851 |
+
|
| 852 |
+
def sort_df(df):
|
| 853 |
+
sorted_df = df.copy()
|
| 854 |
+
for col in sorted_df.columns:
|
| 855 |
+
sorted_df[col] = sorted_df[col].astype(str).sort_values(ignore_index=True)
|
| 856 |
+
return sorted_df
|
| 857 |
+
|
| 858 |
+
if df1.empty or df2.empty or len(df1) != len(df2):
|
| 859 |
+
return False
|
| 860 |
+
|
| 861 |
+
subset_df, superset_df = (df1, df2) if df1.shape[1] <= df2.shape[1] else (df2, df1)
|
| 862 |
+
|
| 863 |
+
if ignore_row_order:
|
| 864 |
+
subset_df = sort_df(subset_df)
|
| 865 |
+
superset_df = sort_df(superset_df)
|
| 866 |
+
|
| 867 |
+
subset_rows = rows_to_multisets(subset_df)
|
| 868 |
+
superset_rows = rows_to_multisets(superset_df)
|
| 869 |
+
|
| 870 |
+
for r1, r2 in zip(subset_rows, superset_rows):
|
| 871 |
+
if not all(r1[k] <= r2.get(k, 0) for k in r1):
|
| 872 |
+
return False
|
| 873 |
+
return True
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
def compare_dfs_bird_eval_logic(df1: pd.DataFrame, df2: pd.DataFrame):
|
| 877 |
+
"""Check if two SQL query result sets are exactly equal, as in BIRD evaluation.
|
| 878 |
+
|
| 879 |
+
This function checks if the set of rows returned by the predicted SQL query
|
| 880 |
+
(`predicted_res`) is exactly equal to the set of rows returned by the ground truth
|
| 881 |
+
SQL query (`ground_truth_res`). This is the logic used in the original BIRD
|
| 882 |
+
evaluation code:
|
| 883 |
+
https://github.com/AlibabaResearch/DAMO-ConvAI/blob/main/bird/llm/src/evaluation.py.
|
| 884 |
+
"""
|
| 885 |
+
df1_set = {tuple(row) for row in df1.values.astype(str)}
|
| 886 |
+
df2_set = {tuple(row) for row in df2.values.astype(str)}
|
| 887 |
+
return int(df1_set == df2_set)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def compare_result_dfs(
|
| 891 |
+
gold_df: pd.DataFrame, pred_df: pd.DataFrame, gold_sql: str
|
| 892 |
+
) -> Tuple[int, int, int]:
|
| 893 |
+
"""Compares two DataFrames representing SQL query results.
|
| 894 |
+
|
| 895 |
+
Args:
|
| 896 |
+
gold_df (pd.DataFrame): The ground truth DataFrame.
|
| 897 |
+
pred_df (pd.DataFrame): The predicted DataFrame.
|
| 898 |
+
gold_sql (str): The ground truth SQL query string.
|
| 899 |
+
|
| 900 |
+
Returns:
|
| 901 |
+
Tuple[int, int, int]: A tuple containing:
|
| 902 |
+
- match (int): 1 if the predicted DataFrame matches the gold DataFrame
|
| 903 |
+
- non_empty_match (int): 1 if both DataFrames are non-empty and match,
|
| 904 |
+
0 otherwise.
|
| 905 |
+
- subset_match (int): 1 if the predicted DataFrame is a subset or
|
| 906 |
+
superset of the gold DataFrame.
|
| 907 |
+
|
| 908 |
+
Notes:
|
| 909 |
+
- The comparison ignores column names.
|
| 910 |
+
- Row order is considered only if 'ORDER BY' is present in the SQL query.
|
| 911 |
+
"""
|
| 912 |
+
subset_match = 0
|
| 913 |
+
non_empty_match = 0
|
| 914 |
+
if "ORDER BY" in gold_sql.upper():
|
| 915 |
+
match = int(compare_dfs_ignore_colnames_ordered_rows(pred_df, gold_df))
|
| 916 |
+
if not gold_df.empty and not pred_df.empty:
|
| 917 |
+
non_empty_match = match
|
| 918 |
+
if compare_dfs_ignore_colnames_subset(
|
| 919 |
+
gold_df, pred_df, ignore_row_order=False
|
| 920 |
+
):
|
| 921 |
+
subset_match = 1
|
| 922 |
+
else:
|
| 923 |
+
match = int(compare_dfs_ignore_colnames_unordered_rows(pred_df, gold_df))
|
| 924 |
+
if not gold_df.empty and not pred_df.empty:
|
| 925 |
+
non_empty_match = match
|
| 926 |
+
if compare_dfs_ignore_colnames_subset(
|
| 927 |
+
gold_df, pred_df, ignore_row_order=True
|
| 928 |
+
):
|
| 929 |
+
subset_match = 1
|
| 930 |
+
return match, non_empty_match, subset_match
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def run_query(
|
| 934 |
+
sql: str, connector, sql_timeout: float
|
| 935 |
+
) -> Tuple[Optional[pd.DataFrame], float, str]:
|
| 936 |
+
"""Executes a SQL query using the provided connector with a timeout.
|
| 937 |
+
|
| 938 |
+
Args:
|
| 939 |
+
sql (str): The SQL query string to execute.
|
| 940 |
+
connector: An object with an `execute_query` method that executes the SQL
|
| 941 |
+
query.
|
| 942 |
+
sql_timeout (float): The maximum time in seconds to allow for query
|
| 943 |
+
execution.
|
| 944 |
+
|
| 945 |
+
Returns:
|
| 946 |
+
Tuple[Optional[pd.DataFrame], float, str]:
|
| 947 |
+
- A pandas DataFrame containing the query results, or None if an error
|
| 948 |
+
occurred.
|
| 949 |
+
- The duration in seconds taken to execute the query. 0.0 if an error.
|
| 950 |
+
- An error message string if an error occurred, otherwise an empty
|
| 951 |
+
string.
|
| 952 |
+
|
| 953 |
+
Notes:
|
| 954 |
+
- If the SQL string is empty or only whitespace, returns immediately with a
|
| 955 |
+
message.
|
| 956 |
+
- If the query execution exceeds the timeout, returns a timeout error
|
| 957 |
+
message.
|
| 958 |
+
- Any other exceptions are caught and returned as error messages.
|
| 959 |
+
"""
|
| 960 |
+
import time
|
| 961 |
+
|
| 962 |
+
from func_timeout import func_timeout
|
| 963 |
+
from func_timeout.exceptions import FunctionTimedOut
|
| 964 |
+
|
| 965 |
+
if not sql.strip():
|
| 966 |
+
return None, 0.0, "No SQL query found in the prediction."
|
| 967 |
+
|
| 968 |
+
try:
|
| 969 |
+
start = time.perf_counter()
|
| 970 |
+
result, error = func_timeout(sql_timeout, connector.execute_query, args=(sql,))
|
| 971 |
+
duration = time.perf_counter() - start
|
| 972 |
+
if isinstance(result, dict) and "results" in result:
|
| 973 |
+
result = result["results"]
|
| 974 |
+
if error:
|
| 975 |
+
return None, duration, error
|
| 976 |
+
return pd.DataFrame(result), duration, ""
|
| 977 |
+
except FunctionTimedOut as e:
|
| 978 |
+
return None, 0.0, f"Timeout: {e}"
|
| 979 |
+
except Exception as e:
|
| 980 |
+
return None, 0.0, f"Error: {e}"
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
def get_sql_execution_results(
|
| 984 |
+
predicted_sql: str, gold_sql: str, connector, sql_timeout: float
|
| 985 |
+
) -> SQLExecutionResult:
|
| 986 |
+
"""Execute and compare predicted and gold SQL queries, returning execution metrics.
|
| 987 |
+
|
| 988 |
+
Args:
|
| 989 |
+
predicted_sql (str): The SQL query predicted by the model.
|
| 990 |
+
gold_sql (str): The reference (gold) SQL query.
|
| 991 |
+
connector: Database connector object used to execute the queries.
|
| 992 |
+
sql_timeout (float): Maximum time (in seconds) allowed for query execution.
|
| 993 |
+
|
| 994 |
+
Returns:
|
| 995 |
+
SQLExecutionResult: An object containing various execution metrics, including:
|
| 996 |
+
- execution_accuracy (int): 1 if predicted and gold queries produce
|
| 997 |
+
equivalent results, else 0.
|
| 998 |
+
- non_empty_execution_accuracy (int): 1 if both queries produce non-empty
|
| 999 |
+
and equivalent results, else 0.
|
| 1000 |
+
- subset_non_empty_execution_accuracy (int): 1 if predicted results are a
|
| 1001 |
+
subset or superset of gold results and non-empty, else 0. Subset
|
| 1002 |
+
comparison is column-based. This means that the predicted SQL dataframe
|
| 1003 |
+
can have missing or additional columns compared to the gold SQL dataframe.
|
| 1004 |
+
- execution_accuracy_bird (int): 1 if results match according to BIRD
|
| 1005 |
+
evaluation logic, else 0.
|
| 1006 |
+
- non_empty_gold_df (int): 1 if the gold query result is non-empty, else 0.
|
| 1007 |
+
- gold_sql_runtime (float): Execution time for the gold SQL query.
|
| 1008 |
+
- predicted_sql_runtime (float): Execution time for the predicted SQL query.
|
| 1009 |
+
- pred_to_gold_runtime_ratio (float): Ratio of predicted to gold query
|
| 1010 |
+
runtimes.
|
| 1011 |
+
- gold_error (int): 1 if the gold query failed, else 0.
|
| 1012 |
+
- predicted_error (int): 1 if the predicted query failed, else 0.
|
| 1013 |
+
- gold_df_json (str): JSON representation of the gold query result
|
| 1014 |
+
DataFrame.
|
| 1015 |
+
- predicted_df_json (str): JSON representation of the predicted query
|
| 1016 |
+
result DataFrame.
|
| 1017 |
+
- error_message (str): Error message if any query failed, else empty
|
| 1018 |
+
string.
|
| 1019 |
+
|
| 1020 |
+
Notes:
|
| 1021 |
+
- If the gold query fails, the function returns early with error details.
|
| 1022 |
+
- If the predicted query is identical or SQL-equivalent to the gold query,
|
| 1023 |
+
results are considered correct without execution.
|
| 1024 |
+
- Otherwise, both queries are executed and their results compared using
|
| 1025 |
+
multiple metrics.
|
| 1026 |
+
"""
|
| 1027 |
+
gold_df, gold_runtime, gold_error_msg = run_query(gold_sql, connector, sql_timeout)
|
| 1028 |
+
gold_error = int(bool(gold_error_msg))
|
| 1029 |
+
|
| 1030 |
+
if gold_error:
|
| 1031 |
+
return SQLExecutionResult(
|
| 1032 |
+
execution_accuracy=0,
|
| 1033 |
+
non_empty_execution_accuracy=0,
|
| 1034 |
+
subset_non_empty_execution_accuracy=0,
|
| 1035 |
+
execution_accuracy_bird=0,
|
| 1036 |
+
non_empty_gold_df=0,
|
| 1037 |
+
gold_sql_runtime=gold_runtime,
|
| 1038 |
+
predicted_sql_runtime=0,
|
| 1039 |
+
pred_to_gold_runtime_ratio=0,
|
| 1040 |
+
gold_error=gold_error,
|
| 1041 |
+
predicted_error=0,
|
| 1042 |
+
gold_df_json="",
|
| 1043 |
+
predicted_df_json="",
|
| 1044 |
+
error_message=gold_error_msg,
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
non_empty_gold_df = int(not gold_df.empty)
|
| 1048 |
+
if predicted_sql.strip().lower() == gold_sql.strip().lower():
|
| 1049 |
+
return SQLExecutionResult(
|
| 1050 |
+
execution_accuracy=1,
|
| 1051 |
+
non_empty_execution_accuracy=non_empty_gold_df,
|
| 1052 |
+
subset_non_empty_execution_accuracy=non_empty_gold_df,
|
| 1053 |
+
execution_accuracy_bird=1,
|
| 1054 |
+
non_empty_gold_df=non_empty_gold_df,
|
| 1055 |
+
gold_sql_runtime=gold_runtime,
|
| 1056 |
+
predicted_sql_runtime=0,
|
| 1057 |
+
pred_to_gold_runtime_ratio=0,
|
| 1058 |
+
gold_error=0,
|
| 1059 |
+
predicted_error=0,
|
| 1060 |
+
gold_df_json=gold_df.to_json(),
|
| 1061 |
+
predicted_df_json=gold_df.to_json(),
|
| 1062 |
+
error_message="",
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
try:
|
| 1066 |
+
if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
|
| 1067 |
+
return SQLExecutionResult(
|
| 1068 |
+
execution_accuracy=1,
|
| 1069 |
+
non_empty_execution_accuracy=non_empty_gold_df,
|
| 1070 |
+
subset_non_empty_execution_accuracy=non_empty_gold_df,
|
| 1071 |
+
execution_accuracy_bird=1,
|
| 1072 |
+
non_empty_gold_df=non_empty_gold_df,
|
| 1073 |
+
gold_sql_runtime=gold_runtime,
|
| 1074 |
+
predicted_sql_runtime=0,
|
| 1075 |
+
pred_to_gold_runtime_ratio=0,
|
| 1076 |
+
gold_error=0,
|
| 1077 |
+
predicted_error=0,
|
| 1078 |
+
gold_df_json=gold_df.to_json(),
|
| 1079 |
+
predicted_df_json=gold_df.to_json(),
|
| 1080 |
+
error_message="",
|
| 1081 |
+
)
|
| 1082 |
+
except Exception as e:
|
| 1083 |
+
logger.info(f"Could not check SQL equivalence: {e}")
|
| 1084 |
+
|
| 1085 |
+
pred_df, pred_runtime, pred_error_msg = run_query(
|
| 1086 |
+
predicted_sql, connector, sql_timeout
|
| 1087 |
+
)
|
| 1088 |
+
pred_error = 1 if pred_error_msg else 0
|
| 1089 |
+
|
| 1090 |
+
if pred_df is None:
|
| 1091 |
+
return SQLExecutionResult(
|
| 1092 |
+
execution_accuracy=0,
|
| 1093 |
+
non_empty_execution_accuracy=0,
|
| 1094 |
+
subset_non_empty_execution_accuracy=0,
|
| 1095 |
+
execution_accuracy_bird=0,
|
| 1096 |
+
non_empty_gold_df=non_empty_gold_df,
|
| 1097 |
+
gold_sql_runtime=gold_runtime,
|
| 1098 |
+
predicted_sql_runtime=pred_runtime,
|
| 1099 |
+
pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
|
| 1100 |
+
if gold_runtime > 0
|
| 1101 |
+
else 0,
|
| 1102 |
+
gold_error=0,
|
| 1103 |
+
predicted_error=pred_error,
|
| 1104 |
+
gold_df_json=gold_df.to_json(),
|
| 1105 |
+
predicted_df_json="",
|
| 1106 |
+
error_message=pred_error_msg,
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
match, non_empty_match, subset_match = compare_result_dfs(
|
| 1110 |
+
gold_df, pred_df, gold_sql
|
| 1111 |
+
)
|
| 1112 |
+
bird_match = compare_dfs_bird_eval_logic(gold_df, pred_df)
|
| 1113 |
+
|
| 1114 |
+
return SQLExecutionResult(
|
| 1115 |
+
execution_accuracy=match,
|
| 1116 |
+
non_empty_execution_accuracy=non_empty_match,
|
| 1117 |
+
subset_non_empty_execution_accuracy=subset_match,
|
| 1118 |
+
execution_accuracy_bird=bird_match,
|
| 1119 |
+
non_empty_gold_df=non_empty_gold_df,
|
| 1120 |
+
gold_sql_runtime=gold_runtime,
|
| 1121 |
+
predicted_sql_runtime=pred_runtime,
|
| 1122 |
+
pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
|
| 1123 |
+
if gold_runtime > 0
|
| 1124 |
+
else 0,
|
| 1125 |
+
gold_error=0,
|
| 1126 |
+
predicted_error=0,
|
| 1127 |
+
gold_df_json=gold_df.to_json(),
|
| 1128 |
+
predicted_df_json=pred_df.to_json(),
|
| 1129 |
+
error_message=pred_error_msg,
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
|
| 1133 |
+
def replace_select_clause(
|
| 1134 |
+
source_query: str, target_query: str, dialect: str = "postgres"
|
| 1135 |
+
) -> str:
|
| 1136 |
+
"""Replaces the SELECT clause of the target SQL query with the SELECT clause from the source SQL query.
|
| 1137 |
+
|
| 1138 |
+
Args:
|
| 1139 |
+
source_query (str): SQL query whose SELECT clause will be used.
|
| 1140 |
+
target_query (str): SQL query whose SELECT clause will be replaced.
|
| 1141 |
+
dialect (str): SQL dialect for parsing and rendering (default: "postgres").
|
| 1142 |
+
|
| 1143 |
+
Returns:
|
| 1144 |
+
str: A new SQL query with the SELECT clause of `target_query` replaced by that of `source_query`.
|
| 1145 |
+
|
| 1146 |
+
Raises:
|
| 1147 |
+
ValueError: If either query is not a valid SELECT statement.
|
| 1148 |
+
|
| 1149 |
+
Example:
|
| 1150 |
+
>>> replace_select_clause(
|
| 1151 |
+
... "SELECT id FROM employees",
|
| 1152 |
+
... "SELECT name FROM employees WHERE age > 30"
|
| 1153 |
+
... )
|
| 1154 |
+
"SELECT id FROM employees WHERE age > 30"
|
| 1155 |
+
"""
|
| 1156 |
+
from sqlglot import exp, parse_one
|
| 1157 |
+
|
| 1158 |
+
if not dialect:
|
| 1159 |
+
dialect = "postgres"
|
| 1160 |
+
|
| 1161 |
+
# Parse queries using the specified dialect
|
| 1162 |
+
source_ast = parse_one(source_query, read=dialect)
|
| 1163 |
+
target_ast = parse_one(target_query, read=dialect)
|
| 1164 |
+
|
| 1165 |
+
if not isinstance(source_ast, exp.Select) or not isinstance(target_ast, exp.Select):
|
| 1166 |
+
raise ValueError("Both queries must be valid SELECT statements.")
|
| 1167 |
+
|
| 1168 |
+
# Replace SELECT expressions in the target with those from the source
|
| 1169 |
+
target_ast.set("expressions", source_ast.expressions)
|
| 1170 |
+
|
| 1171 |
+
# Return the updated SQL string using the dialect
|
| 1172 |
+
return target_ast.sql(dialect=dialect)
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
def extract_sql_from_text(text: str) -> str:
|
| 1176 |
+
"""Extracts the first SQL query from the given text.
|
| 1177 |
+
|
| 1178 |
+
Priority:
|
| 1179 |
+
1. SQL inside fenced blocks like ```sql ... ```
|
| 1180 |
+
2. SQL starting on a new line or after a colon/label
|
| 1181 |
+
3. SQL without semicolon
|
| 1182 |
+
|
| 1183 |
+
Returns:
|
| 1184 |
+
The SQL query string, or an empty string if not found.
|
| 1185 |
+
"""
|
| 1186 |
+
# 1. Look for fenced SQL code block
|
| 1187 |
+
fenced_block_pattern = re.compile(r"```sql\s+(.*?)```", re.IGNORECASE | re.DOTALL)
|
| 1188 |
+
match = fenced_block_pattern.search(text)
|
| 1189 |
+
if match:
|
| 1190 |
+
return match.group(1).strip()
|
| 1191 |
+
|
| 1192 |
+
# 2. Inline SQL with semicolon
|
| 1193 |
+
sql_keywords = r"(?:SELECT|INSERT|UPDATE|DELETE|WITH)\s+"
|
| 1194 |
+
sql_start = (
|
| 1195 |
+
r"(?:^|\n|:\s*)" # Start of string, newline, or colon label like "Just run:"
|
| 1196 |
+
)
|
| 1197 |
+
sql_pattern = re.compile(
|
| 1198 |
+
rf"{sql_start}({sql_keywords}.*?;)", re.IGNORECASE | re.DOTALL
|
| 1199 |
+
)
|
| 1200 |
+
match = sql_pattern.search(text)
|
| 1201 |
+
if match:
|
| 1202 |
+
return match.group(1).strip()
|
| 1203 |
+
|
| 1204 |
+
# 3. Inline SQL without semicolon
|
| 1205 |
+
fallback_pattern = re.compile(
|
| 1206 |
+
rf"{sql_start}({sql_keywords}.*)", re.IGNORECASE | re.DOTALL
|
| 1207 |
+
)
|
| 1208 |
+
fallback_match = fallback_pattern.search(text)
|
| 1209 |
+
if fallback_match:
|
| 1210 |
+
return fallback_match.group(1).strip()
|
| 1211 |
+
|
| 1212 |
+
return ""
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
ALL_DIALECTS = [
|
| 1216 |
+
"Athena",
|
| 1217 |
+
"BigQuery",
|
| 1218 |
+
"ClickHouse",
|
| 1219 |
+
"Databricks",
|
| 1220 |
+
"Doris",
|
| 1221 |
+
"Drill",
|
| 1222 |
+
"Druid",
|
| 1223 |
+
"DuckDB",
|
| 1224 |
+
"Hive",
|
| 1225 |
+
"Materialize",
|
| 1226 |
+
"MySQL",
|
| 1227 |
+
"Oracle",
|
| 1228 |
+
"Postgres",
|
| 1229 |
+
"Presto",
|
| 1230 |
+
"PRQL",
|
| 1231 |
+
"Redshift",
|
| 1232 |
+
"RisingWave",
|
| 1233 |
+
"Snowflake",
|
| 1234 |
+
"Spark",
|
| 1235 |
+
"Spark2",
|
| 1236 |
+
"SQLite",
|
| 1237 |
+
"StarRocks",
|
| 1238 |
+
"Tableau",
|
| 1239 |
+
"Teradata",
|
| 1240 |
+
"Trino",
|
| 1241 |
+
"TSQL",
|
| 1242 |
+
]
|
type_utils.py
CHANGED
|
@@ -503,9 +503,25 @@ def isoftype(object, typing_type):
|
|
| 503 |
if is_typed_dict(typing_type):
|
| 504 |
if not isinstance(object, dict):
|
| 505 |
return False
|
|
|
|
|
|
|
| 506 |
for key, expected_type in typing_type.__annotations__.items():
|
| 507 |
-
|
| 508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
return True
|
| 510 |
|
| 511 |
if typing_type == typing.Any:
|
|
|
|
| 503 |
if is_typed_dict(typing_type):
|
| 504 |
if not isinstance(object, dict):
|
| 505 |
return False
|
| 506 |
+
|
| 507 |
+
# Only support total=True, check each field
|
| 508 |
for key, expected_type in typing_type.__annotations__.items():
|
| 509 |
+
# Check if field is Optional (Union with None)
|
| 510 |
+
is_optional = (
|
| 511 |
+
hasattr(expected_type, "__origin__")
|
| 512 |
+
and expected_type.__origin__ is Union
|
| 513 |
+
and type(None) in expected_type.__args__
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
if key not in object:
|
| 517 |
+
# Field is missing - only allowed if it's Optional
|
| 518 |
+
if not is_optional:
|
| 519 |
+
return False
|
| 520 |
+
else:
|
| 521 |
+
# Field is present - check type
|
| 522 |
+
if not isoftype(object[key], expected_type):
|
| 523 |
+
return False
|
| 524 |
+
|
| 525 |
return True
|
| 526 |
|
| 527 |
if typing_type == typing.Any:
|
types.py
CHANGED
|
@@ -6,8 +6,52 @@ Text = NewType("Text", str)
|
|
| 6 |
Number = NewType("Number", Union[float, int])
|
| 7 |
|
| 8 |
|
| 9 |
-
class
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
content: Text
|
| 12 |
|
| 13 |
|
|
@@ -18,7 +62,12 @@ class RagResponse(TypedDict):
|
|
| 18 |
is_answerable: bool
|
| 19 |
|
| 20 |
|
| 21 |
-
Dialog = NewType("Dialog", List[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class Image(TypedDict):
|
|
@@ -52,36 +101,17 @@ class SQLDatabase(TypedDict):
|
|
| 52 |
data: Optional[Dict[str, Dict]]
|
| 53 |
|
| 54 |
|
| 55 |
-
class JsonSchema:
|
| 56 |
-
@classmethod
|
| 57 |
-
def __verify_type__(cls, object):
|
| 58 |
-
if not isinstance(object, dict):
|
| 59 |
-
return False
|
| 60 |
-
import jsonschema_rs
|
| 61 |
-
|
| 62 |
-
jsonschema_rs.meta.validate(object)
|
| 63 |
-
return True
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
class Tool(TypedDict):
|
| 67 |
-
name: str
|
| 68 |
-
description: str
|
| 69 |
-
parameters: JsonSchema
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class ToolCall(TypedDict):
|
| 73 |
-
name: str
|
| 74 |
-
arguments: Dict[str, Any]
|
| 75 |
-
|
| 76 |
-
|
| 77 |
register_type(Text)
|
| 78 |
register_type(Number)
|
| 79 |
-
register_type(
|
|
|
|
|
|
|
| 80 |
register_type(Dialog)
|
| 81 |
register_type(Table)
|
| 82 |
register_type(Audio)
|
| 83 |
register_type(Image)
|
| 84 |
register_type(Video)
|
|
|
|
| 85 |
register_type(Document)
|
| 86 |
register_type(MultiDocument)
|
| 87 |
register_type(RagResponse)
|
|
|
|
| 6 |
Number = NewType("Number", Union[float, int])
|
| 7 |
|
| 8 |
|
| 9 |
+
class JsonSchema:
|
| 10 |
+
@classmethod
|
| 11 |
+
def __verify_type__(cls, object):
|
| 12 |
+
if not isinstance(object, dict):
|
| 13 |
+
return False
|
| 14 |
+
import jsonschema_rs
|
| 15 |
+
|
| 16 |
+
jsonschema_rs.meta.validate(object)
|
| 17 |
+
return True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Tool(TypedDict):
|
| 21 |
+
# Original fields
|
| 22 |
+
name: str
|
| 23 |
+
description: str
|
| 24 |
+
parameters: JsonSchema
|
| 25 |
+
# LiteLLM extension
|
| 26 |
+
type: Optional[Literal["function"]]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ToolCall(TypedDict):
|
| 30 |
+
name: str
|
| 31 |
+
arguments: Dict[str, Any]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ToolCallContext(TypedDict):
|
| 35 |
+
id: str
|
| 36 |
+
type: Literal["function"]
|
| 37 |
+
function: ToolCall
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ToolCallTurn(TypedDict):
|
| 41 |
+
role: Literal["assistant"]
|
| 42 |
+
content: Optional[str]
|
| 43 |
+
tool_calls: List[ToolCallContext]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ToolOutputTurn(TypedDict):
|
| 47 |
+
role: Literal["tool"]
|
| 48 |
+
tool_call_id: str
|
| 49 |
+
name: str
|
| 50 |
+
content: str
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TextTurn(TypedDict):
|
| 54 |
+
role: Literal["system", "user", "agent", "assistant"]
|
| 55 |
content: Text
|
| 56 |
|
| 57 |
|
|
|
|
| 62 |
is_answerable: bool
|
| 63 |
|
| 64 |
|
| 65 |
+
Dialog = NewType("Dialog", List[Union[TextTurn, ToolCallTurn, ToolOutputTurn]])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Conversation(TypedDict):
|
| 69 |
+
id: str
|
| 70 |
+
dialog: Dialog
|
| 71 |
|
| 72 |
|
| 73 |
class Image(TypedDict):
|
|
|
|
| 101 |
data: Optional[Dict[str, Dict]]
|
| 102 |
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
register_type(Text)
|
| 105 |
register_type(Number)
|
| 106 |
+
register_type(TextTurn)
|
| 107 |
+
register_type(ToolCallTurn)
|
| 108 |
+
register_type(ToolOutputTurn)
|
| 109 |
register_type(Dialog)
|
| 110 |
register_type(Table)
|
| 111 |
register_type(Audio)
|
| 112 |
register_type(Image)
|
| 113 |
register_type(Video)
|
| 114 |
+
register_type(Conversation)
|
| 115 |
register_type(Document)
|
| 116 |
register_type(MultiDocument)
|
| 117 |
register_type(RagResponse)
|
version.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
version = "1.
|
|
|
|
| 1 |
+
version = "1.25.0"
|