Upload folder using huggingface_hub
Browse files- api.py +41 -9
- artifact.py +7 -2
- collections_operators.py +22 -4
- dialog_operators.py +2 -2
- formats.py +1 -0
- generator_utils.py +2 -32
- inference.py +376 -55
- llm_as_judge.py +261 -62
- loaders.py +14 -6
- metric_utils.py +18 -9
- metrics.py +206 -67
- operators.py +79 -47
- processors.py +77 -2
- settings_utils.py +1 -0
- split_utils.py +6 -1
- splitters.py +4 -2
- standard.py +6 -6
- stream.py +4 -3
- stream_operators.py +5 -3
- string_operators.py +9 -0
- struct_data_operators.py +194 -5
- templates.py +1 -1
- type_utils.py +3 -0
- utils.py +84 -1
- version.py +1 -1
api.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
|
|
| 1 |
from functools import lru_cache
|
| 2 |
from typing import Any, Dict, List, Optional, Union
|
| 3 |
|
| 4 |
-
from datasets import DatasetDict
|
| 5 |
-
|
| 6 |
from .artifact import fetch_artifact
|
| 7 |
from .dataset_utils import get_dataset_artifact
|
|
|
|
| 8 |
from .logging_utils import get_logger
|
| 9 |
from .metric_utils import _compute, _inference_post_process
|
| 10 |
from .operator import SourceOperator
|
|
@@ -14,7 +14,7 @@ from .standard import StandardRecipe
|
|
| 14 |
logger = get_logger()
|
| 15 |
|
| 16 |
|
| 17 |
-
def load(source: Union[SourceOperator, str])
|
| 18 |
assert isinstance(
|
| 19 |
source, (SourceOperator, str)
|
| 20 |
), "source must be a SourceOperator or a string"
|
|
@@ -79,7 +79,9 @@ def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> StandardRecipe
|
|
| 79 |
return recipe
|
| 80 |
|
| 81 |
|
| 82 |
-
def load_dataset(
|
|
|
|
|
|
|
| 83 |
"""Loads dataset.
|
| 84 |
|
| 85 |
If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
|
|
@@ -90,6 +92,7 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
|
| 90 |
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
| 91 |
For example:
|
| 92 |
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
|
|
|
|
| 93 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
| 94 |
|
| 95 |
Returns:
|
|
@@ -107,6 +110,9 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
|
| 107 |
"""
|
| 108 |
recipe = load_recipe(dataset_query, **kwargs)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
| 110 |
return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
|
| 111 |
|
| 112 |
|
|
@@ -135,19 +141,45 @@ def produce(instance_or_instances, dataset_query: Optional[str] = None, **kwargs
|
|
| 135 |
|
| 136 |
def infer(
|
| 137 |
instance_or_instances,
|
| 138 |
-
engine,
|
| 139 |
dataset_query: Optional[str] = None,
|
| 140 |
-
return_data=False,
|
|
|
|
|
|
|
| 141 |
**kwargs,
|
| 142 |
):
|
| 143 |
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
| 144 |
engine, _ = fetch_artifact(engine)
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
predictions = post_process(raw_predictions, dataset)
|
| 147 |
if return_data:
|
| 148 |
-
for prediction, raw_prediction, instance in zip(
|
| 149 |
-
predictions, raw_predictions, dataset
|
| 150 |
):
|
|
|
|
|
|
|
|
|
|
| 151 |
instance["prediction"] = prediction
|
| 152 |
instance["raw_prediction"] = raw_prediction
|
| 153 |
return dataset
|
|
|
|
| 1 |
+
import json
|
| 2 |
from functools import lru_cache
|
| 3 |
from typing import Any, Dict, List, Optional, Union
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from .artifact import fetch_artifact
|
| 6 |
from .dataset_utils import get_dataset_artifact
|
| 7 |
+
from .inference import InferenceEngine, LogProbInferenceEngine
|
| 8 |
from .logging_utils import get_logger
|
| 9 |
from .metric_utils import _compute, _inference_post_process
|
| 10 |
from .operator import SourceOperator
|
|
|
|
| 14 |
logger = get_logger()
|
| 15 |
|
| 16 |
|
| 17 |
+
def load(source: Union[SourceOperator, str]):
|
| 18 |
assert isinstance(
|
| 19 |
source, (SourceOperator, str)
|
| 20 |
), "source must be a SourceOperator or a string"
|
|
|
|
| 79 |
return recipe
|
| 80 |
|
| 81 |
|
| 82 |
+
def load_dataset(
|
| 83 |
+
dataset_query: Optional[str] = None, streaming: bool = False, **kwargs
|
| 84 |
+
):
|
| 85 |
"""Loads dataset.
|
| 86 |
|
| 87 |
If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
|
|
|
|
| 92 |
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
| 93 |
For example:
|
| 94 |
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
|
| 95 |
+
streaming (bool, False): When True yields the data as Unitxt streams dictionary
|
| 96 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
| 97 |
|
| 98 |
Returns:
|
|
|
|
| 110 |
"""
|
| 111 |
recipe = load_recipe(dataset_query, **kwargs)
|
| 112 |
|
| 113 |
+
if streaming:
|
| 114 |
+
return recipe()
|
| 115 |
+
|
| 116 |
return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
|
| 117 |
|
| 118 |
|
|
|
|
| 141 |
|
| 142 |
def infer(
|
| 143 |
instance_or_instances,
|
| 144 |
+
engine: InferenceEngine,
|
| 145 |
dataset_query: Optional[str] = None,
|
| 146 |
+
return_data: bool = False,
|
| 147 |
+
return_log_probs: bool = False,
|
| 148 |
+
return_meta_data: bool = False,
|
| 149 |
**kwargs,
|
| 150 |
):
|
| 151 |
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
| 152 |
engine, _ = fetch_artifact(engine)
|
| 153 |
+
if return_log_probs:
|
| 154 |
+
if not isinstance(engine, LogProbInferenceEngine):
|
| 155 |
+
raise NotImplementedError(
|
| 156 |
+
f"Error in infer: return_log_probs set to True but supplied engine "
|
| 157 |
+
f"{engine.__class__.__name__} does not support logprobs."
|
| 158 |
+
)
|
| 159 |
+
infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
|
| 160 |
+
raw_predictions = (
|
| 161 |
+
[output.prediction for output in infer_outputs]
|
| 162 |
+
if return_meta_data
|
| 163 |
+
else infer_outputs
|
| 164 |
+
)
|
| 165 |
+
raw_predictions = [
|
| 166 |
+
json.dumps(raw_prediction) for raw_prediction in raw_predictions
|
| 167 |
+
]
|
| 168 |
+
else:
|
| 169 |
+
infer_outputs = engine.infer(dataset, return_meta_data)
|
| 170 |
+
raw_predictions = (
|
| 171 |
+
[output.prediction for output in infer_outputs]
|
| 172 |
+
if return_meta_data
|
| 173 |
+
else infer_outputs
|
| 174 |
+
)
|
| 175 |
predictions = post_process(raw_predictions, dataset)
|
| 176 |
if return_data:
|
| 177 |
+
for prediction, raw_prediction, instance, infer_output in zip(
|
| 178 |
+
predictions, raw_predictions, dataset, infer_outputs
|
| 179 |
):
|
| 180 |
+
if return_meta_data:
|
| 181 |
+
instance["infer_meta_data"] = infer_output.__dict__
|
| 182 |
+
del instance["infer_meta_data"]["prediction"]
|
| 183 |
instance["prediction"] = prediction
|
| 184 |
instance["raw_prediction"] = raw_prediction
|
| 185 |
return dataset
|
artifact.py
CHANGED
|
@@ -22,7 +22,12 @@ from .parsing_utils import (
|
|
| 22 |
from .settings_utils import get_constants, get_settings
|
| 23 |
from .text_utils import camel_to_snake_case, is_camel_case
|
| 24 |
from .type_utils import issubtype
|
| 25 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
logger = get_logger()
|
| 28 |
settings = get_settings()
|
|
@@ -405,7 +410,7 @@ def get_raw(obj):
|
|
| 405 |
if isinstance(obj, dict):
|
| 406 |
return type(obj)({get_raw(k): get_raw(v) for k, v in obj.items()})
|
| 407 |
|
| 408 |
-
return
|
| 409 |
|
| 410 |
|
| 411 |
class ArtifactList(list, Artifact):
|
|
|
|
| 22 |
from .settings_utils import get_constants, get_settings
|
| 23 |
from .text_utils import camel_to_snake_case, is_camel_case
|
| 24 |
from .type_utils import issubtype
|
| 25 |
+
from .utils import (
|
| 26 |
+
artifacts_json_cache,
|
| 27 |
+
json_dump,
|
| 28 |
+
save_to_file,
|
| 29 |
+
shallow_copy,
|
| 30 |
+
)
|
| 31 |
|
| 32 |
logger = get_logger()
|
| 33 |
settings = get_settings()
|
|
|
|
| 410 |
if isinstance(obj, dict):
|
| 411 |
return type(obj)({get_raw(k): get_raw(v) for k, v in obj.items()})
|
| 412 |
|
| 413 |
+
return shallow_copy(obj)
|
| 414 |
|
| 415 |
|
| 416 |
class ArtifactList(list, Artifact):
|
collections_operators.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Any, Generator, List, Optional
|
|
| 3 |
from .dict_utils import dict_get, dict_set
|
| 4 |
from .operators import FieldOperator, StreamOperator
|
| 5 |
from .stream import Stream
|
| 6 |
-
from .utils import
|
| 7 |
|
| 8 |
|
| 9 |
class Dictify(FieldOperator):
|
|
@@ -70,10 +70,10 @@ class DuplicateByList(StreamOperator):
|
|
| 70 |
elements = dict_get(instance, self.field)
|
| 71 |
for element in elements:
|
| 72 |
if self.use_deep_copy:
|
| 73 |
-
instance_copy =
|
| 74 |
|
| 75 |
else:
|
| 76 |
-
instance_copy =
|
| 77 |
dict_set(instance_copy, to_field, element)
|
| 78 |
yield instance_copy
|
| 79 |
|
|
@@ -93,7 +93,7 @@ class DuplicateBySubLists(StreamOperator):
|
|
| 93 |
elements = instance[self.field]
|
| 94 |
for i in range(1, len(elements) + 1):
|
| 95 |
if self.use_deep_copy:
|
| 96 |
-
instance_copy =
|
| 97 |
instance_copy[to_field] = elements[:i]
|
| 98 |
else:
|
| 99 |
instance_copy = {
|
|
@@ -107,3 +107,21 @@ class DuplicateBySubLists(StreamOperator):
|
|
| 107 |
class GetLength(FieldOperator):
|
| 108 |
def process_value(self, collection: Any) -> Any:
|
| 109 |
return len(collection)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from .dict_utils import dict_get, dict_set
|
| 4 |
from .operators import FieldOperator, StreamOperator
|
| 5 |
from .stream import Stream
|
| 6 |
+
from .utils import recursive_shallow_copy
|
| 7 |
|
| 8 |
|
| 9 |
class Dictify(FieldOperator):
|
|
|
|
| 70 |
elements = dict_get(instance, self.field)
|
| 71 |
for element in elements:
|
| 72 |
if self.use_deep_copy:
|
| 73 |
+
instance_copy = recursive_shallow_copy(instance)
|
| 74 |
|
| 75 |
else:
|
| 76 |
+
instance_copy = instance.copy()
|
| 77 |
dict_set(instance_copy, to_field, element)
|
| 78 |
yield instance_copy
|
| 79 |
|
|
|
|
| 93 |
elements = instance[self.field]
|
| 94 |
for i in range(1, len(elements) + 1):
|
| 95 |
if self.use_deep_copy:
|
| 96 |
+
instance_copy = recursive_shallow_copy(instance)
|
| 97 |
instance_copy[to_field] = elements[:i]
|
| 98 |
else:
|
| 99 |
instance_copy = {
|
|
|
|
| 107 |
class GetLength(FieldOperator):
|
| 108 |
def process_value(self, collection: Any) -> Any:
|
| 109 |
return len(collection)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Filter(FieldOperator):
|
| 113 |
+
values: List[Any]
|
| 114 |
+
|
| 115 |
+
def process_value(self, collection: Any) -> Any:
|
| 116 |
+
# If collection is a list, tuple, or set
|
| 117 |
+
if isinstance(collection, (list, set, tuple)):
|
| 118 |
+
return type(collection)(
|
| 119 |
+
item for item in collection if item not in self.values
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# If collection is a dictionary, filter by keys
|
| 123 |
+
if isinstance(collection, dict):
|
| 124 |
+
return {k: v for k, v in collection.items() if k not in self.values}
|
| 125 |
+
|
| 126 |
+
# If collection is of an unsupported type
|
| 127 |
+
raise TypeError(f"Unsupported collection type: {type(collection)}")
|
dialog_operators.py
CHANGED
|
@@ -157,13 +157,13 @@ class SerializeOpenAiFormatDialog(SerializeDialog):
|
|
| 157 |
f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
|
| 158 |
)
|
| 159 |
|
| 160 |
-
if entry["role"] not in {"user", "assistant"}:
|
| 161 |
raise ValueError(
|
| 162 |
f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
|
| 163 |
)
|
| 164 |
|
| 165 |
first_entry = dialog[0]
|
| 166 |
-
if first_entry["role"] != "user":
|
| 167 |
raise ValueError(
|
| 168 |
f"First entry role is expected to be 'user' It is {first_entry['role']}."
|
| 169 |
)
|
|
|
|
| 157 |
f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
|
| 158 |
)
|
| 159 |
|
| 160 |
+
if entry["role"].lower() not in {"user", "assistant"}:
|
| 161 |
raise ValueError(
|
| 162 |
f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
|
| 163 |
)
|
| 164 |
|
| 165 |
first_entry = dialog[0]
|
| 166 |
+
if first_entry["role"].lower() != "user":
|
| 167 |
raise ValueError(
|
| 168 |
f"First entry role is expected to be 'user' It is {first_entry['role']}."
|
| 169 |
)
|
formats.py
CHANGED
|
@@ -182,6 +182,7 @@ class SystemFormat(BaseFormat):
|
|
| 182 |
target_prefix=demo_target_prefix,
|
| 183 |
source=demo_source,
|
| 184 |
target=demo_target,
|
|
|
|
| 185 |
**self.format_args,
|
| 186 |
)
|
| 187 |
demos_string += demo_str
|
|
|
|
| 182 |
target_prefix=demo_target_prefix,
|
| 183 |
source=demo_source,
|
| 184 |
target=demo_target,
|
| 185 |
+
instruction=instruction,
|
| 186 |
**self.format_args,
|
| 187 |
)
|
| 188 |
demos_string += demo_str
|
generator_utils.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from typing import Any, Dict, List
|
| 2 |
|
| 3 |
from .dataclass import Dataclass, OptionalField
|
| 4 |
-
from .utils import
|
| 5 |
|
| 6 |
|
| 7 |
class ReusableGenerator(Dataclass):
|
|
@@ -22,34 +22,4 @@ class ReusableGenerator(Dataclass):
|
|
| 22 |
class CopyingReusableGenerator(ReusableGenerator):
|
| 23 |
def __iter__(self):
|
| 24 |
for instance in self.activate():
|
| 25 |
-
yield
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# if __name__ == "__main__":
|
| 29 |
-
# from itertools import chain, islice
|
| 30 |
-
|
| 31 |
-
# # Creating objects of MyIterable
|
| 32 |
-
# iterable1 = ReusableGenerator(range, gen_argv=[1, 4])
|
| 33 |
-
# iterable2 = ReusableGenerator(range, gen_argv=[4, 7])
|
| 34 |
-
|
| 35 |
-
# # Using itertools.chain
|
| 36 |
-
# chained = list(chain(iterable1, iterable2))
|
| 37 |
-
# logger.info(chained) # Prints: [1, 2, 3, 4, 5, 6]
|
| 38 |
-
|
| 39 |
-
# # Using itertools.islice
|
| 40 |
-
# sliced = list(islice(ReusableGenerator(range, gen_argv=[1, 7]), 1, 4))
|
| 41 |
-
# logger.info(sliced) # Prints: [2, 3, 4]
|
| 42 |
-
|
| 43 |
-
# # now same test with generators
|
| 44 |
-
# def generator(start, end):
|
| 45 |
-
# for i in range(start, end):
|
| 46 |
-
# yield i
|
| 47 |
-
|
| 48 |
-
# iterable1 = ReusableGenerator(generator, gen_argv=[1, 4])
|
| 49 |
-
# iterable2 = ReusableGenerator(generator, gen_argv=[4, 7])
|
| 50 |
-
|
| 51 |
-
# chained = list(chain(iterable1, iterable2))
|
| 52 |
-
# logger.info(chained) # Prints: [1, 2, 3, 4, 5, 6]
|
| 53 |
-
|
| 54 |
-
# sliced = list(islice(ReusableGenerator(generator, gen_argv=[1, 7]), 1, 4))
|
| 55 |
-
# logger.info(sliced) # Prints: [2, 3, 4]
|
|
|
|
| 1 |
from typing import Any, Dict, List
|
| 2 |
|
| 3 |
from .dataclass import Dataclass, OptionalField
|
| 4 |
+
from .utils import recursive_shallow_copy
|
| 5 |
|
| 6 |
|
| 7 |
class ReusableGenerator(Dataclass):
|
|
|
|
| 22 |
class CopyingReusableGenerator(ReusableGenerator):
|
| 23 |
def __iter__(self):
|
| 24 |
for instance in self.activate():
|
| 25 |
+
yield recursive_shallow_copy(instance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
import abc
|
|
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
from typing import Any, Dict, List, Literal, Optional, Union
|
| 5 |
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
|
| 8 |
from .artifact import Artifact, fetch_artifact
|
|
@@ -16,12 +18,52 @@ from .settings_utils import get_settings
|
|
| 16 |
settings = get_settings()
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class InferenceEngine(abc.ABC, Artifact):
|
| 20 |
"""Abstract base class for inference."""
|
| 21 |
|
| 22 |
@abc.abstractmethod
|
| 23 |
-
def _infer(
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
pass
|
| 26 |
|
| 27 |
@abc.abstractmethod
|
|
@@ -33,12 +75,29 @@ class InferenceEngine(abc.ABC, Artifact):
|
|
| 33 |
if not settings.mock_inference_mode:
|
| 34 |
self.prepare_engine()
|
| 35 |
|
| 36 |
-
def infer(
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
[self.verify_instance(instance) for instance in dataset]
|
| 39 |
if settings.mock_inference_mode:
|
| 40 |
return [instance["source"] for instance in dataset]
|
| 41 |
-
return self._infer(dataset)
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
@deprecation(version="2.0.0")
|
| 44 |
def _set_inference_parameters(self):
|
|
@@ -62,19 +121,39 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
|
|
| 62 |
"""Abstract base class for inference with log probs."""
|
| 63 |
|
| 64 |
@abc.abstractmethod
|
| 65 |
-
def _infer_log_probs(
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
pass
|
| 68 |
|
| 69 |
-
def infer_log_probs(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
| 71 |
|
| 72 |
-
For each instance ,
|
| 73 |
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
|
| 74 |
-
|
|
|
|
| 75 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
[self.verify_instance(instance) for instance in dataset]
|
| 77 |
-
return self._infer_log_probs(dataset)
|
| 78 |
|
| 79 |
|
| 80 |
class LazyLoadMixin(Artifact):
|
|
@@ -96,6 +175,9 @@ class HFPipelineBasedInferenceEngine(
|
|
| 96 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
| 97 |
}
|
| 98 |
|
|
|
|
|
|
|
|
|
|
| 99 |
def _prepare_pipeline(self):
|
| 100 |
import torch
|
| 101 |
from transformers import AutoConfig, pipeline
|
|
@@ -143,7 +225,11 @@ class HFPipelineBasedInferenceEngine(
|
|
| 143 |
def _is_loaded(self):
|
| 144 |
return hasattr(self, "model") and self.model is not None
|
| 145 |
|
| 146 |
-
def _infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
if not self._is_loaded():
|
| 148 |
self._prepare_pipeline()
|
| 149 |
|
|
@@ -157,12 +243,20 @@ class HFPipelineBasedInferenceEngine(
|
|
| 157 |
|
| 158 |
class MockInferenceEngine(InferenceEngine):
|
| 159 |
model_name: str
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def prepare_engine(self):
|
| 162 |
return
|
| 163 |
|
| 164 |
-
def _infer(
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
class MockModeMixin(Artifact):
|
|
@@ -226,7 +320,14 @@ class GenericInferenceEngine(InferenceEngine):
|
|
| 226 |
engine_reference = self.default
|
| 227 |
self.engine, _ = fetch_artifact(engine_reference)
|
| 228 |
|
| 229 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
return self.engine._infer(dataset)
|
| 231 |
|
| 232 |
|
|
@@ -238,10 +339,17 @@ class OllamaInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
|
| 238 |
}
|
| 239 |
data_classification_policy = ["public", "proprietary"]
|
| 240 |
|
|
|
|
|
|
|
|
|
|
| 241 |
def prepare_engine(self):
|
| 242 |
pass
|
| 243 |
|
| 244 |
-
def _infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
import ollama
|
| 246 |
|
| 247 |
result = [
|
|
@@ -260,7 +368,10 @@ class OllamaInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
|
| 260 |
|
| 261 |
|
| 262 |
class IbmGenAiInferenceEngine(
|
| 263 |
-
InferenceEngine,
|
|
|
|
|
|
|
|
|
|
| 264 |
):
|
| 265 |
label: str = "ibm_genai"
|
| 266 |
model_name: str
|
|
@@ -270,6 +381,9 @@ class IbmGenAiInferenceEngine(
|
|
| 270 |
data_classification_policy = ["public", "proprietary"]
|
| 271 |
parameters: Optional[IbmGenAiInferenceEngineParams] = None
|
| 272 |
|
|
|
|
|
|
|
|
|
|
| 273 |
def prepare_engine(self):
|
| 274 |
from genai import Client, Credentials
|
| 275 |
|
|
@@ -285,21 +399,88 @@ class IbmGenAiInferenceEngine(
|
|
| 285 |
|
| 286 |
self._set_inference_parameters()
|
| 287 |
|
| 288 |
-
def _infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
from genai.schema import TextGenerationParameters
|
| 290 |
|
| 291 |
genai_params = TextGenerationParameters(
|
| 292 |
**self.to_dict([IbmGenAiInferenceEngineParamsMixin])
|
| 293 |
)
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
)
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
|
| 305 |
class OpenAiInferenceEngineParamsMixin(Artifact):
|
|
@@ -349,18 +530,29 @@ class OpenAiInferenceEngine(
|
|
| 349 |
data_classification_policy = ["public"]
|
| 350 |
parameters: Optional[OpenAiInferenceEngineParams] = None
|
| 351 |
|
| 352 |
-
def
|
| 353 |
-
|
| 354 |
|
| 355 |
-
|
| 356 |
-
|
|
|
|
| 357 |
assert api_key is not None, (
|
| 358 |
-
f"Error while trying to run
|
| 359 |
-
f" Please set the environment param '{
|
| 360 |
)
|
|
|
|
| 361 |
|
| 362 |
-
|
|
|
|
| 363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
self._set_inference_parameters()
|
| 365 |
|
| 366 |
def _get_completion_kwargs(self):
|
|
@@ -370,7 +562,11 @@ class OpenAiInferenceEngine(
|
|
| 370 |
if v is not None
|
| 371 |
}
|
| 372 |
|
| 373 |
-
def _infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
outputs = []
|
| 375 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
| 376 |
response = self.client.chat.completions.create(
|
|
@@ -387,13 +583,18 @@ class OpenAiInferenceEngine(
|
|
| 387 |
model=self.model_name,
|
| 388 |
**self._get_completion_kwargs(),
|
| 389 |
)
|
| 390 |
-
|
|
|
|
| 391 |
|
| 392 |
outputs.append(output)
|
| 393 |
|
| 394 |
return outputs
|
| 395 |
|
| 396 |
-
def _infer_log_probs(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
outputs = []
|
| 398 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
| 399 |
response = self.client.chat.completions.create(
|
|
@@ -411,7 +612,7 @@ class OpenAiInferenceEngine(
|
|
| 411 |
**self._get_completion_kwargs(),
|
| 412 |
)
|
| 413 |
top_logprobs_response = response.choices[0].logprobs.content
|
| 414 |
-
|
| 415 |
{
|
| 416 |
"top_tokens": [
|
| 417 |
{"text": obj.token, "logprob": obj.logprob}
|
|
@@ -420,9 +621,21 @@ class OpenAiInferenceEngine(
|
|
| 420 |
}
|
| 421 |
for generated_token in top_logprobs_response
|
| 422 |
]
|
|
|
|
| 423 |
outputs.append(output)
|
| 424 |
return outputs
|
| 425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
class TogetherAiInferenceEngineParamsMixin(Artifact):
|
| 428 |
max_tokens: Optional[int] = None
|
|
@@ -450,6 +663,9 @@ class TogetherAiInferenceEngine(
|
|
| 450 |
data_classification_policy = ["public"]
|
| 451 |
parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None
|
| 452 |
|
|
|
|
|
|
|
|
|
|
| 453 |
def prepare_engine(self):
|
| 454 |
from together import Together
|
| 455 |
from together.types.models import ModelType
|
|
@@ -501,7 +717,11 @@ class TogetherAiInferenceEngine(
|
|
| 501 |
)
|
| 502 |
return response.choices[0].text
|
| 503 |
|
| 504 |
-
def _infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
from together.types.models import ModelType
|
| 506 |
|
| 507 |
outputs = []
|
|
@@ -514,6 +734,23 @@ class TogetherAiInferenceEngine(
|
|
| 514 |
return outputs
|
| 515 |
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
class WMLInferenceEngineParamsMixin(Artifact):
|
| 518 |
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
| 519 |
length_penalty: Optional[Dict[str, Union[int, float]]] = None
|
|
@@ -550,7 +787,10 @@ class WMLInferenceEngineParams(Artifact):
|
|
| 550 |
|
| 551 |
|
| 552 |
class WMLInferenceEngine(
|
| 553 |
-
InferenceEngine,
|
|
|
|
|
|
|
|
|
|
| 554 |
):
|
| 555 |
"""Runs inference using ibm-watsonx-ai.
|
| 556 |
|
|
@@ -604,14 +844,17 @@ class WMLInferenceEngine(
|
|
| 604 |
concurrency_limit: int = 10
|
| 605 |
_client: Any = InternalField(default=None, name="WML client")
|
| 606 |
|
|
|
|
|
|
|
|
|
|
| 607 |
def verify(self):
|
| 608 |
super().verify()
|
| 609 |
|
| 610 |
if self.credentials is not None:
|
| 611 |
for key in self.credentials:
|
| 612 |
-
if key not in ["url", "apikey", "project_id"]:
|
| 613 |
raise ValueError(
|
| 614 |
-
f'Illegal credential key: {key}, use only ["url", "apikey", "project_id"]'
|
| 615 |
)
|
| 616 |
|
| 617 |
assert (
|
|
@@ -631,10 +874,14 @@ class WMLInferenceEngine(
|
|
| 631 |
|
| 632 |
@staticmethod
|
| 633 |
def _read_wml_credentials_from_env() -> (
|
| 634 |
-
Dict[Literal["url", "apikey", "project_id"], str]
|
| 635 |
):
|
| 636 |
credentials = {}
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
env_var = os.environ.get(env_var_name)
|
| 639 |
assert env_var, (
|
| 640 |
f"Error while trying to run 'WMLInferenceEngine'. "
|
|
@@ -655,7 +902,10 @@ class WMLInferenceEngine(
|
|
| 655 |
self.credentials = self._read_wml_credentials_from_env()
|
| 656 |
|
| 657 |
client = APIClient(credentials=self.credentials)
|
| 658 |
-
|
|
|
|
|
|
|
|
|
|
| 659 |
return client
|
| 660 |
|
| 661 |
def prepare_engine(self):
|
|
@@ -663,7 +913,7 @@ class WMLInferenceEngine(
|
|
| 663 |
|
| 664 |
self._set_inference_parameters()
|
| 665 |
|
| 666 |
-
def
|
| 667 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
| 668 |
|
| 669 |
model = ModelInference(
|
|
@@ -671,20 +921,81 @@ class WMLInferenceEngine(
|
|
| 671 |
deployment_id=self.deployment_id,
|
| 672 |
api_client=self._client,
|
| 673 |
)
|
|
|
|
| 674 |
|
| 675 |
-
|
| 676 |
-
dataset = dataset if isinstance(dataset, list) else [dataset]
|
| 677 |
|
| 678 |
-
|
| 679 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 680 |
prompt=instance["source"],
|
| 681 |
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
|
| 682 |
)
|
| 683 |
-
|
| 684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
|
| 686 |
-
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
|
| 690 |
class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
@@ -698,6 +1009,9 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
| 698 |
"accelerate": "pip install accelerate",
|
| 699 |
}
|
| 700 |
|
|
|
|
|
|
|
|
|
|
| 701 |
def _prepare_engine(self):
|
| 702 |
import torch
|
| 703 |
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
@@ -725,14 +1039,18 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
| 725 |
def _is_loaded(self):
|
| 726 |
return hasattr(self, "model") and self.model is not None
|
| 727 |
|
| 728 |
-
def _infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
if not self._is_loaded():
|
| 730 |
self._prepare_engine()
|
| 731 |
|
| 732 |
import torch
|
| 733 |
|
| 734 |
results = []
|
| 735 |
-
for instance in dataset:
|
| 736 |
text = instance["source"]
|
| 737 |
images = extract_images(text, instance)
|
| 738 |
# Regular expression to match all <img src="..."> tags
|
|
@@ -745,7 +1063,10 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
| 745 |
).to(self.device, torch.float16)
|
| 746 |
input_len = len(inputs["input_ids"][0])
|
| 747 |
output = self.model.generate(
|
| 748 |
-
**inputs,
|
|
|
|
|
|
|
|
|
|
| 749 |
)
|
| 750 |
result = self.processor.decode(
|
| 751 |
output[0][input_len:], skip_special_tokens=True
|
|
|
|
| 1 |
import abc
|
| 2 |
+
import dataclasses
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
from typing import Any, Dict, List, Literal, Optional, Union
|
| 6 |
|
| 7 |
+
from datasets import DatasetDict
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
from .artifact import Artifact, fetch_artifact
|
|
|
|
| 18 |
settings = get_settings()
|
| 19 |
|
| 20 |
|
| 21 |
+
def get_model_and_label_id(model_name, label):
|
| 22 |
+
model_id = model_name.split("/")[-1].replace("-", "_").replace(".", ",").lower()
|
| 23 |
+
return f"{model_id}_{label}"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclasses.dataclass
|
| 27 |
+
class TextGenerationInferenceOutput:
|
| 28 |
+
"""Contains the prediction results and metadata for the inference.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
prediction (Union[str, List[Dict[str, Any]]]): If this is the result of an _infer call, the string predicted by the model.
|
| 32 |
+
If this is the results of an _infer_log_probs call, a list of dictionaries. The i'th dictionary represents
|
| 33 |
+
the i'th token in the response. The entry "top_tokens" in the dictionary holds a sorted list of the top tokens
|
| 34 |
+
for this position and their probabilities.
|
| 35 |
+
For example: [ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
|
| 36 |
+
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]}
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
input_tokens (int) : number of input tokens to the model.
|
| 40 |
+
output_tokens (int) : number of output tokens to the model.
|
| 41 |
+
model_name (str): the model_name as kept in the InferenceEngine.
|
| 42 |
+
inference_type (str): The label stating the type of the InferenceEngine.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
prediction: Union[str, List[Dict[str, Any]]]
|
| 46 |
+
input_tokens: Optional[int] = None
|
| 47 |
+
output_tokens: Optional[int] = None
|
| 48 |
+
model_name: Optional[str] = None
|
| 49 |
+
inference_type: Optional[str] = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
class InferenceEngine(abc.ABC, Artifact):
|
| 53 |
"""Abstract base class for inference."""
|
| 54 |
|
| 55 |
@abc.abstractmethod
|
| 56 |
+
def _infer(
|
| 57 |
+
self,
|
| 58 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 59 |
+
return_meta_data: bool = False,
|
| 60 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 61 |
+
"""Perform inference on the input dataset.
|
| 62 |
+
|
| 63 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string.
|
| 64 |
+
return_meta_data is only supported for some InferenceEngines.
|
| 65 |
+
predictions.
|
| 66 |
+
"""
|
| 67 |
pass
|
| 68 |
|
| 69 |
@abc.abstractmethod
|
|
|
|
| 75 |
if not settings.mock_inference_mode:
|
| 76 |
self.prepare_engine()
|
| 77 |
|
| 78 |
+
def infer(
|
| 79 |
+
self,
|
| 80 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 81 |
+
return_meta_data: bool = False,
|
| 82 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 83 |
+
"""Verifies instances of a dataset and perform inference on the input dataset.
|
| 84 |
+
|
| 85 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
|
| 86 |
+
predictions.
|
| 87 |
+
"""
|
| 88 |
+
if return_meta_data and not hasattr(self, "get_return_object"):
|
| 89 |
+
raise NotImplementedError(
|
| 90 |
+
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
|
| 91 |
+
f"does not contain a 'get_return_object' method. Please set return_meta_data=False."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
[self.verify_instance(instance) for instance in dataset]
|
| 95 |
if settings.mock_inference_mode:
|
| 96 |
return [instance["source"] for instance in dataset]
|
| 97 |
+
return self._infer(dataset, return_meta_data)
|
| 98 |
+
|
| 99 |
+
def get_engine_id(self):
|
| 100 |
+
raise NotImplementedError()
|
| 101 |
|
| 102 |
@deprecation(version="2.0.0")
|
| 103 |
def _set_inference_parameters(self):
|
|
|
|
| 121 |
"""Abstract base class for inference with log probs."""
|
| 122 |
|
| 123 |
@abc.abstractmethod
|
| 124 |
+
def _infer_log_probs(
|
| 125 |
+
self,
|
| 126 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 127 |
+
return_meta_data: bool = False,
|
| 128 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 129 |
+
"""Perform inference on the input dataset that returns log probs.
|
| 130 |
+
|
| 131 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the logprob dicts.
|
| 132 |
+
return_meta_data is only supported for some InferenceEngines.
|
| 133 |
+
predictions.
|
| 134 |
+
"""
|
| 135 |
pass
|
| 136 |
|
| 137 |
+
def infer_log_probs(
|
| 138 |
+
self,
|
| 139 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 140 |
+
return_meta_data: bool = False,
|
| 141 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 142 |
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
| 143 |
|
| 144 |
+
For each instance , generates a list of top tokens per position.
|
| 145 |
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
|
| 146 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns the list of the logprob dicts.
|
| 147 |
+
return_meta_data is only supported for some InferenceEngines.
|
| 148 |
"""
|
| 149 |
+
if return_meta_data and not hasattr(self, "get_return_object"):
|
| 150 |
+
raise NotImplementedError(
|
| 151 |
+
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
|
| 152 |
+
f"does not contain a 'get_return_object' method. Please set return_meta_data=False."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
[self.verify_instance(instance) for instance in dataset]
|
| 156 |
+
return self._infer_log_probs(dataset, return_meta_data)
|
| 157 |
|
| 158 |
|
| 159 |
class LazyLoadMixin(Artifact):
|
|
|
|
| 175 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
| 176 |
}
|
| 177 |
|
| 178 |
+
def get_engine_id(self):
|
| 179 |
+
return get_model_and_label_id(self.model_name, "hf_pipeline")
|
| 180 |
+
|
| 181 |
def _prepare_pipeline(self):
|
| 182 |
import torch
|
| 183 |
from transformers import AutoConfig, pipeline
|
|
|
|
| 225 |
def _is_loaded(self):
|
| 226 |
return hasattr(self, "model") and self.model is not None
|
| 227 |
|
| 228 |
+
def _infer(
|
| 229 |
+
self,
|
| 230 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 231 |
+
return_meta_data: bool = False,
|
| 232 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 233 |
if not self._is_loaded():
|
| 234 |
self._prepare_pipeline()
|
| 235 |
|
|
|
|
| 243 |
|
| 244 |
class MockInferenceEngine(InferenceEngine):
|
| 245 |
model_name: str
|
| 246 |
+
default_inference_value: str = "[[10]]"
|
| 247 |
+
|
| 248 |
+
def get_engine_id(self):
|
| 249 |
+
return get_model_and_label_id(self.model_name, "mock")
|
| 250 |
|
| 251 |
def prepare_engine(self):
|
| 252 |
return
|
| 253 |
|
| 254 |
+
def _infer(
|
| 255 |
+
self,
|
| 256 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 257 |
+
return_meta_data: bool = False,
|
| 258 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 259 |
+
return [self.default_inference_value for instance in dataset]
|
| 260 |
|
| 261 |
|
| 262 |
class MockModeMixin(Artifact):
|
|
|
|
| 320 |
engine_reference = self.default
|
| 321 |
self.engine, _ = fetch_artifact(engine_reference)
|
| 322 |
|
| 323 |
+
def get_engine_id(self):
|
| 324 |
+
return "generic_inference_engine"
|
| 325 |
+
|
| 326 |
+
def _infer(
|
| 327 |
+
self,
|
| 328 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 329 |
+
return_meta_data: bool = False,
|
| 330 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 331 |
return self.engine._infer(dataset)
|
| 332 |
|
| 333 |
|
|
|
|
| 339 |
}
|
| 340 |
data_classification_policy = ["public", "proprietary"]
|
| 341 |
|
| 342 |
+
def get_engine_id(self):
|
| 343 |
+
return get_model_and_label_id(self.model_name, self.label)
|
| 344 |
+
|
| 345 |
def prepare_engine(self):
|
| 346 |
pass
|
| 347 |
|
| 348 |
+
def _infer(
|
| 349 |
+
self,
|
| 350 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 351 |
+
return_meta_data: bool = False,
|
| 352 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 353 |
import ollama
|
| 354 |
|
| 355 |
result = [
|
|
|
|
| 368 |
|
| 369 |
|
| 370 |
class IbmGenAiInferenceEngine(
|
| 371 |
+
InferenceEngine,
|
| 372 |
+
IbmGenAiInferenceEngineParamsMixin,
|
| 373 |
+
PackageRequirementsMixin,
|
| 374 |
+
LogProbInferenceEngine,
|
| 375 |
):
|
| 376 |
label: str = "ibm_genai"
|
| 377 |
model_name: str
|
|
|
|
| 381 |
data_classification_policy = ["public", "proprietary"]
|
| 382 |
parameters: Optional[IbmGenAiInferenceEngineParams] = None
|
| 383 |
|
| 384 |
+
def get_engine_id(self):
|
| 385 |
+
return get_model_and_label_id(self.model_name, self.label)
|
| 386 |
+
|
| 387 |
def prepare_engine(self):
|
| 388 |
from genai import Client, Credentials
|
| 389 |
|
|
|
|
| 399 |
|
| 400 |
self._set_inference_parameters()
|
| 401 |
|
| 402 |
+
def _infer(
|
| 403 |
+
self,
|
| 404 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 405 |
+
return_meta_data: bool = False,
|
| 406 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 407 |
from genai.schema import TextGenerationParameters
|
| 408 |
|
| 409 |
genai_params = TextGenerationParameters(
|
| 410 |
**self.to_dict([IbmGenAiInferenceEngineParamsMixin])
|
| 411 |
)
|
| 412 |
|
| 413 |
+
results = []
|
| 414 |
+
responses = self.client.text.generation.create(
|
| 415 |
+
model_id=self.model_name,
|
| 416 |
+
inputs=[instance["source"] for instance in dataset],
|
| 417 |
+
parameters=genai_params,
|
| 418 |
+
)
|
| 419 |
+
for response in responses:
|
| 420 |
+
generated_text = response.results[0].generated_text
|
| 421 |
+
result = self.get_return_object(
|
| 422 |
+
generated_text, response.results[0], return_meta_data
|
| 423 |
)
|
| 424 |
+
results.append(result)
|
| 425 |
+
return results
|
| 426 |
+
|
| 427 |
+
def _infer_log_probs(
|
| 428 |
+
self,
|
| 429 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 430 |
+
return_meta_data: bool = False,
|
| 431 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 432 |
+
from genai.schema import TextGenerationParameters
|
| 433 |
+
|
| 434 |
+
logprobs_return_options = {
|
| 435 |
+
"generated_tokens": True,
|
| 436 |
+
"input_text": False,
|
| 437 |
+
"input_tokens": False,
|
| 438 |
+
"token_logprobs": True,
|
| 439 |
+
"token_ranks": True,
|
| 440 |
+
"top_n_tokens": 5,
|
| 441 |
+
}
|
| 442 |
+
genai_params = self.to_dict(
|
| 443 |
+
[IbmGenAiInferenceEngineParamsMixin], keep_empty=False
|
| 444 |
+
)
|
| 445 |
+
genai_params = {**genai_params, "return_options": logprobs_return_options}
|
| 446 |
+
genai_params = TextGenerationParameters(**genai_params)
|
| 447 |
+
predictions = self.client.text.generation.create(
|
| 448 |
+
model_id=self.model_name,
|
| 449 |
+
inputs=[instance["source"] for instance in dataset],
|
| 450 |
+
parameters=genai_params,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
predict_results = []
|
| 454 |
+
for prediction in predictions:
|
| 455 |
+
result = prediction.results[0]
|
| 456 |
+
assert isinstance(
|
| 457 |
+
result.generated_tokens, list
|
| 458 |
+
), "result.generated_tokens should be a list"
|
| 459 |
+
|
| 460 |
+
predict_result = []
|
| 461 |
+
for base_token in result.generated_tokens:
|
| 462 |
+
res = {**base_token.__dict__, **base_token.model_extra}
|
| 463 |
+
res["top_tokens"] = [
|
| 464 |
+
{"logprob": top_token.logprob, "text": top_token.text}
|
| 465 |
+
for top_token in res["top_tokens"]
|
| 466 |
+
]
|
| 467 |
+
predict_result.append(res)
|
| 468 |
+
final_results = self.get_return_object(
|
| 469 |
+
predict_result, result, return_meta_data
|
| 470 |
+
)
|
| 471 |
+
predict_results.append(final_results)
|
| 472 |
+
return predict_results
|
| 473 |
+
|
| 474 |
+
def get_return_object(self, predict_result, result, return_meta_data):
|
| 475 |
+
if return_meta_data:
|
| 476 |
+
return TextGenerationInferenceOutput(
|
| 477 |
+
prediction=predict_result,
|
| 478 |
+
input_tokens=result.input_token_count,
|
| 479 |
+
output_tokens=result.generated_token_count,
|
| 480 |
+
model_name=self.model_name,
|
| 481 |
+
inference_type=self.label,
|
| 482 |
+
)
|
| 483 |
+
return predict_result
|
| 484 |
|
| 485 |
|
| 486 |
class OpenAiInferenceEngineParamsMixin(Artifact):
|
|
|
|
| 530 |
data_classification_policy = ["public"]
|
| 531 |
parameters: Optional[OpenAiInferenceEngineParams] = None
|
| 532 |
|
| 533 |
+
def get_engine_id(self):
|
| 534 |
+
return get_model_and_label_id(self.model_name, self.label)
|
| 535 |
|
| 536 |
+
@classmethod
|
| 537 |
+
def get_api_param(cls, inference_engine: str, api_param_env_var_name: str):
|
| 538 |
+
api_key = os.environ.get(api_param_env_var_name)
|
| 539 |
assert api_key is not None, (
|
| 540 |
+
f"Error while trying to run {inference_engine}."
|
| 541 |
+
f" Please set the environment param '{api_param_env_var_name}'."
|
| 542 |
)
|
| 543 |
+
return api_key
|
| 544 |
|
| 545 |
+
def create_client(self):
|
| 546 |
+
from openai import OpenAI
|
| 547 |
|
| 548 |
+
api_key = self.get_api_param(
|
| 549 |
+
inference_engine="OpenAiInferenceEngine",
|
| 550 |
+
api_param_env_var_name="OPENAI_API_KEY",
|
| 551 |
+
)
|
| 552 |
+
return OpenAI(api_key=api_key)
|
| 553 |
+
|
| 554 |
+
def prepare_engine(self):
|
| 555 |
+
self.client = self.create_client()
|
| 556 |
self._set_inference_parameters()
|
| 557 |
|
| 558 |
def _get_completion_kwargs(self):
|
|
|
|
| 562 |
if v is not None
|
| 563 |
}
|
| 564 |
|
| 565 |
+
def _infer(
|
| 566 |
+
self,
|
| 567 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 568 |
+
return_meta_data: bool = False,
|
| 569 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 570 |
outputs = []
|
| 571 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
| 572 |
response = self.client.chat.completions.create(
|
|
|
|
| 583 |
model=self.model_name,
|
| 584 |
**self._get_completion_kwargs(),
|
| 585 |
)
|
| 586 |
+
prediction = response.choices[0].message.content
|
| 587 |
+
output = self.get_return_object(prediction, response, return_meta_data)
|
| 588 |
|
| 589 |
outputs.append(output)
|
| 590 |
|
| 591 |
return outputs
|
| 592 |
|
| 593 |
+
def _infer_log_probs(
|
| 594 |
+
self,
|
| 595 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 596 |
+
return_meta_data: bool = False,
|
| 597 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 598 |
outputs = []
|
| 599 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
| 600 |
response = self.client.chat.completions.create(
|
|
|
|
| 612 |
**self._get_completion_kwargs(),
|
| 613 |
)
|
| 614 |
top_logprobs_response = response.choices[0].logprobs.content
|
| 615 |
+
pred_output = [
|
| 616 |
{
|
| 617 |
"top_tokens": [
|
| 618 |
{"text": obj.token, "logprob": obj.logprob}
|
|
|
|
| 621 |
}
|
| 622 |
for generated_token in top_logprobs_response
|
| 623 |
]
|
| 624 |
+
output = self.get_return_object(pred_output, response, return_meta_data)
|
| 625 |
outputs.append(output)
|
| 626 |
return outputs
|
| 627 |
|
| 628 |
+
def get_return_object(self, predict_result, response, return_meta_data):
|
| 629 |
+
if return_meta_data:
|
| 630 |
+
return TextGenerationInferenceOutput(
|
| 631 |
+
prediction=predict_result,
|
| 632 |
+
input_tokens=response.usage.prompt_tokens,
|
| 633 |
+
output_tokens=response.usage.completion_tokens,
|
| 634 |
+
model_name=self.model_name,
|
| 635 |
+
inference_type=self.label,
|
| 636 |
+
)
|
| 637 |
+
return predict_result
|
| 638 |
+
|
| 639 |
|
| 640 |
class TogetherAiInferenceEngineParamsMixin(Artifact):
|
| 641 |
max_tokens: Optional[int] = None
|
|
|
|
| 663 |
data_classification_policy = ["public"]
|
| 664 |
parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None
|
| 665 |
|
| 666 |
+
def get_engine_id(self):
|
| 667 |
+
return get_model_and_label_id(self.model_name, self.label)
|
| 668 |
+
|
| 669 |
def prepare_engine(self):
|
| 670 |
from together import Together
|
| 671 |
from together.types.models import ModelType
|
|
|
|
| 717 |
)
|
| 718 |
return response.choices[0].text
|
| 719 |
|
| 720 |
+
def _infer(
|
| 721 |
+
self,
|
| 722 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 723 |
+
return_meta_data: bool = False,
|
| 724 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 725 |
from together.types.models import ModelType
|
| 726 |
|
| 727 |
outputs = []
|
|
|
|
| 734 |
return outputs
|
| 735 |
|
| 736 |
|
| 737 |
+
class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
|
| 738 |
+
label: str = "vllm"
|
| 739 |
+
|
| 740 |
+
def create_client(self):
|
| 741 |
+
from openai import OpenAI
|
| 742 |
+
|
| 743 |
+
api_key = self.get_api_param(
|
| 744 |
+
inference_engine="VLLMRemoteInferenceEngine",
|
| 745 |
+
api_param_env_var_name="VLLM_API_KEY",
|
| 746 |
+
)
|
| 747 |
+
api_url = self.get_api_param(
|
| 748 |
+
inference_engine="VLLMRemoteInferenceEngine",
|
| 749 |
+
api_param_env_var_name="VLLM_API_URL",
|
| 750 |
+
)
|
| 751 |
+
return OpenAI(api_key=api_key, base_url=api_url)
|
| 752 |
+
|
| 753 |
+
|
| 754 |
class WMLInferenceEngineParamsMixin(Artifact):
|
| 755 |
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
| 756 |
length_penalty: Optional[Dict[str, Union[int, float]]] = None
|
|
|
|
| 787 |
|
| 788 |
|
| 789 |
class WMLInferenceEngine(
|
| 790 |
+
InferenceEngine,
|
| 791 |
+
WMLInferenceEngineParamsMixin,
|
| 792 |
+
PackageRequirementsMixin,
|
| 793 |
+
LogProbInferenceEngine,
|
| 794 |
):
|
| 795 |
"""Runs inference using ibm-watsonx-ai.
|
| 796 |
|
|
|
|
| 844 |
concurrency_limit: int = 10
|
| 845 |
_client: Any = InternalField(default=None, name="WML client")
|
| 846 |
|
| 847 |
+
def get_engine_id(self):
|
| 848 |
+
return get_model_and_label_id(self.model_name, self.label)
|
| 849 |
+
|
| 850 |
def verify(self):
|
| 851 |
super().verify()
|
| 852 |
|
| 853 |
if self.credentials is not None:
|
| 854 |
for key in self.credentials:
|
| 855 |
+
if key not in ["url", "apikey", "project_id", "space_id"]:
|
| 856 |
raise ValueError(
|
| 857 |
+
f'Illegal credential key: {key}, use only ["url", "apikey", "project_id", "space_id"]'
|
| 858 |
)
|
| 859 |
|
| 860 |
assert (
|
|
|
|
| 874 |
|
| 875 |
@staticmethod
|
| 876 |
def _read_wml_credentials_from_env() -> (
|
| 877 |
+
Dict[Literal["url", "apikey", "project_id", "space_id"], str]
|
| 878 |
):
|
| 879 |
credentials = {}
|
| 880 |
+
project_or_deployment_var_name = (
|
| 881 |
+
"WML_SPACE_ID" if "WML_SPACE_ID" in os.environ else "WML_PROJECT_ID"
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
for env_var_name in ["WML_URL", project_or_deployment_var_name, "WML_APIKEY"]:
|
| 885 |
env_var = os.environ.get(env_var_name)
|
| 886 |
assert env_var, (
|
| 887 |
f"Error while trying to run 'WMLInferenceEngine'. "
|
|
|
|
| 902 |
self.credentials = self._read_wml_credentials_from_env()
|
| 903 |
|
| 904 |
client = APIClient(credentials=self.credentials)
|
| 905 |
+
if "space_id" in self.credentials:
|
| 906 |
+
client.set.default_space(self.credentials["space_id"])
|
| 907 |
+
else:
|
| 908 |
+
client.set.default_project(self.credentials["project_id"])
|
| 909 |
return client
|
| 910 |
|
| 911 |
def prepare_engine(self):
|
|
|
|
| 913 |
|
| 914 |
self._set_inference_parameters()
|
| 915 |
|
| 916 |
+
def _load_model_and_params(self):
|
| 917 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
| 918 |
|
| 919 |
model = ModelInference(
|
|
|
|
| 921 |
deployment_id=self.deployment_id,
|
| 922 |
api_client=self._client,
|
| 923 |
)
|
| 924 |
+
params = self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False)
|
| 925 |
|
| 926 |
+
return model, params
|
|
|
|
| 927 |
|
| 928 |
+
def _infer(
|
| 929 |
+
self,
|
| 930 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 931 |
+
return_meta_data: bool = False,
|
| 932 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 933 |
+
model, params = self._load_model_and_params()
|
| 934 |
+
|
| 935 |
+
result = []
|
| 936 |
+
for instance in dataset:
|
| 937 |
+
instance_result = model.generate(
|
| 938 |
prompt=instance["source"],
|
| 939 |
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
|
| 940 |
)
|
| 941 |
+
prediction = instance_result["results"][0]["generated_text"]
|
| 942 |
+
instance_final_results = self.get_return_object(
|
| 943 |
+
prediction, instance_result, return_meta_data
|
| 944 |
+
)
|
| 945 |
+
result.append(instance_final_results)
|
| 946 |
+
|
| 947 |
+
return result
|
| 948 |
+
|
| 949 |
+
def _infer_log_probs(
|
| 950 |
+
self,
|
| 951 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 952 |
+
return_meta_data: bool = False,
|
| 953 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 954 |
+
model, params = self._load_model_and_params()
|
| 955 |
+
|
| 956 |
+
user_return_options = params.pop("return_options", {})
|
| 957 |
+
# currently this is the only configuration that returns generated logprobs and behaves as expected
|
| 958 |
+
logprobs_return_options = {
|
| 959 |
+
"input_tokens": True,
|
| 960 |
+
"generated_tokens": True,
|
| 961 |
+
"token_logprobs": True,
|
| 962 |
+
"top_n_tokens": user_return_options.get("top_n_tokens", 5),
|
| 963 |
+
}
|
| 964 |
+
for key, value in logprobs_return_options.items():
|
| 965 |
+
if key in user_return_options and user_return_options[key] != value:
|
| 966 |
+
raise ValueError(
|
| 967 |
+
f"'{key}={user_return_options[key]}' is not supported for the 'infer_log_probs' "
|
| 968 |
+
f"method of {self.__class__.__name__}. For obtaining the logprobs of generated tokens "
|
| 969 |
+
f"please use '{key}={value}'."
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
params = {
|
| 973 |
+
**params,
|
| 974 |
+
"return_options": logprobs_return_options,
|
| 975 |
+
}
|
| 976 |
|
| 977 |
+
results = model.generate(
|
| 978 |
+
prompt=[instance["source"] for instance in dataset],
|
| 979 |
+
params=params,
|
| 980 |
+
)
|
| 981 |
+
final_results = []
|
| 982 |
+
for result in results:
|
| 983 |
+
generated_tokens = result["results"][0]["generated_tokens"]
|
| 984 |
+
final_results.append(
|
| 985 |
+
self.get_return_object(generated_tokens, result, return_meta_data)
|
| 986 |
+
)
|
| 987 |
+
return final_results
|
| 988 |
+
|
| 989 |
+
def get_return_object(self, predict_result, result, return_meta_data):
|
| 990 |
+
if return_meta_data:
|
| 991 |
+
return TextGenerationInferenceOutput(
|
| 992 |
+
prediction=predict_result,
|
| 993 |
+
input_tokens=result["results"][0]["input_token_count"],
|
| 994 |
+
output_tokens=result["results"][0]["generated_token_count"],
|
| 995 |
+
model_name=self.model_name,
|
| 996 |
+
inference_type=self.label,
|
| 997 |
+
)
|
| 998 |
+
return predict_result
|
| 999 |
|
| 1000 |
|
| 1001 |
class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
|
|
| 1009 |
"accelerate": "pip install accelerate",
|
| 1010 |
}
|
| 1011 |
|
| 1012 |
+
def get_engine_id(self):
|
| 1013 |
+
return get_model_and_label_id(self.model_name, "hf_lava")
|
| 1014 |
+
|
| 1015 |
def _prepare_engine(self):
|
| 1016 |
import torch
|
| 1017 |
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
|
|
| 1039 |
def _is_loaded(self):
|
| 1040 |
return hasattr(self, "model") and self.model is not None
|
| 1041 |
|
| 1042 |
+
def _infer(
|
| 1043 |
+
self,
|
| 1044 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
| 1045 |
+
return_meta_data: bool = False,
|
| 1046 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1047 |
if not self._is_loaded():
|
| 1048 |
self._prepare_engine()
|
| 1049 |
|
| 1050 |
import torch
|
| 1051 |
|
| 1052 |
results = []
|
| 1053 |
+
for instance in tqdm(dataset):
|
| 1054 |
text = instance["source"]
|
| 1055 |
images = extract_images(text, instance)
|
| 1056 |
# Regular expression to match all <img src="..."> tags
|
|
|
|
| 1063 |
).to(self.device, torch.float16)
|
| 1064 |
input_len = len(inputs["input_ids"][0])
|
| 1065 |
output = self.model.generate(
|
| 1066 |
+
**inputs,
|
| 1067 |
+
max_new_tokens=self.max_new_tokens,
|
| 1068 |
+
do_sample=False,
|
| 1069 |
+
pad_token_id=self.processor.tokenizer.eos_token_id,
|
| 1070 |
)
|
| 1071 |
result = self.processor.decode(
|
| 1072 |
output[0][input_len:], skip_special_tokens=True
|
llm_as_judge.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
|
|
| 1 |
from typing import Any, Dict, List, Literal, Optional
|
| 2 |
|
| 3 |
from .api import infer
|
| 4 |
from .artifact import fetch_artifact
|
| 5 |
from .dataclass import Field
|
| 6 |
from .formats import Format, SystemFormat
|
| 7 |
-
from .inference import InferenceEngine, OpenAiInferenceEngine
|
| 8 |
from .metrics import BulkInstanceMetric
|
| 9 |
from .operator import SequentialOperator
|
| 10 |
from .settings_utils import get_settings
|
|
@@ -14,38 +15,142 @@ from .templates import Template
|
|
| 14 |
settings = get_settings()
|
| 15 |
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
Attributes:
|
| 21 |
main_score (str): The main score label used for evaluation.
|
| 22 |
-
task (
|
| 23 |
format of the judge model.
|
| 24 |
template (Template): The template used when generating inputs for the judge llm.
|
| 25 |
format (Format): The format used when generating inputs for judge llm.
|
| 26 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
| 27 |
-
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
| 28 |
-
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
| 29 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
| 30 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
| 31 |
batch_size (int): The size of the bulk.
|
| 32 |
"""
|
| 33 |
|
| 34 |
main_score: str = "llm_as_judge"
|
| 35 |
-
task:
|
| 36 |
-
"rating.single_turn",
|
| 37 |
-
"rating.single_turn_with_reference",
|
| 38 |
-
"pairwise_comparative_rating.single_turn",
|
| 39 |
-
]
|
| 40 |
template: Template
|
| 41 |
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
| 42 |
format: Format = Field(default_factory=SystemFormat)
|
| 43 |
-
strip_system_prompt_and_format_from_inputs: bool = True
|
| 44 |
inference_model: InferenceEngine
|
| 45 |
reduction_map: Optional[Dict[str, List[str]]] = None
|
| 46 |
batch_size: int = 32
|
| 47 |
prediction_type = Any # Because handled with multiple tasks
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def _get_input_instances(self, task_data: List[Dict]) -> List:
|
| 50 |
if self.strip_system_prompt_and_format_from_inputs:
|
| 51 |
instances = []
|
|
@@ -119,6 +224,7 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
| 119 |
self.reduction_map = {"mean": [self.main_score]}
|
| 120 |
|
| 121 |
def verify(self):
|
|
|
|
| 122 |
supported_tasks = [
|
| 123 |
"rating.single_turn",
|
| 124 |
"rating.single_turn_with_reference",
|
|
@@ -129,68 +235,25 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
| 129 |
f"The supported tasks types are: {', '.join(supported_tasks)}."
|
| 130 |
)
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
f"Provided template argument to 'LLMAsJudge' metric is not of type Template, but {type(self.template)}"
|
| 135 |
-
)
|
| 136 |
-
if self.format and not isinstance(self.format, Format):
|
| 137 |
-
raise ValueError(
|
| 138 |
-
f"Provided format argument to 'LLMAsJudge' metric is not of type Format, but {type(self.format)}"
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
if self.system_prompt and not isinstance(self.system_prompt, SystemPrompt):
|
| 142 |
-
raise ValueError(
|
| 143 |
-
f"Provided system_prompt argument to 'LLMAsJudge' metric is not of type SystemPrompt, but {type(self.system_prompt)}"
|
| 144 |
-
)
|
| 145 |
-
|
| 146 |
-
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
| 147 |
-
if self.format and type(self.format) is not SystemFormat:
|
| 148 |
-
raise ValueError(
|
| 149 |
-
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
| 150 |
-
"not support formatting. Please remove the format definition from the recipe"
|
| 151 |
-
" (OpenAi Chat API take care of the formatting automatically)."
|
| 152 |
-
)
|
| 153 |
-
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
| 154 |
-
raise ValueError(
|
| 155 |
-
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
| 156 |
-
"not support system prompt. Please remove the system_prompt definition from the recipe"
|
| 157 |
-
" (Current implementation of Unitxt does not support this."
|
| 158 |
-
" Support will be added in future updates)."
|
| 159 |
-
)
|
| 160 |
|
| 161 |
-
def
|
| 162 |
-
|
| 163 |
-
references: List[List[Any]],
|
| 164 |
-
predictions: List[Any],
|
| 165 |
-
task_data: List[Dict],
|
| 166 |
-
) -> List[Dict[str, Any]]:
|
| 167 |
-
input_instances = self._get_input_instances(task_data)
|
| 168 |
-
instances = self._get_instance_for_judge_model(
|
| 169 |
-
input_instances, predictions, references
|
| 170 |
-
)
|
| 171 |
-
outputs = infer(
|
| 172 |
instances,
|
| 173 |
engine=self.inference_model,
|
| 174 |
-
task=
|
| 175 |
template=self.template,
|
| 176 |
system_prompt=self.system_prompt,
|
| 177 |
format=self.format,
|
| 178 |
return_data=True,
|
| 179 |
)
|
| 180 |
|
|
|
|
| 181 |
results = []
|
| 182 |
for instance in outputs:
|
| 183 |
if self.task == "pairwise_comparative_rating.single_turn":
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
# seems like the task data sometimes comes as a string, not a dict
|
| 187 |
-
# this fixes it
|
| 188 |
-
task_data = (
|
| 189 |
-
json.loads(instance["task_data"])
|
| 190 |
-
if isinstance(instance["task_data"], str)
|
| 191 |
-
else instance["task_data"]
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
is_model_b_the_baseline = task_data["model_b"] == "baseline_model"
|
| 195 |
if is_model_b_the_baseline:
|
| 196 |
model_a_preference_score = instance["prediction"]
|
|
@@ -209,5 +272,141 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
| 209 |
"judge_raw_input": instance["source"],
|
| 210 |
}
|
| 211 |
results.append(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
from typing import Any, Dict, List, Literal, Optional
|
| 3 |
|
| 4 |
from .api import infer
|
| 5 |
from .artifact import fetch_artifact
|
| 6 |
from .dataclass import Field
|
| 7 |
from .formats import Format, SystemFormat
|
| 8 |
+
from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
|
| 9 |
from .metrics import BulkInstanceMetric
|
| 10 |
from .operator import SequentialOperator
|
| 11 |
from .settings_utils import get_settings
|
|
|
|
| 15 |
settings = get_settings()
|
| 16 |
|
| 17 |
|
| 18 |
+
def get_task_data_dict(task_data):
|
| 19 |
+
import json
|
| 20 |
+
|
| 21 |
+
# seems like the task data sometimes comes as a string, not a dict
|
| 22 |
+
# this fixes it
|
| 23 |
+
return json.loads(task_data) if isinstance(task_data, str) else task_data
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LLMAsJudgeBase(BulkInstanceMetric):
|
| 27 |
+
"""LLM-as-judge-base metric class for evaluating correctness of generated predictions.
|
| 28 |
|
| 29 |
Attributes:
|
| 30 |
main_score (str): The main score label used for evaluation.
|
| 31 |
+
task (str): The type of task the llm as judge runs. This defines the output and input
|
| 32 |
format of the judge model.
|
| 33 |
template (Template): The template used when generating inputs for the judge llm.
|
| 34 |
format (Format): The format used when generating inputs for judge llm.
|
| 35 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
|
|
|
|
|
|
| 36 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
| 37 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
| 38 |
batch_size (int): The size of the bulk.
|
| 39 |
"""
|
| 40 |
|
| 41 |
main_score: str = "llm_as_judge"
|
| 42 |
+
task: str
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
template: Template
|
| 44 |
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
| 45 |
format: Format = Field(default_factory=SystemFormat)
|
|
|
|
| 46 |
inference_model: InferenceEngine
|
| 47 |
reduction_map: Optional[Dict[str, List[str]]] = None
|
| 48 |
batch_size: int = 32
|
| 49 |
prediction_type = Any # Because handled with multiple tasks
|
| 50 |
|
| 51 |
+
def verify(self):
|
| 52 |
+
if not isinstance(self.template, Template):
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f"Provided template argument to 'LLMAsJudge' metric is not of type Template, but {type(self.template)}"
|
| 55 |
+
)
|
| 56 |
+
if self.format and not isinstance(self.format, Format):
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f"Provided format argument to 'LLMAsJudge' metric is not of type Format, but {type(self.format)}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if self.system_prompt and not isinstance(self.system_prompt, SystemPrompt):
|
| 62 |
+
raise ValueError(
|
| 63 |
+
f"Provided system_prompt argument to 'LLMAsJudge' metric is not of type SystemPrompt, but {type(self.system_prompt)}"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
| 67 |
+
if self.format and type(self.format) is not SystemFormat:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
| 70 |
+
"not support formatting. Please remove the format definition from the recipe"
|
| 71 |
+
" (OpenAi Chat API take care of the formatting automatically)."
|
| 72 |
+
)
|
| 73 |
+
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
| 74 |
+
raise ValueError(
|
| 75 |
+
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
| 76 |
+
"not support system prompt. Please remove the system_prompt definition from the recipe"
|
| 77 |
+
" (Current implementation of Unitxt does not support this."
|
| 78 |
+
" Support will be added in future updates)."
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def get_full_task_name(self):
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
def compute(
|
| 86 |
+
self,
|
| 87 |
+
references: List[List[Any]],
|
| 88 |
+
predictions: List[Any],
|
| 89 |
+
task_data: List[Dict],
|
| 90 |
+
) -> List[Dict[str, Any]]:
|
| 91 |
+
instances = self.prepare_instances(references, predictions, task_data)
|
| 92 |
+
outputs = self.infer_instances(instances)
|
| 93 |
+
return self.get_metric_results_from_prediction_outputs(outputs)
|
| 94 |
+
|
| 95 |
+
@abstractmethod
|
| 96 |
+
def prepare_instances(
|
| 97 |
+
self, references, predictions, task_data
|
| 98 |
+
) -> List[Dict[str, Any]]:
|
| 99 |
+
"""Generate a list of instances for inference.
|
| 100 |
+
|
| 101 |
+
Each generated instance should include all the fields required by the metrics' task and template, to
|
| 102 |
+
create the source prompt for the judge.
|
| 103 |
+
"""
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
@abstractmethod
|
| 107 |
+
def infer_instances(self, instances: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 108 |
+
"""Generate the dataset and call the inference engine to generate the judges' predictions.
|
| 109 |
+
|
| 110 |
+
Return the list of the produced instances with their generated judge predictions.
|
| 111 |
+
"""
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
@abstractmethod
|
| 115 |
+
def get_metric_results_from_prediction_outputs(
|
| 116 |
+
self, outputs: List[Dict[str, Any]]
|
| 117 |
+
) -> List[Dict[str, Any]]:
|
| 118 |
+
"""Generate a scores' dictionary for each instance.
|
| 119 |
+
|
| 120 |
+
Return the list of scores dictionaries for the input instances.
|
| 121 |
+
"""
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class LLMAsJudge(LLMAsJudgeBase):
|
| 126 |
+
"""LLM-as-judge-based metric class for evaluating correctness of generated predictions.
|
| 127 |
+
|
| 128 |
+
This class uses the source prompt given to the generator and the generator's predictions to evaluate
|
| 129 |
+
correctness using one of three supported tasks (rating.single_turn, rating.single_turn_with_reference,
|
| 130 |
+
pairwise_comparative_rating.single_turn).
|
| 131 |
+
|
| 132 |
+
Attributes:
|
| 133 |
+
main_score (str): The main score label used for evaluation.
|
| 134 |
+
task (Literal["rating.single_turn","rating.single_turn_with_reference",
|
| 135 |
+
"pairwise_comparative_rating.single_turn"]): The type of task the llm as judge runs.
|
| 136 |
+
This defines the output and input format of the judge model.
|
| 137 |
+
template (Template): The template used when generating inputs for the judge llm.
|
| 138 |
+
format (Format): The format used when generating inputs for judge llm.
|
| 139 |
+
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
| 140 |
+
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
| 141 |
+
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
| 142 |
+
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
| 143 |
+
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
| 144 |
+
batch_size (int): The size of the bulk.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
task: Literal[
|
| 148 |
+
"rating.single_turn",
|
| 149 |
+
"rating.single_turn_with_reference",
|
| 150 |
+
"pairwise_comparative_rating.single_turn",
|
| 151 |
+
]
|
| 152 |
+
strip_system_prompt_and_format_from_inputs: bool = True
|
| 153 |
+
|
| 154 |
def _get_input_instances(self, task_data: List[Dict]) -> List:
|
| 155 |
if self.strip_system_prompt_and_format_from_inputs:
|
| 156 |
instances = []
|
|
|
|
| 224 |
self.reduction_map = {"mean": [self.main_score]}
|
| 225 |
|
| 226 |
def verify(self):
|
| 227 |
+
super().verify()
|
| 228 |
supported_tasks = [
|
| 229 |
"rating.single_turn",
|
| 230 |
"rating.single_turn_with_reference",
|
|
|
|
| 235 |
f"The supported tasks types are: {', '.join(supported_tasks)}."
|
| 236 |
)
|
| 237 |
|
| 238 |
+
def get_full_task_name(self):
|
| 239 |
+
return f"tasks.response_assessment.{self.task}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
def infer_instances(self, instances):
|
| 242 |
+
return infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
instances,
|
| 244 |
engine=self.inference_model,
|
| 245 |
+
task=self.get_full_task_name(),
|
| 246 |
template=self.template,
|
| 247 |
system_prompt=self.system_prompt,
|
| 248 |
format=self.format,
|
| 249 |
return_data=True,
|
| 250 |
)
|
| 251 |
|
| 252 |
+
def get_metric_results_from_prediction_outputs(self, outputs):
|
| 253 |
results = []
|
| 254 |
for instance in outputs:
|
| 255 |
if self.task == "pairwise_comparative_rating.single_turn":
|
| 256 |
+
task_data = get_task_data_dict(instance["task_data"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
is_model_b_the_baseline = task_data["model_b"] == "baseline_model"
|
| 258 |
if is_model_b_the_baseline:
|
| 259 |
model_a_preference_score = instance["prediction"]
|
|
|
|
| 272 |
"judge_raw_input": instance["source"],
|
| 273 |
}
|
| 274 |
results.append(result)
|
| 275 |
+
return results
|
| 276 |
+
|
| 277 |
+
def prepare_instances(self, references, predictions, task_data):
|
| 278 |
+
input_instances = self._get_input_instances(task_data)
|
| 279 |
+
return self._get_instance_for_judge_model(
|
| 280 |
+
input_instances, predictions, references
|
| 281 |
+
)
|
| 282 |
|
| 283 |
+
|
| 284 |
+
class TaskBasedLLMasJudge(LLMAsJudgeBase):
|
| 285 |
+
"""LLM-as-judge-based metric class for evaluating correctness of generated predictions.
|
| 286 |
+
|
| 287 |
+
This class can use any task and matching template to evaluate the predictions. All
|
| 288 |
+
task/templates field are taken from the instance's task_data.
|
| 289 |
+
The instances sent to the judge can either be: 1.a unitxt dataset, in which case the predictions are
|
| 290 |
+
copied to a specified field of the task. 2. dictionaries with the fields required by the task and template.
|
| 291 |
+
|
| 292 |
+
Attributes:
|
| 293 |
+
main_score (str): The main score label used for evaluation.
|
| 294 |
+
task (str): The type of task the llm as judge runs.
|
| 295 |
+
This defines the output and input format of the judge model.
|
| 296 |
+
template (Template): The template used when generating inputs for the judge llm.
|
| 297 |
+
format (Format): The format used when generating inputs for judge llm.
|
| 298 |
+
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
| 299 |
+
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
| 300 |
+
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
| 301 |
+
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
| 302 |
+
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
| 303 |
+
batch_size (int): The size of the bulk.
|
| 304 |
+
infer_log_probs(bool): whether to perform the inference using logprobs. If true, the template's
|
| 305 |
+
post-processing must support the logprobs output.
|
| 306 |
+
judge_to_generator_fields_mapping (Dict[str, str]): optional mapping between the names of the fields in the generator task and the
|
| 307 |
+
judge task. For example, if the generator task uses "reference_answers" and the judge task expect "ground_truth",
|
| 308 |
+
include {"ground_truth": "reference_answers"} in this dictionary.
|
| 309 |
+
prediction_field: if indicated, and prediction exist, copy prediction to this field name in task_data.
|
| 310 |
+
include_meta_data (bool): whether to include the inference per-instance metadata in the returned results.
|
| 311 |
+
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
infer_log_probs: bool = False
|
| 315 |
+
judge_to_generator_fields_mapping: Dict[str, str] = {}
|
| 316 |
+
prediction_field: Optional[str] = None
|
| 317 |
+
include_meta_data: bool = True
|
| 318 |
+
|
| 319 |
+
# Allow for input which is a dictionary of all input fields. In this case, all input fields are
|
| 320 |
+
# treated as the task data, and the predictions and references are taken directly from there
|
| 321 |
+
# by the judge's template
|
| 322 |
+
def preprocess_instance(self, instance):
|
| 323 |
+
if "task_data" not in instance:
|
| 324 |
+
instance["task_data"] = instance.copy()
|
| 325 |
+
if "prediction" not in instance:
|
| 326 |
+
instance["prediction"] = None
|
| 327 |
+
if "references" not in instance:
|
| 328 |
+
instance["references"] = [""]
|
| 329 |
+
return instance
|
| 330 |
+
|
| 331 |
+
def verify(self):
|
| 332 |
+
super().verify()
|
| 333 |
+
if self.infer_log_probs and not isinstance(
|
| 334 |
+
self.inference_model, LogProbInferenceEngine
|
| 335 |
+
):
|
| 336 |
+
raise NotImplementedError(
|
| 337 |
+
f"Error in TaskBasedLLMasJudge: return_log_probs set to True but supplied engine "
|
| 338 |
+
f"{self.inference_model.__class__.__name__} does not support logprobs."
|
| 339 |
+
)
|
| 340 |
+
if self.include_meta_data and not hasattr(
|
| 341 |
+
self.inference_model, "get_return_object"
|
| 342 |
+
):
|
| 343 |
+
Warning(
|
| 344 |
+
f"Supplied inference engine {self.inference_model.__class__.__name__} does not support "
|
| 345 |
+
"return_meta_data. Setting return_meta_data to False. Metadata scores will not appear "
|
| 346 |
+
"in returned instances scores."
|
| 347 |
+
)
|
| 348 |
+
self.include_meta_data = False
|
| 349 |
+
|
| 350 |
+
def prepare(self):
|
| 351 |
+
super().prepare()
|
| 352 |
+
self.reduction_map = {"mean": [self.main_score]}
|
| 353 |
+
self.score_prefix = f"{self.inference_model.get_engine_id()}_"
|
| 354 |
+
|
| 355 |
+
def get_full_task_name(self):
|
| 356 |
+
return self.task
|
| 357 |
+
|
| 358 |
+
def get_metric_results_from_prediction_outputs(self, outputs):
|
| 359 |
+
results = []
|
| 360 |
+
for instance in outputs:
|
| 361 |
+
result = {
|
| 362 |
+
self.main_score: instance["prediction"],
|
| 363 |
+
f"{self.main_score}_judge_raw_output": instance["raw_prediction"],
|
| 364 |
+
f"{self.main_score}_judge_raw_input": instance["source"],
|
| 365 |
+
}
|
| 366 |
+
if self.include_meta_data:
|
| 367 |
+
meta_data = {
|
| 368 |
+
f"{self.main_score}_{k}": v
|
| 369 |
+
for k, v in instance["infer_meta_data"].items()
|
| 370 |
+
}
|
| 371 |
+
result.update(meta_data)
|
| 372 |
+
results.append(result)
|
| 373 |
return results
|
| 374 |
+
|
| 375 |
+
def prepare_instances(self, references, predictions, task_data):
|
| 376 |
+
from . import get_from_catalog
|
| 377 |
+
|
| 378 |
+
instances = []
|
| 379 |
+
judge_task = get_from_catalog(self.get_full_task_name())
|
| 380 |
+
judge_task_input_fields = judge_task.input_fields
|
| 381 |
+
|
| 382 |
+
for input_instance, prediction, _ in zip(task_data, predictions, references):
|
| 383 |
+
input_instance = get_task_data_dict(input_instance)
|
| 384 |
+
|
| 385 |
+
instance_task_data = {}
|
| 386 |
+
for judge_task_input_field in judge_task_input_fields:
|
| 387 |
+
orig_task_field_name = self.judge_to_generator_fields_mapping.get(
|
| 388 |
+
judge_task_input_field, judge_task_input_field
|
| 389 |
+
)
|
| 390 |
+
new_val = input_instance.get(orig_task_field_name)
|
| 391 |
+
if new_val:
|
| 392 |
+
instance_task_data[judge_task_input_field] = new_val
|
| 393 |
+
|
| 394 |
+
if self.prediction_field and prediction:
|
| 395 |
+
instance_task_data[self.prediction_field] = str(prediction)
|
| 396 |
+
instance_task_data = judge_task.process(instance_task_data)["input_fields"]
|
| 397 |
+
instances.append(instance_task_data)
|
| 398 |
+
|
| 399 |
+
return instances
|
| 400 |
+
|
| 401 |
+
def infer_instances(self, instances):
|
| 402 |
+
return infer(
|
| 403 |
+
instances,
|
| 404 |
+
engine=self.inference_model,
|
| 405 |
+
task=self.get_full_task_name(),
|
| 406 |
+
template=self.template,
|
| 407 |
+
system_prompt=self.system_prompt,
|
| 408 |
+
format=self.format,
|
| 409 |
+
return_data=True,
|
| 410 |
+
return_log_probs=self.infer_log_probs,
|
| 411 |
+
return_meta_data=self.include_meta_data,
|
| 412 |
+
)
|
loaders.py
CHANGED
|
@@ -53,7 +53,7 @@ from .operators import Set
|
|
| 53 |
from .settings_utils import get_settings
|
| 54 |
from .stream import DynamicStream, MultiStream
|
| 55 |
from .type_utils import isoftype
|
| 56 |
-
from .utils import
|
| 57 |
|
| 58 |
logger = get_logger()
|
| 59 |
settings = get_settings()
|
|
@@ -195,6 +195,10 @@ class LoadHF(Loader):
|
|
| 195 |
def stream_dataset(self):
|
| 196 |
if self._cache is None:
|
| 197 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
try:
|
| 199 |
dataset = hf_load_dataset(
|
| 200 |
self.path,
|
|
@@ -203,7 +207,7 @@ class LoadHF(Loader):
|
|
| 203 |
data_files=self.data_files,
|
| 204 |
revision=self.revision,
|
| 205 |
streaming=self.streaming,
|
| 206 |
-
cache_dir=
|
| 207 |
split=self.split,
|
| 208 |
trust_remote_code=settings.allow_unverified_code,
|
| 209 |
num_proc=self.num_proc,
|
|
@@ -231,6 +235,10 @@ class LoadHF(Loader):
|
|
| 231 |
def load_dataset(self):
|
| 232 |
if self._cache is None:
|
| 233 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
try:
|
| 235 |
dataset = hf_load_dataset(
|
| 236 |
self.path,
|
|
@@ -239,7 +247,7 @@ class LoadHF(Loader):
|
|
| 239 |
data_files=self.data_files,
|
| 240 |
streaming=False,
|
| 241 |
keep_in_memory=True,
|
| 242 |
-
cache_dir=
|
| 243 |
split=self.split,
|
| 244 |
trust_remote_code=settings.allow_unverified_code,
|
| 245 |
num_proc=self.num_proc,
|
|
@@ -664,7 +672,7 @@ class MultipleSourceLoader(Loader):
|
|
| 664 |
|
| 665 |
.. code-block:: python
|
| 666 |
|
| 667 |
-
MultipleSourceLoader(
|
| 668 |
|
| 669 |
|
| 670 |
|
|
@@ -672,7 +680,7 @@ class MultipleSourceLoader(Loader):
|
|
| 672 |
|
| 673 |
.. code-block:: python
|
| 674 |
|
| 675 |
-
MultipleSourceLoader(
|
| 676 |
"""
|
| 677 |
|
| 678 |
sources: List[Loader]
|
|
@@ -737,7 +745,7 @@ class LoadFromDictionary(Loader):
|
|
| 737 |
self.sef_default_data_classification(
|
| 738 |
["proprietary"], "when loading from python dictionary"
|
| 739 |
)
|
| 740 |
-
return MultiStream.from_iterables(
|
| 741 |
|
| 742 |
|
| 743 |
class LoadFromHFSpace(LoadHF):
|
|
|
|
| 53 |
from .settings_utils import get_settings
|
| 54 |
from .stream import DynamicStream, MultiStream
|
| 55 |
from .type_utils import isoftype
|
| 56 |
+
from .utils import recursive_copy
|
| 57 |
|
| 58 |
logger = get_logger()
|
| 59 |
settings = get_settings()
|
|
|
|
| 195 |
def stream_dataset(self):
|
| 196 |
if self._cache is None:
|
| 197 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
| 198 |
+
if settings.disable_hf_datasets_cache and not self.streaming:
|
| 199 |
+
cache_dir = dir_to_be_deleted
|
| 200 |
+
else:
|
| 201 |
+
cache_dir = None
|
| 202 |
try:
|
| 203 |
dataset = hf_load_dataset(
|
| 204 |
self.path,
|
|
|
|
| 207 |
data_files=self.data_files,
|
| 208 |
revision=self.revision,
|
| 209 |
streaming=self.streaming,
|
| 210 |
+
cache_dir=cache_dir,
|
| 211 |
split=self.split,
|
| 212 |
trust_remote_code=settings.allow_unverified_code,
|
| 213 |
num_proc=self.num_proc,
|
|
|
|
| 235 |
def load_dataset(self):
|
| 236 |
if self._cache is None:
|
| 237 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
| 238 |
+
if settings.disable_hf_datasets_cache:
|
| 239 |
+
cache_dir = dir_to_be_deleted
|
| 240 |
+
else:
|
| 241 |
+
cache_dir = None
|
| 242 |
try:
|
| 243 |
dataset = hf_load_dataset(
|
| 244 |
self.path,
|
|
|
|
| 247 |
data_files=self.data_files,
|
| 248 |
streaming=False,
|
| 249 |
keep_in_memory=True,
|
| 250 |
+
cache_dir=cache_dir,
|
| 251 |
split=self.split,
|
| 252 |
trust_remote_code=settings.allow_unverified_code,
|
| 253 |
num_proc=self.num_proc,
|
|
|
|
| 672 |
|
| 673 |
.. code-block:: python
|
| 674 |
|
| 675 |
+
MultipleSourceLoader(sources = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
|
| 676 |
|
| 677 |
|
| 678 |
|
|
|
|
| 680 |
|
| 681 |
.. code-block:: python
|
| 682 |
|
| 683 |
+
MultipleSourceLoader(sources = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
|
| 684 |
"""
|
| 685 |
|
| 686 |
sources: List[Loader]
|
|
|
|
| 745 |
self.sef_default_data_classification(
|
| 746 |
["proprietary"], "when loading from python dictionary"
|
| 747 |
)
|
| 748 |
+
return MultiStream.from_iterables(recursive_copy(self.data))
|
| 749 |
|
| 750 |
|
| 751 |
class LoadFromHFSpace(LoadHF):
|
metric_utils.py
CHANGED
|
@@ -16,8 +16,8 @@ from .operator import (
|
|
| 16 |
from .operators import (
|
| 17 |
ApplyMetric,
|
| 18 |
ApplyOperatorsField,
|
| 19 |
-
Copy,
|
| 20 |
FlattenInstances,
|
|
|
|
| 21 |
Rename,
|
| 22 |
)
|
| 23 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
|
@@ -25,7 +25,7 @@ from .schema import UNITXT_DATASET_SCHEMA
|
|
| 25 |
from .settings_utils import get_constants, get_settings
|
| 26 |
from .stream import DynamicStream, MultiStream
|
| 27 |
from .struct_data_operators import LoadJson
|
| 28 |
-
from .utils import
|
| 29 |
|
| 30 |
constants = get_constants()
|
| 31 |
|
|
@@ -54,27 +54,27 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
| 54 |
|
| 55 |
_post_process_steps = SequentialOperator(
|
| 56 |
steps=[
|
| 57 |
-
|
| 58 |
field="prediction",
|
| 59 |
to_field="raw_prediction",
|
| 60 |
),
|
| 61 |
-
|
| 62 |
field="references",
|
| 63 |
to_field="raw_references",
|
| 64 |
dont_apply_to_streams=[constants.inference_stream],
|
| 65 |
),
|
| 66 |
-
|
| 67 |
field="source",
|
| 68 |
to_field="task_data/source",
|
| 69 |
),
|
| 70 |
ApplyOperatorsField(
|
| 71 |
operators_field="postprocessors",
|
| 72 |
),
|
| 73 |
-
|
| 74 |
field="prediction",
|
| 75 |
to_field="processed_prediction",
|
| 76 |
),
|
| 77 |
-
|
| 78 |
field="references",
|
| 79 |
to_field="processed_references",
|
| 80 |
dont_apply_to_streams=[constants.inference_stream],
|
|
@@ -213,14 +213,19 @@ class JoinSubsetsAndGroups(MultiStreamOperator):
|
|
| 213 |
|
| 214 |
result = {}
|
| 215 |
all_scores = []
|
|
|
|
| 216 |
for k, v in dic.items():
|
| 217 |
score = recursive_mean(v)
|
| 218 |
if score is not None:
|
| 219 |
all_scores.append(score["score"])
|
|
|
|
|
|
|
| 220 |
result[k] = score
|
| 221 |
|
| 222 |
result["score"] = nan_mean(all_scores)
|
| 223 |
result["score_name"] = "subsets_mean"
|
|
|
|
|
|
|
| 224 |
|
| 225 |
if result:
|
| 226 |
return result
|
|
@@ -237,11 +242,15 @@ class JoinSubsetsAndGroups(MultiStreamOperator):
|
|
| 237 |
"score": score["subsets"]["score"],
|
| 238 |
"score_name": score["subsets"]["score_name"],
|
| 239 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
sorted_instances = []
|
| 242 |
for key in sorted(stream_instances.keys()):
|
| 243 |
instance = stream_instances[key]
|
| 244 |
-
instance["score"].update(
|
| 245 |
sorted_instances.append(instance)
|
| 246 |
result[stream_name] = sorted_instances
|
| 247 |
|
|
@@ -299,7 +308,7 @@ class MetricRecipe(SequentialOperatorInitializer):
|
|
| 299 |
field="raw_references",
|
| 300 |
to_field="references",
|
| 301 |
),
|
| 302 |
-
|
| 303 |
field="source",
|
| 304 |
to_field="task_data/source",
|
| 305 |
),
|
|
|
|
| 16 |
from .operators import (
|
| 17 |
ApplyMetric,
|
| 18 |
ApplyOperatorsField,
|
|
|
|
| 19 |
FlattenInstances,
|
| 20 |
+
RecursiveCopy,
|
| 21 |
Rename,
|
| 22 |
)
|
| 23 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
|
|
|
| 25 |
from .settings_utils import get_constants, get_settings
|
| 26 |
from .stream import DynamicStream, MultiStream
|
| 27 |
from .struct_data_operators import LoadJson
|
| 28 |
+
from .utils import recursive_shallow_copy
|
| 29 |
|
| 30 |
constants = get_constants()
|
| 31 |
|
|
|
|
| 54 |
|
| 55 |
_post_process_steps = SequentialOperator(
|
| 56 |
steps=[
|
| 57 |
+
RecursiveCopy(
|
| 58 |
field="prediction",
|
| 59 |
to_field="raw_prediction",
|
| 60 |
),
|
| 61 |
+
RecursiveCopy(
|
| 62 |
field="references",
|
| 63 |
to_field="raw_references",
|
| 64 |
dont_apply_to_streams=[constants.inference_stream],
|
| 65 |
),
|
| 66 |
+
RecursiveCopy(
|
| 67 |
field="source",
|
| 68 |
to_field="task_data/source",
|
| 69 |
),
|
| 70 |
ApplyOperatorsField(
|
| 71 |
operators_field="postprocessors",
|
| 72 |
),
|
| 73 |
+
RecursiveCopy(
|
| 74 |
field="prediction",
|
| 75 |
to_field="processed_prediction",
|
| 76 |
),
|
| 77 |
+
RecursiveCopy(
|
| 78 |
field="references",
|
| 79 |
to_field="processed_references",
|
| 80 |
dont_apply_to_streams=[constants.inference_stream],
|
|
|
|
| 213 |
|
| 214 |
result = {}
|
| 215 |
all_scores = []
|
| 216 |
+
all_num_of_instances = []
|
| 217 |
for k, v in dic.items():
|
| 218 |
score = recursive_mean(v)
|
| 219 |
if score is not None:
|
| 220 |
all_scores.append(score["score"])
|
| 221 |
+
if "num_of_instances" in score:
|
| 222 |
+
all_num_of_instances.append(score["num_of_instances"])
|
| 223 |
result[k] = score
|
| 224 |
|
| 225 |
result["score"] = nan_mean(all_scores)
|
| 226 |
result["score_name"] = "subsets_mean"
|
| 227 |
+
if all_num_of_instances:
|
| 228 |
+
result["num_of_instances"] = sum(all_num_of_instances)
|
| 229 |
|
| 230 |
if result:
|
| 231 |
return result
|
|
|
|
| 242 |
"score": score["subsets"]["score"],
|
| 243 |
"score_name": score["subsets"]["score_name"],
|
| 244 |
}
|
| 245 |
+
if "num_of_instances" in score["subsets"]:
|
| 246 |
+
score["global"]["num_of_instances"] = score["subsets"][
|
| 247 |
+
"num_of_instances"
|
| 248 |
+
]
|
| 249 |
|
| 250 |
sorted_instances = []
|
| 251 |
for key in sorted(stream_instances.keys()):
|
| 252 |
instance = stream_instances[key]
|
| 253 |
+
instance["score"].update(recursive_shallow_copy(score))
|
| 254 |
sorted_instances.append(instance)
|
| 255 |
result[stream_name] = sorted_instances
|
| 256 |
|
|
|
|
| 308 |
field="raw_references",
|
| 309 |
to_field="references",
|
| 310 |
),
|
| 311 |
+
RecursiveCopy(
|
| 312 |
field="source",
|
| 313 |
to_field="task_data/source",
|
| 314 |
),
|
metrics.py
CHANGED
|
@@ -8,10 +8,9 @@ import warnings
|
|
| 8 |
from abc import ABC, abstractmethod
|
| 9 |
from collections import Counter, defaultdict
|
| 10 |
from dataclasses import field
|
| 11 |
-
from
|
| 12 |
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
| 13 |
|
| 14 |
-
import evaluate
|
| 15 |
import numpy
|
| 16 |
import numpy as np
|
| 17 |
import pandas as pd
|
|
@@ -37,20 +36,18 @@ from .operator import (
|
|
| 37 |
StreamingOperator,
|
| 38 |
StreamOperator,
|
| 39 |
)
|
| 40 |
-
from .operators import Copy
|
| 41 |
from .random_utils import get_seed
|
| 42 |
from .settings_utils import get_settings
|
| 43 |
from .stream import MultiStream, Stream
|
| 44 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
| 45 |
-
from .utils import
|
| 46 |
|
| 47 |
logger = get_logger()
|
| 48 |
settings = get_settings()
|
| 49 |
|
| 50 |
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
| 51 |
|
| 52 |
-
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
| 53 |
-
|
| 54 |
|
| 55 |
def abstract_factory():
|
| 56 |
return {}
|
|
@@ -139,6 +136,7 @@ class Metric(Artifact):
|
|
| 139 |
return (
|
| 140 |
self.score_prefix + score_name
|
| 141 |
if score_name not in ["score", "score_name"]
|
|
|
|
| 142 |
else score_name
|
| 143 |
)
|
| 144 |
|
|
@@ -147,18 +145,24 @@ class Metric(Artifact):
|
|
| 147 |
) -> Dict[str, Any]:
|
| 148 |
new_scores = {}
|
| 149 |
for score_name, score in scores.items():
|
|
|
|
|
|
|
|
|
|
| 150 |
score_with_prefix = self._add_score_prefix(score_name)
|
| 151 |
new_scores[score_with_prefix] = (
|
| 152 |
score if score_name not in ["score_name"] else self.score_prefix + score
|
| 153 |
)
|
| 154 |
for new_score_name in new_scores:
|
| 155 |
-
if new_score_name in ["score", "score_name"]
|
|
|
|
|
|
|
| 156 |
continue
|
| 157 |
if new_score_name in existing_scores:
|
| 158 |
UnitxtWarning(
|
| 159 |
message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
|
| 160 |
f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
|
| 161 |
-
f"To avoid overwriting the existing value, add a score_prefix to the metric (e.g. score_prefix='my_second_'
|
|
|
|
| 162 |
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
| 163 |
)
|
| 164 |
return new_scores
|
|
@@ -279,7 +283,12 @@ class Metric(Artifact):
|
|
| 279 |
self, instance: Dict[str, Any], global_score: dict
|
| 280 |
):
|
| 281 |
for score_name in global_score:
|
| 282 |
-
if score_name in [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
continue
|
| 284 |
if score_name in instance["score"]["global"]:
|
| 285 |
UnitxtWarning(
|
|
@@ -469,11 +478,17 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 469 |
# iterate over the rows and compute the metric on each resampling
|
| 470 |
def metric(sample_refs, sample_preds, sample_task_data):
|
| 471 |
try:
|
| 472 |
-
|
| 473 |
references=sample_refs,
|
| 474 |
predictions=sample_preds,
|
| 475 |
task_data=sample_task_data,
|
| 476 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
except Exception as e:
|
| 478 |
# this happens in edge cases, for example, when the sampling creates a
|
| 479 |
# sample where all strings are empty and this fails bleu.
|
|
@@ -538,7 +553,6 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 538 |
references = []
|
| 539 |
predictions = []
|
| 540 |
task_data = []
|
| 541 |
-
global_score = {}
|
| 542 |
|
| 543 |
instances = []
|
| 544 |
|
|
@@ -589,6 +603,7 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 589 |
)
|
| 590 |
)
|
| 591 |
self._validate_references_and_prediction(references, predictions)
|
|
|
|
| 592 |
|
| 593 |
result = self._compute(references, predictions, task_data)
|
| 594 |
global_score.update(
|
|
@@ -596,11 +611,18 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 596 |
result, global_score
|
| 597 |
)
|
| 598 |
)
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
for instance in instances:
|
| 606 |
self.update_and_adjust_global_score(instance, global_score)
|
|
@@ -649,28 +671,24 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 649 |
default_factory=lambda: ["mean", "weighted_win_rate"]
|
| 650 |
)
|
| 651 |
|
|
|
|
|
|
|
|
|
|
| 652 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 653 |
-
global_score = {}
|
| 654 |
instances = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
|
| 656 |
-
|
| 657 |
-
references
|
| 658 |
-
list,
|
| 659 |
-
zip(
|
| 660 |
-
*[
|
| 661 |
-
itemgetter("references", "prediction")(
|
| 662 |
-
self.verify_instance(instance)
|
| 663 |
-
)
|
| 664 |
-
for instance in stream
|
| 665 |
-
]
|
| 666 |
-
),
|
| 667 |
-
)
|
| 668 |
-
|
| 669 |
task_data = [
|
| 670 |
instance["task_data"] if "task_data" in instance else {}
|
| 671 |
-
for instance in
|
| 672 |
]
|
| 673 |
self._validate_references_and_prediction(references, predictions)
|
|
|
|
| 674 |
# compute the metric over all refs and preds
|
| 675 |
instance_scores = self.compute(
|
| 676 |
references=references,
|
|
@@ -683,7 +701,7 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 683 |
instance_score["score"] = instance_score[self.main_score]
|
| 684 |
instance_score["score_name"] = self.main_score
|
| 685 |
|
| 686 |
-
for instance, score in zip(
|
| 687 |
if "score" not in instance:
|
| 688 |
instance["score"] = {"global": {}, "instance": {}}
|
| 689 |
|
|
@@ -692,7 +710,6 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 692 |
score, instance["score"]["instance"]
|
| 693 |
)
|
| 694 |
)
|
| 695 |
-
instances.append(instance)
|
| 696 |
|
| 697 |
for reduction, fields in self.reduction_map.items():
|
| 698 |
assert (
|
|
@@ -1059,7 +1076,7 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1059 |
|
| 1060 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1061 |
instances = self.compute_instance_scores(stream)
|
| 1062 |
-
global_score = {}
|
| 1063 |
for reduction_type, reduction_params in self.reduction_map.items():
|
| 1064 |
assert (
|
| 1065 |
reduction_type in self.implemented_reductions
|
|
@@ -1096,7 +1113,10 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1096 |
scores_to_resample,
|
| 1097 |
aggregation_function,
|
| 1098 |
) = self._set_up_group_mean_aggregation(
|
| 1099 |
-
instances,
|
|
|
|
|
|
|
|
|
|
| 1100 |
)
|
| 1101 |
else:
|
| 1102 |
raise ValueError(
|
|
@@ -1171,13 +1191,16 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1171 |
instance_score["score_name"] = self.main_score
|
| 1172 |
if "score" not in instance:
|
| 1173 |
instance["score"] = {"global": {}, "instance": {}}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1174 |
|
| 1175 |
instance["score"]["instance"].update(
|
| 1176 |
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
| 1177 |
instance_score, instance["score"]["instance"]
|
| 1178 |
)
|
| 1179 |
)
|
| 1180 |
-
|
| 1181 |
instances.append(instance)
|
| 1182 |
|
| 1183 |
return instances
|
|
@@ -1187,7 +1210,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1187 |
instances: List[dict],
|
| 1188 |
score_names: List[str],
|
| 1189 |
group_aggregation_func,
|
| 1190 |
-
prepend_score_prefix: bool
|
|
|
|
|
|
|
| 1191 |
):
|
| 1192 |
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
| 1193 |
|
|
@@ -1199,6 +1224,8 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1199 |
callable function returns a single score for the group
|
| 1200 |
prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
|
| 1201 |
if down the stream such a prepending is expected.
|
|
|
|
|
|
|
| 1202 |
|
| 1203 |
Returns:
|
| 1204 |
List of dicts, each corresponding to a group of instances (defined by 'group_id'),
|
|
@@ -1233,8 +1260,27 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1233 |
]
|
| 1234 |
)
|
| 1235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1236 |
# if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
|
| 1237 |
-
|
| 1238 |
{
|
| 1239 |
"score": {
|
| 1240 |
"instance": {
|
|
@@ -1255,9 +1301,25 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1255 |
) # sorted for consistency
|
| 1256 |
]
|
| 1257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1258 |
def _set_up_group_mean_aggregation(
|
| 1259 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1260 |
):
|
|
|
|
| 1261 |
group_aggregation_func = reduction_params["agg_func"][1]
|
| 1262 |
# if treat groups as units
|
| 1263 |
do_resample_as_group = reduction_params["agg_func"][2]
|
|
@@ -1265,7 +1327,12 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1265 |
# pass the group aggregate---not instance---scores to resample as usual
|
| 1266 |
aggregation_function = self.average_item_scores
|
| 1267 |
scores_to_resample = self.get_group_scores(
|
| 1268 |
-
instances,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1269 |
)
|
| 1270 |
else:
|
| 1271 |
# pass the instance scores to resample, and calculate the group aggregation on the resamplings
|
|
@@ -1277,7 +1344,12 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 1277 |
group_aggregation_func=group_aggregation_func,
|
| 1278 |
):
|
| 1279 |
group_scores = self.get_group_scores(
|
| 1280 |
-
instances,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1281 |
)
|
| 1282 |
return nan_mean(
|
| 1283 |
[group["score"]["instance"][field_name] for group in group_scores]
|
|
@@ -1315,6 +1387,19 @@ class ANLS(InstanceMetric):
|
|
| 1315 |
reduction_map = {"mean": ["anls"]}
|
| 1316 |
prediction_type = Any # string representation is compared
|
| 1317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1318 |
def compute(
|
| 1319 |
self,
|
| 1320 |
references: List[Any],
|
|
@@ -1324,20 +1409,14 @@ class ANLS(InstanceMetric):
|
|
| 1324 |
) -> dict:
|
| 1325 |
"""ANLS image-text accuracy metric."""
|
| 1326 |
values = []
|
| 1327 |
-
for
|
| 1328 |
-
|
| 1329 |
-
gt_answer = " ".join(answer.strip().lower().split())
|
| 1330 |
-
det_answer = " ".join(prediction.strip().lower().split())
|
| 1331 |
-
|
| 1332 |
-
# dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
|
| 1333 |
-
dist = self.levenshtein_distance(gt_answer, det_answer)
|
| 1334 |
-
length = max(len(answer.upper()), len(prediction.upper()))
|
| 1335 |
-
values.append(0.0 if length == 0 else float(dist) / float(length))
|
| 1336 |
|
| 1337 |
question_result = 1.0 - min(values)
|
| 1338 |
|
| 1339 |
if question_result < threshold:
|
| 1340 |
question_result = 0.0
|
|
|
|
| 1341 |
result = {}
|
| 1342 |
result["score"] = question_result
|
| 1343 |
result[self.main_score] = question_result
|
|
@@ -1345,6 +1424,7 @@ class ANLS(InstanceMetric):
|
|
| 1345 |
return result
|
| 1346 |
|
| 1347 |
@staticmethod
|
|
|
|
| 1348 |
def levenshtein_distance(s1, s2):
|
| 1349 |
if len(s1) > len(s2):
|
| 1350 |
s1, s2 = s2, s1
|
|
@@ -1526,16 +1606,40 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
| 1526 |
), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
|
| 1527 |
if has_postpreprocess:
|
| 1528 |
self.postprocess_steps = self.postpreprocess_steps
|
| 1529 |
-
self.prepare_score =
|
| 1530 |
-
|
| 1531 |
-
|
| 1532 |
-
f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
|
| 1533 |
-
"score/instance/score",
|
| 1534 |
-
|
| 1535 |
-
|
| 1536 |
-
f"score/global/{self.metric._add_score_prefix(self.main_score)}",
|
| 1537 |
-
"score/global/score",
|
| 1538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1539 |
],
|
| 1540 |
)
|
| 1541 |
|
|
@@ -1589,6 +1693,8 @@ class HuggingfaceMetric(GlobalMetric):
|
|
| 1589 |
|
| 1590 |
def prepare(self):
|
| 1591 |
super().prepare()
|
|
|
|
|
|
|
| 1592 |
self.metric = evaluate.load(
|
| 1593 |
self.hf_metric_name, experiment_id=self.experiment_id
|
| 1594 |
)
|
|
@@ -1663,6 +1769,8 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
|
|
| 1663 |
|
| 1664 |
def prepare(self):
|
| 1665 |
super().prepare()
|
|
|
|
|
|
|
| 1666 |
self.metric = evaluate.load(
|
| 1667 |
self.hf_metric_name, experiment_id=str(uuid.uuid4())
|
| 1668 |
)
|
|
@@ -1709,6 +1817,8 @@ class HuggingfaceInstanceMetric(InstanceMetric):
|
|
| 1709 |
|
| 1710 |
def prepare(self):
|
| 1711 |
super().prepare()
|
|
|
|
|
|
|
| 1712 |
self.metric = evaluate.load(
|
| 1713 |
self.hf_metric_name, experiment_id=str(uuid.uuid4())
|
| 1714 |
)
|
|
@@ -1788,6 +1898,8 @@ class F1(GlobalMetric):
|
|
| 1788 |
|
| 1789 |
def prepare(self):
|
| 1790 |
super().prepare()
|
|
|
|
|
|
|
| 1791 |
self._metric = evaluate.load(self.metric, experiment_id=str(uuid.uuid4()))
|
| 1792 |
|
| 1793 |
def get_str_id(self, str):
|
|
@@ -1847,6 +1959,7 @@ class F1Binary(GlobalMetric):
|
|
| 1847 |
_metric = None
|
| 1848 |
metric = "f1"
|
| 1849 |
single_reference_per_prediction = True
|
|
|
|
| 1850 |
_requirements_list: List[str] = ["sklearn"]
|
| 1851 |
|
| 1852 |
def prepare(self):
|
|
@@ -2064,6 +2177,8 @@ class F1MultiLabel(GlobalMetric):
|
|
| 2064 |
|
| 2065 |
def prepare(self):
|
| 2066 |
super().prepare()
|
|
|
|
|
|
|
| 2067 |
self._metric = evaluate.load(
|
| 2068 |
self.metric, "multilabel", experiment_id=str(uuid.uuid4())
|
| 2069 |
)
|
|
@@ -3033,7 +3148,7 @@ class SafetyMetric(GlobalMetric):
|
|
| 3033 |
class LlamaIndexLLMMetric(InstanceMetric):
|
| 3034 |
model_name: str = ""
|
| 3035 |
main_score: str = ""
|
| 3036 |
-
prediction_type
|
| 3037 |
reduction_map: Dict[str, List[str]] = None
|
| 3038 |
openai_models: List[str] = ["gpt-3.5-turbo"]
|
| 3039 |
anthropic_models: List[
|
|
@@ -3679,6 +3794,7 @@ class RetrievalAtK(RetrievalMetric):
|
|
| 3679 |
(recall_at_k, "recall"),
|
| 3680 |
(match_at_k, "match"),
|
| 3681 |
]:
|
|
|
|
| 3682 |
max_k = max(measure_array.keys())
|
| 3683 |
for k in self.k_list:
|
| 3684 |
result[self.score_name(measure_name, k)] = measure_array[min(k, max_k)]
|
|
@@ -3725,7 +3841,7 @@ class RemoteMetric(StreamOperator, Metric):
|
|
| 3725 |
remotely (pre and post processing steps in the MetricPipeline will be computed locally).
|
| 3726 |
"""
|
| 3727 |
local_inner_metric = metric_pipeline.metric
|
| 3728 |
-
metric_pipeline =
|
| 3729 |
metric_pipeline
|
| 3730 |
) # To avoid unintentional changes to the catalog contents
|
| 3731 |
metric_pipeline.metric = RemoteMetric(
|
|
@@ -4376,6 +4492,7 @@ class BinaryMaxF1(F1Binary):
|
|
| 4376 |
main_score = "max_f1_binary"
|
| 4377 |
single_reference_per_prediction = True
|
| 4378 |
average = None
|
|
|
|
| 4379 |
|
| 4380 |
def compute(
|
| 4381 |
self,
|
|
@@ -4799,17 +4916,22 @@ class F1Strings(InstanceMetric):
|
|
| 4799 |
"spacy": "Please pip install spacy",
|
| 4800 |
}
|
| 4801 |
|
| 4802 |
-
def
|
| 4803 |
-
super().prepare()
|
| 4804 |
import spacy
|
| 4805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4806 |
try:
|
| 4807 |
-
self.
|
| 4808 |
except OSError:
|
| 4809 |
from spacy.cli import download
|
| 4810 |
|
| 4811 |
download("en_core_web_sm")
|
| 4812 |
-
self.
|
| 4813 |
|
| 4814 |
def compute(
|
| 4815 |
self,
|
|
@@ -4955,3 +5077,20 @@ class RandomForestMetricsEnsemble(MetricsEnsemble):
|
|
| 4955 |
)
|
| 4956 |
score = ensemble_model.predict([prediction_lst])
|
| 4957 |
return score.tolist()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from abc import ABC, abstractmethod
|
| 9 |
from collections import Counter, defaultdict
|
| 10 |
from dataclasses import field
|
| 11 |
+
from functools import lru_cache
|
| 12 |
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
| 13 |
|
|
|
|
| 14 |
import numpy
|
| 15 |
import numpy as np
|
| 16 |
import pandas as pd
|
|
|
|
| 36 |
StreamingOperator,
|
| 37 |
StreamOperator,
|
| 38 |
)
|
| 39 |
+
from .operators import Copy, Set
|
| 40 |
from .random_utils import get_seed
|
| 41 |
from .settings_utils import get_settings
|
| 42 |
from .stream import MultiStream, Stream
|
| 43 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
| 44 |
+
from .utils import deep_copy
|
| 45 |
|
| 46 |
logger = get_logger()
|
| 47 |
settings = get_settings()
|
| 48 |
|
| 49 |
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
| 50 |
|
|
|
|
|
|
|
| 51 |
|
| 52 |
def abstract_factory():
|
| 53 |
return {}
|
|
|
|
| 136 |
return (
|
| 137 |
self.score_prefix + score_name
|
| 138 |
if score_name not in ["score", "score_name"]
|
| 139 |
+
and not score_name.startswith("num_of_instances")
|
| 140 |
else score_name
|
| 141 |
)
|
| 142 |
|
|
|
|
| 145 |
) -> Dict[str, Any]:
|
| 146 |
new_scores = {}
|
| 147 |
for score_name, score in scores.items():
|
| 148 |
+
if isinstance(score, dict):
|
| 149 |
+
new_scores[score_name] = score
|
| 150 |
+
continue # do not prefix group names
|
| 151 |
score_with_prefix = self._add_score_prefix(score_name)
|
| 152 |
new_scores[score_with_prefix] = (
|
| 153 |
score if score_name not in ["score_name"] else self.score_prefix + score
|
| 154 |
)
|
| 155 |
for new_score_name in new_scores:
|
| 156 |
+
if new_score_name in ["score", "score_name"] or new_score_name.startswith(
|
| 157 |
+
"num_of_instances"
|
| 158 |
+
):
|
| 159 |
continue
|
| 160 |
if new_score_name in existing_scores:
|
| 161 |
UnitxtWarning(
|
| 162 |
message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
|
| 163 |
f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
|
| 164 |
+
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
|
| 165 |
+
f"which will yield, in this case, a score named: 'my_second_{new_score_name}')",
|
| 166 |
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
| 167 |
)
|
| 168 |
return new_scores
|
|
|
|
| 283 |
self, instance: Dict[str, Any], global_score: dict
|
| 284 |
):
|
| 285 |
for score_name in global_score:
|
| 286 |
+
if score_name in [
|
| 287 |
+
"score",
|
| 288 |
+
"score_name",
|
| 289 |
+
"score_ci_low",
|
| 290 |
+
"score_ci_high",
|
| 291 |
+
] or score_name.startswith("num_of_instances"):
|
| 292 |
continue
|
| 293 |
if score_name in instance["score"]["global"]:
|
| 294 |
UnitxtWarning(
|
|
|
|
| 478 |
# iterate over the rows and compute the metric on each resampling
|
| 479 |
def metric(sample_refs, sample_preds, sample_task_data):
|
| 480 |
try:
|
| 481 |
+
results = self._compute(
|
| 482 |
references=sample_refs,
|
| 483 |
predictions=sample_preds,
|
| 484 |
task_data=sample_task_data,
|
| 485 |
+
)
|
| 486 |
+
results.update(
|
| 487 |
+
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
| 488 |
+
results, {}
|
| 489 |
+
)
|
| 490 |
+
)
|
| 491 |
+
return results[score_name]
|
| 492 |
except Exception as e:
|
| 493 |
# this happens in edge cases, for example, when the sampling creates a
|
| 494 |
# sample where all strings are empty and this fails bleu.
|
|
|
|
| 553 |
references = []
|
| 554 |
predictions = []
|
| 555 |
task_data = []
|
|
|
|
| 556 |
|
| 557 |
instances = []
|
| 558 |
|
|
|
|
| 603 |
)
|
| 604 |
)
|
| 605 |
self._validate_references_and_prediction(references, predictions)
|
| 606 |
+
global_score = {"num_of_instances": len(instances)}
|
| 607 |
|
| 608 |
result = self._compute(references, predictions, task_data)
|
| 609 |
global_score.update(
|
|
|
|
| 611 |
result, global_score
|
| 612 |
)
|
| 613 |
)
|
| 614 |
+
if self.ci_scores:
|
| 615 |
+
score_names = [
|
| 616 |
+
self._add_score_prefix(score_name) for score_name in self.ci_scores
|
| 617 |
+
]
|
| 618 |
+
else:
|
| 619 |
+
score_names = [global_score["score_name"]]
|
| 620 |
+
|
| 621 |
+
for score_name in score_names:
|
| 622 |
+
confidence_interval = self.compute_global_confidence_intervals(
|
| 623 |
+
references, predictions, task_data, score_name
|
| 624 |
+
)
|
| 625 |
+
global_score.update(confidence_interval)
|
| 626 |
|
| 627 |
for instance in instances:
|
| 628 |
self.update_and_adjust_global_score(instance, global_score)
|
|
|
|
| 671 |
default_factory=lambda: ["mean", "weighted_win_rate"]
|
| 672 |
)
|
| 673 |
|
| 674 |
+
def preprocess_instance(self, instance):
|
| 675 |
+
return instance
|
| 676 |
+
|
| 677 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
|
|
|
| 678 |
instances = []
|
| 679 |
+
for instance in stream:
|
| 680 |
+
self.verify_instance(instance)
|
| 681 |
+
instance = self.preprocess_instance(instance)
|
| 682 |
+
instances.append(instance)
|
| 683 |
|
| 684 |
+
predictions = [instance["prediction"] for instance in instances]
|
| 685 |
+
references = [instance["references"] for instance in instances]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
task_data = [
|
| 687 |
instance["task_data"] if "task_data" in instance else {}
|
| 688 |
+
for instance in instances
|
| 689 |
]
|
| 690 |
self._validate_references_and_prediction(references, predictions)
|
| 691 |
+
global_score = {"num_of_instances": len(instances)}
|
| 692 |
# compute the metric over all refs and preds
|
| 693 |
instance_scores = self.compute(
|
| 694 |
references=references,
|
|
|
|
| 701 |
instance_score["score"] = instance_score[self.main_score]
|
| 702 |
instance_score["score_name"] = self.main_score
|
| 703 |
|
| 704 |
+
for instance, score in zip(instances, instance_scores):
|
| 705 |
if "score" not in instance:
|
| 706 |
instance["score"] = {"global": {}, "instance": {}}
|
| 707 |
|
|
|
|
| 710 |
score, instance["score"]["instance"]
|
| 711 |
)
|
| 712 |
)
|
|
|
|
| 713 |
|
| 714 |
for reduction, fields in self.reduction_map.items():
|
| 715 |
assert (
|
|
|
|
| 1076 |
|
| 1077 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1078 |
instances = self.compute_instance_scores(stream)
|
| 1079 |
+
global_score = {"num_of_instances": len(instances)}
|
| 1080 |
for reduction_type, reduction_params in self.reduction_map.items():
|
| 1081 |
assert (
|
| 1082 |
reduction_type in self.implemented_reductions
|
|
|
|
| 1113 |
scores_to_resample,
|
| 1114 |
aggregation_function,
|
| 1115 |
) = self._set_up_group_mean_aggregation(
|
| 1116 |
+
instances,
|
| 1117 |
+
reduction_params,
|
| 1118 |
+
reduction_fields,
|
| 1119 |
+
global_score,
|
| 1120 |
)
|
| 1121 |
else:
|
| 1122 |
raise ValueError(
|
|
|
|
| 1191 |
instance_score["score_name"] = self.main_score
|
| 1192 |
if "score" not in instance:
|
| 1193 |
instance["score"] = {"global": {}, "instance": {}}
|
| 1194 |
+
if "global" not in instance["score"]:
|
| 1195 |
+
instance["score"]["global"] = {}
|
| 1196 |
+
if "instance" not in instance["score"]:
|
| 1197 |
+
instance["score"]["instance"] = {}
|
| 1198 |
|
| 1199 |
instance["score"]["instance"].update(
|
| 1200 |
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
| 1201 |
instance_score, instance["score"]["instance"]
|
| 1202 |
)
|
| 1203 |
)
|
|
|
|
| 1204 |
instances.append(instance)
|
| 1205 |
|
| 1206 |
return instances
|
|
|
|
| 1210 |
instances: List[dict],
|
| 1211 |
score_names: List[str],
|
| 1212 |
group_aggregation_func,
|
| 1213 |
+
prepend_score_prefix: bool,
|
| 1214 |
+
global_score: dict,
|
| 1215 |
+
aggregation_function_name: str,
|
| 1216 |
):
|
| 1217 |
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
| 1218 |
|
|
|
|
| 1224 |
callable function returns a single score for the group
|
| 1225 |
prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
|
| 1226 |
if down the stream such a prepending is expected.
|
| 1227 |
+
global_score: the being built up global score. It will be filled here with number of instances per each group, and group scores.
|
| 1228 |
+
aggregation_function_name: used to annotate the groups' global scores.
|
| 1229 |
|
| 1230 |
Returns:
|
| 1231 |
List of dicts, each corresponding to a group of instances (defined by 'group_id'),
|
|
|
|
| 1260 |
]
|
| 1261 |
)
|
| 1262 |
|
| 1263 |
+
# count the instances in each group and subgroup.
|
| 1264 |
+
# Each instance goes into group_to_instances per each score_name.
|
| 1265 |
+
# So we count over the first score_name only
|
| 1266 |
+
for group_key in group_to_instance_scores:
|
| 1267 |
+
if group_key not in global_score:
|
| 1268 |
+
global_score[group_key] = {}
|
| 1269 |
+
global_score[group_key]["num_of_instances"] = sum(
|
| 1270 |
+
[
|
| 1271 |
+
len(
|
| 1272 |
+
group_to_instance_scores[group_key][score_names[0]][
|
| 1273 |
+
subgroup_type
|
| 1274 |
+
]
|
| 1275 |
+
)
|
| 1276 |
+
for subgroup_type in group_to_instance_scores[group_key][
|
| 1277 |
+
score_names[0]
|
| 1278 |
+
]
|
| 1279 |
+
]
|
| 1280 |
+
)
|
| 1281 |
+
|
| 1282 |
# if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
|
| 1283 |
+
to_return = [
|
| 1284 |
{
|
| 1285 |
"score": {
|
| 1286 |
"instance": {
|
|
|
|
| 1301 |
) # sorted for consistency
|
| 1302 |
]
|
| 1303 |
|
| 1304 |
+
# update each group section in global_score
|
| 1305 |
+
for i, group_name in enumerate(sorted(group_to_instance_scores.keys())):
|
| 1306 |
+
global_score[group_name].update(
|
| 1307 |
+
{
|
| 1308 |
+
aggregation_function_name + "_" + k: v
|
| 1309 |
+
for k, v in to_return[i]["score"]["instance"].items()
|
| 1310 |
+
}
|
| 1311 |
+
)
|
| 1312 |
+
|
| 1313 |
+
return to_return
|
| 1314 |
+
|
| 1315 |
def _set_up_group_mean_aggregation(
|
| 1316 |
+
self,
|
| 1317 |
+
instances,
|
| 1318 |
+
reduction_params,
|
| 1319 |
+
reduction_fields,
|
| 1320 |
+
global_score,
|
| 1321 |
):
|
| 1322 |
+
aggregation_function_name = str(reduction_params["agg_func"][0])
|
| 1323 |
group_aggregation_func = reduction_params["agg_func"][1]
|
| 1324 |
# if treat groups as units
|
| 1325 |
do_resample_as_group = reduction_params["agg_func"][2]
|
|
|
|
| 1327 |
# pass the group aggregate---not instance---scores to resample as usual
|
| 1328 |
aggregation_function = self.average_item_scores
|
| 1329 |
scores_to_resample = self.get_group_scores(
|
| 1330 |
+
instances=instances,
|
| 1331 |
+
score_names=reduction_fields,
|
| 1332 |
+
group_aggregation_func=group_aggregation_func,
|
| 1333 |
+
prepend_score_prefix=True,
|
| 1334 |
+
global_score=global_score,
|
| 1335 |
+
aggregation_function_name=aggregation_function_name,
|
| 1336 |
)
|
| 1337 |
else:
|
| 1338 |
# pass the instance scores to resample, and calculate the group aggregation on the resamplings
|
|
|
|
| 1344 |
group_aggregation_func=group_aggregation_func,
|
| 1345 |
):
|
| 1346 |
group_scores = self.get_group_scores(
|
| 1347 |
+
instances=instances,
|
| 1348 |
+
score_names=[field_name],
|
| 1349 |
+
group_aggregation_func=group_aggregation_func,
|
| 1350 |
+
prepend_score_prefix=False,
|
| 1351 |
+
global_score=global_score,
|
| 1352 |
+
aggregation_function_name=aggregation_function_name,
|
| 1353 |
)
|
| 1354 |
return nan_mean(
|
| 1355 |
[group["score"]["instance"][field_name] for group in group_scores]
|
|
|
|
| 1387 |
reduction_map = {"mean": ["anls"]}
|
| 1388 |
prediction_type = Any # string representation is compared
|
| 1389 |
|
| 1390 |
+
@staticmethod
|
| 1391 |
+
@lru_cache(maxsize=10000)
|
| 1392 |
+
def preprocess_text(text):
|
| 1393 |
+
return " ".join(text.strip().lower().split()), len(text.upper())
|
| 1394 |
+
|
| 1395 |
+
def distance(self, prediction, reference):
|
| 1396 |
+
processed_reference, len_reference = self.preprocess_text(reference)
|
| 1397 |
+
processed_prediction, len_prediction = self.preprocess_text(prediction)
|
| 1398 |
+
|
| 1399 |
+
dist = self.levenshtein_distance(processed_reference, processed_prediction)
|
| 1400 |
+
length = max(len_reference, len_prediction)
|
| 1401 |
+
return 0.0 if length == 0 else float(dist) / float(length)
|
| 1402 |
+
|
| 1403 |
def compute(
|
| 1404 |
self,
|
| 1405 |
references: List[Any],
|
|
|
|
| 1409 |
) -> dict:
|
| 1410 |
"""ANLS image-text accuracy metric."""
|
| 1411 |
values = []
|
| 1412 |
+
for reference in references:
|
| 1413 |
+
values.append(self.distance(prediction, reference))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1414 |
|
| 1415 |
question_result = 1.0 - min(values)
|
| 1416 |
|
| 1417 |
if question_result < threshold:
|
| 1418 |
question_result = 0.0
|
| 1419 |
+
|
| 1420 |
result = {}
|
| 1421 |
result["score"] = question_result
|
| 1422 |
result[self.main_score] = question_result
|
|
|
|
| 1424 |
return result
|
| 1425 |
|
| 1426 |
@staticmethod
|
| 1427 |
+
@lru_cache(maxsize=10000)
|
| 1428 |
def levenshtein_distance(s1, s2):
|
| 1429 |
if len(s1) > len(s2):
|
| 1430 |
s1, s2 = s2, s1
|
|
|
|
| 1606 |
), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
|
| 1607 |
if has_postpreprocess:
|
| 1608 |
self.postprocess_steps = self.postpreprocess_steps
|
| 1609 |
+
self.prepare_score = SequentialOperator(
|
| 1610 |
+
steps=[
|
| 1611 |
+
Copy(
|
| 1612 |
+
field=f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
|
| 1613 |
+
to_field="score/instance/score",
|
| 1614 |
+
),
|
| 1615 |
+
Copy(
|
| 1616 |
+
field=f"score/global/{self.metric._add_score_prefix(self.main_score)}",
|
| 1617 |
+
to_field="score/global/score",
|
| 1618 |
+
),
|
| 1619 |
+
Copy(
|
| 1620 |
+
field=f"score/global/{self.metric._add_score_prefix(self.main_score)}_ci_low",
|
| 1621 |
+
to_field="score/global/score_ci_low",
|
| 1622 |
+
not_exist_do_nothing=True,
|
| 1623 |
+
),
|
| 1624 |
+
Copy(
|
| 1625 |
+
field=f"score/global/{self.metric._add_score_prefix(self.main_score)}_ci_high",
|
| 1626 |
+
to_field="score/global/score_ci_high",
|
| 1627 |
+
not_exist_do_nothing=True,
|
| 1628 |
+
),
|
| 1629 |
+
Set(
|
| 1630 |
+
fields={
|
| 1631 |
+
"score/instance/score_name": self.metric._add_score_prefix(
|
| 1632 |
+
self.main_score
|
| 1633 |
+
)
|
| 1634 |
+
}
|
| 1635 |
+
),
|
| 1636 |
+
Set(
|
| 1637 |
+
fields={
|
| 1638 |
+
"score/global/score_name": self.metric._add_score_prefix(
|
| 1639 |
+
self.main_score
|
| 1640 |
+
)
|
| 1641 |
+
}
|
| 1642 |
+
),
|
| 1643 |
],
|
| 1644 |
)
|
| 1645 |
|
|
|
|
| 1693 |
|
| 1694 |
def prepare(self):
|
| 1695 |
super().prepare()
|
| 1696 |
+
import evaluate
|
| 1697 |
+
|
| 1698 |
self.metric = evaluate.load(
|
| 1699 |
self.hf_metric_name, experiment_id=self.experiment_id
|
| 1700 |
)
|
|
|
|
| 1769 |
|
| 1770 |
def prepare(self):
|
| 1771 |
super().prepare()
|
| 1772 |
+
import evaluate
|
| 1773 |
+
|
| 1774 |
self.metric = evaluate.load(
|
| 1775 |
self.hf_metric_name, experiment_id=str(uuid.uuid4())
|
| 1776 |
)
|
|
|
|
| 1817 |
|
| 1818 |
def prepare(self):
|
| 1819 |
super().prepare()
|
| 1820 |
+
import evaluate
|
| 1821 |
+
|
| 1822 |
self.metric = evaluate.load(
|
| 1823 |
self.hf_metric_name, experiment_id=str(uuid.uuid4())
|
| 1824 |
)
|
|
|
|
| 1898 |
|
| 1899 |
def prepare(self):
|
| 1900 |
super().prepare()
|
| 1901 |
+
import evaluate
|
| 1902 |
+
|
| 1903 |
self._metric = evaluate.load(self.metric, experiment_id=str(uuid.uuid4()))
|
| 1904 |
|
| 1905 |
def get_str_id(self, str):
|
|
|
|
| 1959 |
_metric = None
|
| 1960 |
metric = "f1"
|
| 1961 |
single_reference_per_prediction = True
|
| 1962 |
+
ci_scores = [main_score, "f1_binary_neg"]
|
| 1963 |
_requirements_list: List[str] = ["sklearn"]
|
| 1964 |
|
| 1965 |
def prepare(self):
|
|
|
|
| 2177 |
|
| 2178 |
def prepare(self):
|
| 2179 |
super().prepare()
|
| 2180 |
+
import evaluate
|
| 2181 |
+
|
| 2182 |
self._metric = evaluate.load(
|
| 2183 |
self.metric, "multilabel", experiment_id=str(uuid.uuid4())
|
| 2184 |
)
|
|
|
|
| 3148 |
class LlamaIndexLLMMetric(InstanceMetric):
|
| 3149 |
model_name: str = ""
|
| 3150 |
main_score: str = ""
|
| 3151 |
+
prediction_type = str
|
| 3152 |
reduction_map: Dict[str, List[str]] = None
|
| 3153 |
openai_models: List[str] = ["gpt-3.5-turbo"]
|
| 3154 |
anthropic_models: List[
|
|
|
|
| 3794 |
(recall_at_k, "recall"),
|
| 3795 |
(match_at_k, "match"),
|
| 3796 |
]:
|
| 3797 |
+
measure_array[0] = 0.0 # to support cases where the prediction is empty.
|
| 3798 |
max_k = max(measure_array.keys())
|
| 3799 |
for k in self.k_list:
|
| 3800 |
result[self.score_name(measure_name, k)] = measure_array[min(k, max_k)]
|
|
|
|
| 3841 |
remotely (pre and post processing steps in the MetricPipeline will be computed locally).
|
| 3842 |
"""
|
| 3843 |
local_inner_metric = metric_pipeline.metric
|
| 3844 |
+
metric_pipeline = deep_copy(
|
| 3845 |
metric_pipeline
|
| 3846 |
) # To avoid unintentional changes to the catalog contents
|
| 3847 |
metric_pipeline.metric = RemoteMetric(
|
|
|
|
| 4492 |
main_score = "max_f1_binary"
|
| 4493 |
single_reference_per_prediction = True
|
| 4494 |
average = None
|
| 4495 |
+
ci_scores = [main_score, "max_f1_binary_neg"]
|
| 4496 |
|
| 4497 |
def compute(
|
| 4498 |
self,
|
|
|
|
| 4916 |
"spacy": "Please pip install spacy",
|
| 4917 |
}
|
| 4918 |
|
| 4919 |
+
def load_spacy(self):
|
|
|
|
| 4920 |
import spacy
|
| 4921 |
|
| 4922 |
+
self.nlp = spacy.load(
|
| 4923 |
+
"en_core_web_sm", disable=["tagger", "parser", "ner", "lemmatizer"]
|
| 4924 |
+
)
|
| 4925 |
+
|
| 4926 |
+
def prepare(self):
|
| 4927 |
+
super().prepare()
|
| 4928 |
try:
|
| 4929 |
+
self.load_spacy()
|
| 4930 |
except OSError:
|
| 4931 |
from spacy.cli import download
|
| 4932 |
|
| 4933 |
download("en_core_web_sm")
|
| 4934 |
+
self.load_spacy()
|
| 4935 |
|
| 4936 |
def compute(
|
| 4937 |
self,
|
|
|
|
| 5077 |
)
|
| 5078 |
score = ensemble_model.predict([prediction_lst])
|
| 5079 |
return score.tolist()[0]
|
| 5080 |
+
|
| 5081 |
+
|
| 5082 |
+
class PredictionLength(InstanceMetric):
|
| 5083 |
+
"""Returns the length of the prediction."""
|
| 5084 |
+
|
| 5085 |
+
main_score = "prediction_length"
|
| 5086 |
+
reduction_map = {"mean": ["prediction_length"]}
|
| 5087 |
+
prediction_type = str
|
| 5088 |
+
single_reference_per_prediction = True
|
| 5089 |
+
|
| 5090 |
+
def compute(
|
| 5091 |
+
self,
|
| 5092 |
+
references: List[str],
|
| 5093 |
+
prediction: str,
|
| 5094 |
+
task_data: List[Dict],
|
| 5095 |
+
) -> dict:
|
| 5096 |
+
return {self.main_score: [len(prediction)], "score_name": self.main_score}
|
operators.py
CHANGED
|
@@ -39,7 +39,6 @@ General Operators List:
|
|
| 39 |
------------------------
|
| 40 |
"""
|
| 41 |
|
| 42 |
-
import copy
|
| 43 |
import operator
|
| 44 |
import uuid
|
| 45 |
import warnings
|
|
@@ -82,14 +81,19 @@ from .operator import (
|
|
| 82 |
StreamOperator,
|
| 83 |
)
|
| 84 |
from .random_utils import new_random_generator
|
| 85 |
-
from .settings_utils import
|
| 86 |
-
from .stream import DynamicStream, Stream
|
| 87 |
from .text_utils import nested_tuple_to_string
|
| 88 |
from .type_utils import isoftype
|
| 89 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
settings = get_settings()
|
| 92 |
-
constants = get_constants()
|
| 93 |
|
| 94 |
|
| 95 |
class FromIterables(StreamInitializerOperator):
|
|
@@ -132,8 +136,8 @@ class MapInstanceValues(InstanceOperator):
|
|
| 132 |
it maps values of instances in a stream using predefined mappers.
|
| 133 |
|
| 134 |
Attributes:
|
| 135 |
-
mappers (Dict[str, Dict[str,
|
| 136 |
-
Keys are the names of the fields to
|
| 137 |
that define the mapping from old values to new values.
|
| 138 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
| 139 |
does not exist in the mapper, it will raise a KeyError. If False, values
|
|
@@ -203,13 +207,12 @@ class MapInstanceValues(InstanceOperator):
|
|
| 203 |
|
| 204 |
def get_mapped_value(self, instance, key, mapper, val):
|
| 205 |
val_as_str = str(val) # make sure the value is a string
|
| 206 |
-
if
|
|
|
|
|
|
|
| 207 |
raise KeyError(
|
| 208 |
f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
| 209 |
)
|
| 210 |
-
# By default deep copy the value in mapper to avoid shared modifications
|
| 211 |
-
if val_as_str in mapper:
|
| 212 |
-
return deepcopy(mapper[val_as_str])
|
| 213 |
return val
|
| 214 |
|
| 215 |
|
|
@@ -269,7 +272,7 @@ class Set(InstanceOperator):
|
|
| 269 |
) -> Dict[str, Any]:
|
| 270 |
for key, value in self.fields.items():
|
| 271 |
if self.use_deepcopy:
|
| 272 |
-
value =
|
| 273 |
dict_set(instance, key, value)
|
| 274 |
return instance
|
| 275 |
|
|
@@ -318,6 +321,13 @@ class SelectFields(InstanceOperator):
|
|
| 318 |
return new_instance
|
| 319 |
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
class InstanceFieldOperator(InstanceOperator):
|
| 322 |
"""A general stream instance operator that processes the values of a field (or multiple ones).
|
| 323 |
|
|
@@ -348,6 +358,7 @@ class InstanceFieldOperator(InstanceOperator):
|
|
| 348 |
process_every_value: bool = False
|
| 349 |
get_default: Any = None
|
| 350 |
not_exist_ok: bool = False
|
|
|
|
| 351 |
|
| 352 |
def verify(self):
|
| 353 |
super().verify()
|
|
@@ -429,19 +440,18 @@ class InstanceFieldOperator(InstanceOperator):
|
|
| 429 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 430 |
) -> Dict[str, Any]:
|
| 431 |
self.verify_field_definition()
|
| 432 |
-
# Need to deep copy instance, because when assigning two dictionary fields,
|
| 433 |
-
# dict_set() the target field dictionary fields.
|
| 434 |
-
# This means that if this target field was assigned to another field before,
|
| 435 |
-
# the field is updated as well.
|
| 436 |
-
instance = deepcopy(instance)
|
| 437 |
for from_field, to_field in self._field_to_field:
|
| 438 |
try:
|
| 439 |
old_value = dict_get(
|
| 440 |
instance,
|
| 441 |
from_field,
|
| 442 |
-
default=
|
| 443 |
-
not_exist_ok=self.not_exist_ok,
|
| 444 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
except Exception as e:
|
| 446 |
raise ValueError(
|
| 447 |
f"Failed to get '{from_field}' from {instance} due to : {e}"
|
|
@@ -476,6 +486,13 @@ class FieldOperator(InstanceFieldOperator):
|
|
| 476 |
pass
|
| 477 |
|
| 478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
class Rename(FieldOperator):
|
| 480 |
"""Renames fields.
|
| 481 |
|
|
@@ -643,7 +660,9 @@ class ListFieldValues(InstanceOperator):
|
|
| 643 |
values = []
|
| 644 |
for field_name in self.fields:
|
| 645 |
values.append(dict_get(instance, field_name))
|
| 646 |
-
|
|
|
|
|
|
|
| 647 |
return instance
|
| 648 |
|
| 649 |
|
|
@@ -680,7 +699,7 @@ class ZipFieldValues(InstanceOperator):
|
|
| 680 |
zipped = zip_longest(*values)
|
| 681 |
else:
|
| 682 |
zipped = zip(*values)
|
| 683 |
-
instance
|
| 684 |
return instance
|
| 685 |
|
| 686 |
|
|
@@ -847,14 +866,15 @@ class Copy(FieldOperator):
|
|
| 847 |
|
| 848 |
"""
|
| 849 |
|
| 850 |
-
use_deep_copy: bool = True
|
| 851 |
-
|
| 852 |
def process_value(self, value: Any) -> Any:
|
| 853 |
-
if self.use_deep_copy:
|
| 854 |
-
return copy.deepcopy(value)
|
| 855 |
return value
|
| 856 |
|
| 857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
@deprecation(version="2.0.0", alternative=Copy)
|
| 859 |
class CopyFields(Copy):
|
| 860 |
pass
|
|
@@ -1022,7 +1042,7 @@ class ArtifactFetcherMixin:
|
|
| 1022 |
if artifact_identifier not in cls.cache:
|
| 1023 |
artifact, artifactory = fetch_artifact(artifact_identifier)
|
| 1024 |
cls.cache[artifact_identifier] = artifact
|
| 1025 |
-
return
|
| 1026 |
|
| 1027 |
|
| 1028 |
class ApplyOperatorsField(InstanceOperator):
|
|
@@ -1602,7 +1622,23 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
| 1602 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1603 |
from .metrics import Metric
|
| 1604 |
|
| 1605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1606 |
|
| 1607 |
metric_names = first_instance.get(self.metric_field, [])
|
| 1608 |
if not metric_names:
|
|
@@ -1619,16 +1655,6 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
| 1619 |
# by the first listed metric (as desired).
|
| 1620 |
metric_names = list(reversed(metric_names))
|
| 1621 |
|
| 1622 |
-
# Workaround: The metric/MetricPipeline modifies the stream itself, sometimes making it incompatible
|
| 1623 |
-
# for further metrics' processing, instead of just modifying the score field.
|
| 1624 |
-
# Here we keep all the fields besides the score, and restore them after the metric finishes.
|
| 1625 |
-
first_instance = stream.peek()
|
| 1626 |
-
keys_to_restore = set(first_instance.keys()).difference({"score"})
|
| 1627 |
-
multi_stream = MultiStream({stream_name: stream})
|
| 1628 |
-
multi_stream = CopyFields(
|
| 1629 |
-
field_to_field={k: f"{k}_orig" for k in keys_to_restore}
|
| 1630 |
-
)(multi_stream)
|
| 1631 |
-
|
| 1632 |
for metric_name in metric_names:
|
| 1633 |
metric = self.get_artifact(metric_name)
|
| 1634 |
assert isinstance(
|
|
@@ -1637,17 +1663,23 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
| 1637 |
|
| 1638 |
if not self.calc_confidence_intervals:
|
| 1639 |
metric.disable_confidence_interval_calculation()
|
| 1640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1641 |
multi_stream = metric(multi_stream)
|
| 1642 |
-
|
| 1643 |
-
|
| 1644 |
-
)
|
|
|
|
|
|
|
|
|
|
| 1645 |
|
| 1646 |
-
|
| 1647 |
-
multi_stream
|
| 1648 |
-
)
|
| 1649 |
-
stream = multi_stream[stream_name]
|
| 1650 |
-
yield from stream
|
| 1651 |
|
| 1652 |
|
| 1653 |
class MergeStreams(MultiStreamOperator):
|
|
@@ -2066,7 +2098,7 @@ class DuplicateInstances(StreamOperator):
|
|
| 2066 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 2067 |
for instance in stream:
|
| 2068 |
for idx in range(self.num_duplications):
|
| 2069 |
-
duplicate =
|
| 2070 |
if self.duplication_index_field:
|
| 2071 |
duplicate.update({self.duplication_index_field: idx})
|
| 2072 |
yield duplicate
|
|
|
|
| 39 |
------------------------
|
| 40 |
"""
|
| 41 |
|
|
|
|
| 42 |
import operator
|
| 43 |
import uuid
|
| 44 |
import warnings
|
|
|
|
| 81 |
StreamOperator,
|
| 82 |
)
|
| 83 |
from .random_utils import new_random_generator
|
| 84 |
+
from .settings_utils import get_settings
|
| 85 |
+
from .stream import DynamicStream, ListStream, Stream
|
| 86 |
from .text_utils import nested_tuple_to_string
|
| 87 |
from .type_utils import isoftype
|
| 88 |
+
from .utils import (
|
| 89 |
+
deep_copy,
|
| 90 |
+
flatten_dict,
|
| 91 |
+
recursive_copy,
|
| 92 |
+
recursive_shallow_copy,
|
| 93 |
+
shallow_copy,
|
| 94 |
+
)
|
| 95 |
|
| 96 |
settings = get_settings()
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
class FromIterables(StreamInitializerOperator):
|
|
|
|
| 136 |
it maps values of instances in a stream using predefined mappers.
|
| 137 |
|
| 138 |
Attributes:
|
| 139 |
+
mappers (Dict[str, Dict[str, Any]]): The mappers to use for mapping instance values.
|
| 140 |
+
Keys are the names of the fields to undergo mapping, and values are dictionaries
|
| 141 |
that define the mapping from old values to new values.
|
| 142 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
| 143 |
does not exist in the mapper, it will raise a KeyError. If False, values
|
|
|
|
| 207 |
|
| 208 |
def get_mapped_value(self, instance, key, mapper, val):
|
| 209 |
val_as_str = str(val) # make sure the value is a string
|
| 210 |
+
if val_as_str in mapper:
|
| 211 |
+
return recursive_copy(mapper[val_as_str])
|
| 212 |
+
if self.strict:
|
| 213 |
raise KeyError(
|
| 214 |
f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
| 215 |
)
|
|
|
|
|
|
|
|
|
|
| 216 |
return val
|
| 217 |
|
| 218 |
|
|
|
|
| 272 |
) -> Dict[str, Any]:
|
| 273 |
for key, value in self.fields.items():
|
| 274 |
if self.use_deepcopy:
|
| 275 |
+
value = deep_copy(value)
|
| 276 |
dict_set(instance, key, value)
|
| 277 |
return instance
|
| 278 |
|
|
|
|
| 321 |
return new_instance
|
| 322 |
|
| 323 |
|
| 324 |
+
class DefaultPlaceHolder:
|
| 325 |
+
pass
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
default_place_holder = DefaultPlaceHolder()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
class InstanceFieldOperator(InstanceOperator):
|
| 332 |
"""A general stream instance operator that processes the values of a field (or multiple ones).
|
| 333 |
|
|
|
|
| 358 |
process_every_value: bool = False
|
| 359 |
get_default: Any = None
|
| 360 |
not_exist_ok: bool = False
|
| 361 |
+
not_exist_do_nothing: bool = False
|
| 362 |
|
| 363 |
def verify(self):
|
| 364 |
super().verify()
|
|
|
|
| 440 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 441 |
) -> Dict[str, Any]:
|
| 442 |
self.verify_field_definition()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
for from_field, to_field in self._field_to_field:
|
| 444 |
try:
|
| 445 |
old_value = dict_get(
|
| 446 |
instance,
|
| 447 |
from_field,
|
| 448 |
+
default=default_place_holder,
|
| 449 |
+
not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
|
| 450 |
)
|
| 451 |
+
if old_value is default_place_holder:
|
| 452 |
+
if self.not_exist_do_nothing:
|
| 453 |
+
return instance
|
| 454 |
+
old_value = self.get_default
|
| 455 |
except Exception as e:
|
| 456 |
raise ValueError(
|
| 457 |
f"Failed to get '{from_field}' from {instance} due to : {e}"
|
|
|
|
| 486 |
pass
|
| 487 |
|
| 488 |
|
| 489 |
+
class MapValues(FieldOperator):
|
| 490 |
+
mapping: Dict[str, str]
|
| 491 |
+
|
| 492 |
+
def process_value(self, value: Any) -> Any:
|
| 493 |
+
return self.mapping[str(value)]
|
| 494 |
+
|
| 495 |
+
|
| 496 |
class Rename(FieldOperator):
|
| 497 |
"""Renames fields.
|
| 498 |
|
|
|
|
| 660 |
values = []
|
| 661 |
for field_name in self.fields:
|
| 662 |
values.append(dict_get(instance, field_name))
|
| 663 |
+
|
| 664 |
+
dict_set(instance, self.to_field, values)
|
| 665 |
+
|
| 666 |
return instance
|
| 667 |
|
| 668 |
|
|
|
|
| 699 |
zipped = zip_longest(*values)
|
| 700 |
else:
|
| 701 |
zipped = zip(*values)
|
| 702 |
+
dict_set(instance, self.to_field, list(zipped))
|
| 703 |
return instance
|
| 704 |
|
| 705 |
|
|
|
|
| 866 |
|
| 867 |
"""
|
| 868 |
|
|
|
|
|
|
|
| 869 |
def process_value(self, value: Any) -> Any:
|
|
|
|
|
|
|
| 870 |
return value
|
| 871 |
|
| 872 |
|
| 873 |
+
class RecursiveCopy(FieldOperator):
|
| 874 |
+
def process_value(self, value: Any) -> Any:
|
| 875 |
+
return recursive_copy(value)
|
| 876 |
+
|
| 877 |
+
|
| 878 |
@deprecation(version="2.0.0", alternative=Copy)
|
| 879 |
class CopyFields(Copy):
|
| 880 |
pass
|
|
|
|
| 1042 |
if artifact_identifier not in cls.cache:
|
| 1043 |
artifact, artifactory = fetch_artifact(artifact_identifier)
|
| 1044 |
cls.cache[artifact_identifier] = artifact
|
| 1045 |
+
return shallow_copy(cls.cache[artifact_identifier])
|
| 1046 |
|
| 1047 |
|
| 1048 |
class ApplyOperatorsField(InstanceOperator):
|
|
|
|
| 1622 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1623 |
from .metrics import Metric
|
| 1624 |
|
| 1625 |
+
# Number of instances in input stream is assumed to be small. This is why
|
| 1626 |
+
# each metric consumes all of them and lays them in its main memory, and even generates
|
| 1627 |
+
# some 1000 copies thereof for the sake of CI.
|
| 1628 |
+
# So we start with deep copying here, to make a 'frozen' status of the stream, having
|
| 1629 |
+
# passed the preprocess_steps of the task, and inference, and now getting to be evaluated,
|
| 1630 |
+
# a frozen status to be fed into each of the metrics listed in metric_field,
|
| 1631 |
+
# so that the evaluation of one does not affect the evaluation of another
|
| 1632 |
+
# (typically, affecting via change of instance as part of
|
| 1633 |
+
# preprocess_steps of MetricPipeline, as illustrated in docs/adding_metrics/Using Metric Pipelines).
|
| 1634 |
+
|
| 1635 |
+
instances_upon_entrance_to_metrics_evaluations = []
|
| 1636 |
+
for instance in stream:
|
| 1637 |
+
instances_upon_entrance_to_metrics_evaluations.append(
|
| 1638 |
+
recursive_copy(instance)
|
| 1639 |
+
)
|
| 1640 |
+
|
| 1641 |
+
first_instance = instances_upon_entrance_to_metrics_evaluations[0]
|
| 1642 |
|
| 1643 |
metric_names = first_instance.get(self.metric_field, [])
|
| 1644 |
if not metric_names:
|
|
|
|
| 1655 |
# by the first listed metric (as desired).
|
| 1656 |
metric_names = list(reversed(metric_names))
|
| 1657 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1658 |
for metric_name in metric_names:
|
| 1659 |
metric = self.get_artifact(metric_name)
|
| 1660 |
assert isinstance(
|
|
|
|
| 1663 |
|
| 1664 |
if not self.calc_confidence_intervals:
|
| 1665 |
metric.disable_confidence_interval_calculation()
|
| 1666 |
+
multi_stream = MultiStream(
|
| 1667 |
+
{
|
| 1668 |
+
"tmp": ListStream(
|
| 1669 |
+
instances_list=instances_upon_entrance_to_metrics_evaluations,
|
| 1670 |
+
copying=True, # ensures deep copy when iterating over instances
|
| 1671 |
+
)
|
| 1672 |
+
}
|
| 1673 |
+
)
|
| 1674 |
multi_stream = metric(multi_stream)
|
| 1675 |
+
for evaluated_instance, freezed_instance in zip(
|
| 1676 |
+
multi_stream["tmp"], instances_upon_entrance_to_metrics_evaluations
|
| 1677 |
+
):
|
| 1678 |
+
freezed_instance["score"] = recursive_shallow_copy(
|
| 1679 |
+
evaluated_instance["score"]
|
| 1680 |
+
)
|
| 1681 |
|
| 1682 |
+
yield from instances_upon_entrance_to_metrics_evaluations
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1683 |
|
| 1684 |
|
| 1685 |
class MergeStreams(MultiStreamOperator):
|
|
|
|
| 2098 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 2099 |
for instance in stream:
|
| 2100 |
for idx in range(self.num_duplications):
|
| 2101 |
+
duplicate = recursive_shallow_copy(instance)
|
| 2102 |
if self.duplication_index_field:
|
| 2103 |
duplicate.update({self.duplication_index_field: idx})
|
| 2104 |
yield duplicate
|
processors.py
CHANGED
|
@@ -2,9 +2,12 @@ import ast
|
|
| 2 |
import copy
|
| 3 |
import json
|
| 4 |
import re
|
|
|
|
| 5 |
from difflib import get_close_matches
|
| 6 |
from typing import Any, Dict
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from .deprecation_utils import deprecation
|
| 9 |
from .operator import MultiStreamOperator
|
| 10 |
from .operators import FieldOperator, InstanceFieldOperator
|
|
@@ -20,9 +23,9 @@ class PostProcess(MultiStreamOperator):
|
|
| 20 |
|
| 21 |
def prepare(self):
|
| 22 |
super().prepare()
|
| 23 |
-
self.prediction_operator = copy.
|
| 24 |
self.prediction_operator.field = "prediction"
|
| 25 |
-
self.references_operator = copy.
|
| 26 |
self.references_operator.field = "references"
|
| 27 |
self.references_operator.process_every_value = True
|
| 28 |
self.references_operator.dont_apply_to_streams = [constants.inference_stream]
|
|
@@ -315,3 +318,75 @@ class ExtractArenaHardNumericalJudgment(FieldOperator):
|
|
| 315 |
|
| 316 |
except:
|
| 317 |
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import copy
|
| 3 |
import json
|
| 4 |
import re
|
| 5 |
+
import string
|
| 6 |
from difflib import get_close_matches
|
| 7 |
from typing import Any, Dict
|
| 8 |
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
from .deprecation_utils import deprecation
|
| 12 |
from .operator import MultiStreamOperator
|
| 13 |
from .operators import FieldOperator, InstanceFieldOperator
|
|
|
|
| 23 |
|
| 24 |
def prepare(self):
|
| 25 |
super().prepare()
|
| 26 |
+
self.prediction_operator = copy.copy(self.operator)
|
| 27 |
self.prediction_operator.field = "prediction"
|
| 28 |
+
self.references_operator = copy.copy(self.operator)
|
| 29 |
self.references_operator.field = "references"
|
| 30 |
self.references_operator.process_every_value = True
|
| 31 |
self.references_operator.dont_apply_to_streams = [constants.inference_stream]
|
|
|
|
| 318 |
|
| 319 |
except:
|
| 320 |
return 0
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class InferDictsToBinaryLogprobs(FieldOperator):
|
| 324 |
+
neg_class_name: str
|
| 325 |
+
pos_class_name: str
|
| 326 |
+
|
| 327 |
+
take_logprobs_from_end: bool = False
|
| 328 |
+
num_logprobs_to_take: int = 3
|
| 329 |
+
min_probability_mass = 0.0001
|
| 330 |
+
|
| 331 |
+
def verify(self):
|
| 332 |
+
super().verify()
|
| 333 |
+
if (
|
| 334 |
+
self.neg_class_name.lower() in self.pos_class_name.lower()
|
| 335 |
+
or self.pos_class_name.lower() in self.neg_class_name.lower()
|
| 336 |
+
):
|
| 337 |
+
raise ValueError(
|
| 338 |
+
f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}"""
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def process_value(self, obj: Any) -> Any:
|
| 342 |
+
for i in self.get_token_range(obj):
|
| 343 |
+
try:
|
| 344 |
+
pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i])
|
| 345 |
+
if pos_probs or neg_probs:
|
| 346 |
+
sum_probs = sum(pos_probs) + sum(neg_probs)
|
| 347 |
+
if sum_probs > self.min_probability_mass:
|
| 348 |
+
return sum(pos_probs) / sum_probs
|
| 349 |
+
except:
|
| 350 |
+
pass
|
| 351 |
+
return 0
|
| 352 |
+
|
| 353 |
+
def get_pos_neg_probs(self, pred_dict):
|
| 354 |
+
token_logprobs = pred_dict["top_tokens"]
|
| 355 |
+
|
| 356 |
+
pos_and_neg_probs = []
|
| 357 |
+
for class_name in [self.pos_class_name, self.neg_class_name]:
|
| 358 |
+
# We need to capture different variants of model behavior and tokenizers, for example with opening space,
|
| 359 |
+
# punctuation etc. but avoid longer words that contain the class name.
|
| 360 |
+
# For example, for class "yes" we would capture "YES," and " Yes" but not "yesterday".
|
| 361 |
+
name_regex = re.compile(
|
| 362 |
+
rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE
|
| 363 |
+
)
|
| 364 |
+
class_probs = [
|
| 365 |
+
np.exp(d["logprob"])
|
| 366 |
+
for d in token_logprobs
|
| 367 |
+
if name_regex.fullmatch(d["text"])
|
| 368 |
+
]
|
| 369 |
+
pos_and_neg_probs.append(class_probs)
|
| 370 |
+
return pos_and_neg_probs
|
| 371 |
+
|
| 372 |
+
def get_token_range(self, obj: Any) -> range:
|
| 373 |
+
n_tokens = min([self.num_logprobs_to_take, len(obj)])
|
| 374 |
+
if self.take_logprobs_from_end:
|
| 375 |
+
return range(-1, -(n_tokens + 1), -1)
|
| 376 |
+
return range(n_tokens)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class RemoveArticles(FieldOperator):
|
| 380 |
+
def process_value(self, text: Any) -> Any:
|
| 381 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class RemovePunctuations(FieldOperator):
|
| 385 |
+
def process_value(self, text: Any) -> Any:
|
| 386 |
+
puncs_to_exclude = set(string.punctuation)
|
| 387 |
+
return "".join(c for c in text if c not in puncs_to_exclude)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class FixWhiteSpace(FieldOperator):
|
| 391 |
+
def process_value(self, text: Any) -> Any:
|
| 392 |
+
return " ".join(text.split())
|
settings_utils.py
CHANGED
|
@@ -147,6 +147,7 @@ if Settings.is_uninitilized():
|
|
| 147 |
settings.skip_artifacts_prepare_and_verify = (bool, False)
|
| 148 |
settings.data_classification_policy = None
|
| 149 |
settings.mock_inference_mode = (bool, False)
|
|
|
|
| 150 |
|
| 151 |
if Constants.is_uninitilized():
|
| 152 |
constants = Constants()
|
|
|
|
| 147 |
settings.skip_artifacts_prepare_and_verify = (bool, False)
|
| 148 |
settings.data_classification_policy = None
|
| 149 |
settings.mock_inference_mode = (bool, False)
|
| 150 |
+
settings.disable_hf_datasets_cache = (bool, True)
|
| 151 |
|
| 152 |
if Constants.is_uninitilized():
|
| 153 |
constants = Constants()
|
split_utils.py
CHANGED
|
@@ -226,7 +226,12 @@ def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]):
|
|
| 226 |
dict: A dictionary containing the generated new streams, where each key is the name
|
| 227 |
of the new stream and the value is a generator representing the stream.
|
| 228 |
"""
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
|
| 232 |
def random_mix_generator(
|
|
|
|
| 226 |
dict: A dictionary containing the generated new streams, where each key is the name
|
| 227 |
of the new stream and the value is a generator representing the stream.
|
| 228 |
"""
|
| 229 |
+
new_streams = {}
|
| 230 |
+
for key, val in mapping.items():
|
| 231 |
+
if key not in input_streams:
|
| 232 |
+
raise ValueError("Wrong stream name")
|
| 233 |
+
new_streams[val] = input_streams.pop(key)
|
| 234 |
+
return {**input_streams, **new_streams}
|
| 235 |
|
| 236 |
|
| 237 |
def random_mix_generator(
|
splitters.py
CHANGED
|
@@ -16,7 +16,7 @@ from .split_utils import (
|
|
| 16 |
)
|
| 17 |
from .stream import EmptyStreamError, FaultyStreamError, MultiStream
|
| 18 |
from .type_utils import isoftype
|
| 19 |
-
from .utils import
|
| 20 |
|
| 21 |
|
| 22 |
class Splitter(MultiStreamOperator):
|
|
@@ -353,7 +353,9 @@ class Sample(InstanceOperatorWithMultiStreamAccess):
|
|
| 353 |
sample_size = self.get_sample_size(instance)
|
| 354 |
try:
|
| 355 |
if self.local_cache is None:
|
| 356 |
-
self.local_cache =
|
|
|
|
|
|
|
| 357 |
|
| 358 |
source_stream = self.local_cache
|
| 359 |
source_stream = self.sampler.filter_source_by_instance(
|
|
|
|
| 16 |
)
|
| 17 |
from .stream import EmptyStreamError, FaultyStreamError, MultiStream
|
| 18 |
from .type_utils import isoftype
|
| 19 |
+
from .utils import recursive_shallow_copy
|
| 20 |
|
| 21 |
|
| 22 |
class Splitter(MultiStreamOperator):
|
|
|
|
| 353 |
sample_size = self.get_sample_size(instance)
|
| 354 |
try:
|
| 355 |
if self.local_cache is None:
|
| 356 |
+
self.local_cache = recursive_shallow_copy(
|
| 357 |
+
list(multi_stream[self.from_stream])
|
| 358 |
+
)
|
| 359 |
|
| 360 |
source_stream = self.local_cache
|
| 361 |
source_stream = self.sampler.filter_source_by_instance(
|
standard.py
CHANGED
|
@@ -249,12 +249,12 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
| 249 |
def produce(self, task_instances):
|
| 250 |
"""Use the recipe in production to produce model ready query from standard task instance."""
|
| 251 |
self.before_process_multi_stream()
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
)
|
| 258 |
multi_stream = self.inference(multi_stream)
|
| 259 |
return list(multi_stream[constants.inference_stream])
|
| 260 |
|
|
|
|
| 249 |
def produce(self, task_instances):
|
| 250 |
"""Use the recipe in production to produce model ready query from standard task instance."""
|
| 251 |
self.before_process_multi_stream()
|
| 252 |
+
streams = {
|
| 253 |
+
constants.inference_stream: self.production_preprocess(task_instances),
|
| 254 |
+
}
|
| 255 |
+
if self.use_demos:
|
| 256 |
+
streams[self.demos_pool_name] = self.production_demos_pool()
|
| 257 |
+
multi_stream = MultiStream.from_iterables(streams)
|
| 258 |
multi_stream = self.inference(multi_stream)
|
| 259 |
return list(multi_stream[constants.inference_stream])
|
| 260 |
|
stream.py
CHANGED
|
@@ -10,7 +10,7 @@ from .dataclass import Dataclass, OptionalField
|
|
| 10 |
from .generator_utils import CopyingReusableGenerator, ReusableGenerator
|
| 11 |
from .logging_utils import get_logger
|
| 12 |
from .settings_utils import get_settings
|
| 13 |
-
from .utils import
|
| 14 |
|
| 15 |
settings = get_settings()
|
| 16 |
logger = get_logger()
|
|
@@ -40,7 +40,7 @@ class ListStream(Stream):
|
|
| 40 |
|
| 41 |
def __iter__(self):
|
| 42 |
if self.copying:
|
| 43 |
-
return iter(
|
| 44 |
return iter(self.instances_list)
|
| 45 |
|
| 46 |
def peek(self):
|
|
@@ -244,7 +244,8 @@ class MultiStream(dict):
|
|
| 244 |
return IterableDatasetDict(
|
| 245 |
{
|
| 246 |
key: IterableDataset.from_generator(
|
| 247 |
-
self.get_generator,
|
|
|
|
| 248 |
)
|
| 249 |
for key in self.keys()
|
| 250 |
}
|
|
|
|
| 10 |
from .generator_utils import CopyingReusableGenerator, ReusableGenerator
|
| 11 |
from .logging_utils import get_logger
|
| 12 |
from .settings_utils import get_settings
|
| 13 |
+
from .utils import recursive_copy
|
| 14 |
|
| 15 |
settings = get_settings()
|
| 16 |
logger = get_logger()
|
|
|
|
| 40 |
|
| 41 |
def __iter__(self):
|
| 42 |
if self.copying:
|
| 43 |
+
return iter(recursive_copy(self.instances_list))
|
| 44 |
return iter(self.instances_list)
|
| 45 |
|
| 46 |
def peek(self):
|
|
|
|
| 244 |
return IterableDatasetDict(
|
| 245 |
{
|
| 246 |
key: IterableDataset.from_generator(
|
| 247 |
+
self.get_generator,
|
| 248 |
+
gen_kwargs={"key": key},
|
| 249 |
)
|
| 250 |
for key in self.keys()
|
| 251 |
}
|
stream_operators.py
CHANGED
|
@@ -31,6 +31,7 @@ The rest of this section is dedicated for operators that operates on streams.
|
|
| 31 |
|
| 32 |
"""
|
| 33 |
|
|
|
|
| 34 |
from typing import (
|
| 35 |
List,
|
| 36 |
Literal,
|
|
@@ -154,6 +155,7 @@ class DuplicateSplit(MultiStreamOperator):
|
|
| 154 |
|
| 155 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
| 156 |
assert self.split in multi_stream
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
| 31 |
|
| 32 |
"""
|
| 33 |
|
| 34 |
+
import copy
|
| 35 |
from typing import (
|
| 36 |
List,
|
| 37 |
Literal,
|
|
|
|
| 155 |
|
| 156 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
| 157 |
assert self.split in multi_stream
|
| 158 |
+
new_stream = copy.deepcopy(multi_stream[self.split])
|
| 159 |
+
new_stream.set_copying(copying=True)
|
| 160 |
+
multi_stream[self.to_split] = new_stream
|
| 161 |
+
return multi_stream
|
string_operators.py
CHANGED
|
@@ -87,3 +87,12 @@ class Replace(FieldOperator):
|
|
| 87 |
|
| 88 |
def process_value(self, value: str) -> str:
|
| 89 |
return value.replace(self.old, self.new)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
def process_value(self, value: str) -> str:
|
| 89 |
return value.replace(self.old, self.new)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class MapReplace(FieldOperator):
|
| 93 |
+
mapping: Dict[str, str]
|
| 94 |
+
|
| 95 |
+
def process_value(self, value: Any) -> Any:
|
| 96 |
+
for key, val in self.mapping.items():
|
| 97 |
+
value = value.replace(key, val)
|
| 98 |
+
return value
|
struct_data_operators.py
CHANGED
|
@@ -32,7 +32,7 @@ from .operators import FieldOperator, InstanceOperator
|
|
| 32 |
from .random_utils import new_random_generator
|
| 33 |
from .serializers import TableSerializer
|
| 34 |
from .types import Table
|
| 35 |
-
from .utils import
|
| 36 |
|
| 37 |
|
| 38 |
def shuffle_columns(table: Table, seed=0) -> Table:
|
|
@@ -76,7 +76,7 @@ class SerializeTable(ABC, TableSerializer):
|
|
| 76 |
shuffle_columns: bool = False
|
| 77 |
|
| 78 |
def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
|
| 79 |
-
value =
|
| 80 |
if self.shuffle_columns:
|
| 81 |
value = shuffle_columns(table=value, seed=self.seed)
|
| 82 |
|
|
@@ -207,6 +207,12 @@ class SerializeTableAsDFLoader(SerializeTable):
|
|
| 207 |
|
| 208 |
assert header and rows, "Incorrect input table format"
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
# Create a pandas DataFrame
|
| 211 |
df = pd.DataFrame(rows, columns=header)
|
| 212 |
|
|
@@ -252,6 +258,59 @@ class SerializeTableAsJson(SerializeTable):
|
|
| 252 |
return json.dumps(output_dict)
|
| 253 |
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
# truncate cell value to maximum allowed length
|
| 256 |
def truncate_cell(cell_value, max_len):
|
| 257 |
if cell_value is None:
|
|
@@ -490,7 +549,7 @@ class ConvertTableColNamesToSequential(FieldOperator):
|
|
| 490 |
"""
|
| 491 |
|
| 492 |
def process_value(self, table: Any) -> Any:
|
| 493 |
-
table_input =
|
| 494 |
return self.replace_header(table_content=table_input)
|
| 495 |
|
| 496 |
# replaces header with sequential column names
|
|
@@ -523,7 +582,7 @@ class ShuffleTableRows(FieldOperator):
|
|
| 523 |
"""
|
| 524 |
|
| 525 |
def process_value(self, table: Any) -> Any:
|
| 526 |
-
table_input =
|
| 527 |
return shuffle_rows(table_input)
|
| 528 |
|
| 529 |
|
|
@@ -544,7 +603,7 @@ class ShuffleTableColumns(FieldOperator):
|
|
| 544 |
"""
|
| 545 |
|
| 546 |
def process_value(self, table: Any) -> Any:
|
| 547 |
-
table_input =
|
| 548 |
return shuffle_columns(table_input)
|
| 549 |
|
| 550 |
|
|
@@ -658,3 +717,133 @@ class ConstructTableFromRowsCols(InstanceOperator):
|
|
| 658 |
instance[self.to_field] = output_dict
|
| 659 |
|
| 660 |
return instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
from .random_utils import new_random_generator
|
| 33 |
from .serializers import TableSerializer
|
| 34 |
from .types import Table
|
| 35 |
+
from .utils import recursive_copy
|
| 36 |
|
| 37 |
|
| 38 |
def shuffle_columns(table: Table, seed=0) -> Table:
|
|
|
|
| 76 |
shuffle_columns: bool = False
|
| 77 |
|
| 78 |
def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
|
| 79 |
+
value = recursive_copy(value)
|
| 80 |
if self.shuffle_columns:
|
| 81 |
value = shuffle_columns(table=value, seed=self.seed)
|
| 82 |
|
|
|
|
| 207 |
|
| 208 |
assert header and rows, "Incorrect input table format"
|
| 209 |
|
| 210 |
+
# Fix duplicate columns, ensuring the first occurrence has no suffix
|
| 211 |
+
header = [
|
| 212 |
+
f"{col}_{header[:i].count(col)}" if header[:i].count(col) > 0 else col
|
| 213 |
+
for i, col in enumerate(header)
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
# Create a pandas DataFrame
|
| 217 |
df = pd.DataFrame(rows, columns=header)
|
| 218 |
|
|
|
|
| 258 |
return json.dumps(output_dict)
|
| 259 |
|
| 260 |
|
| 261 |
+
class SerializeTableAsHTML(SerializeTable):
|
| 262 |
+
"""HTML Table Serializer.
|
| 263 |
+
|
| 264 |
+
HTML table format used for rendering tables in web pages.
|
| 265 |
+
Format(Sample):
|
| 266 |
+
<table>
|
| 267 |
+
<thead>
|
| 268 |
+
<tr><th>name</th><th>age</th><th>sex</th></tr>
|
| 269 |
+
</thead>
|
| 270 |
+
<tbody>
|
| 271 |
+
<tr><td>Alice</td><td>26</td><td>F</td></tr>
|
| 272 |
+
<tr><td>Raj</td><td>34</td><td>M</td></tr>
|
| 273 |
+
</tbody>
|
| 274 |
+
</table>
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
# main method that serializes a table.
|
| 278 |
+
# table_content must be in the prescribed input format.
|
| 279 |
+
def serialize_table(self, table_content: Dict) -> str:
|
| 280 |
+
# Extract headers and rows from the dictionary
|
| 281 |
+
header = table_content.get("header", [])
|
| 282 |
+
rows = table_content.get("rows", [])
|
| 283 |
+
|
| 284 |
+
assert header and rows, "Incorrect input table format"
|
| 285 |
+
|
| 286 |
+
# Build the HTML table structure
|
| 287 |
+
serialized_tbl_str = "<table>\n"
|
| 288 |
+
serialized_tbl_str += self.process_header(header) + "\n"
|
| 289 |
+
serialized_tbl_str += self.process_rows(rows) + "\n"
|
| 290 |
+
serialized_tbl_str += "</table>"
|
| 291 |
+
|
| 292 |
+
return serialized_tbl_str.strip()
|
| 293 |
+
|
| 294 |
+
# serialize the header into an HTML <thead> section
|
| 295 |
+
def process_header(self, header: List) -> str:
|
| 296 |
+
header_html = " <thead>\n <tr>"
|
| 297 |
+
for col in header:
|
| 298 |
+
header_html += f"<th>{col}</th>"
|
| 299 |
+
header_html += "</tr>\n </thead>"
|
| 300 |
+
return header_html
|
| 301 |
+
|
| 302 |
+
# serialize the rows into an HTML <tbody> section
|
| 303 |
+
def process_rows(self, rows: List[List]) -> str:
|
| 304 |
+
rows_html = " <tbody>"
|
| 305 |
+
for row in rows:
|
| 306 |
+
rows_html += "\n <tr>"
|
| 307 |
+
for cell in row:
|
| 308 |
+
rows_html += f"<td>{cell}</td>"
|
| 309 |
+
rows_html += "</tr>"
|
| 310 |
+
rows_html += "\n </tbody>"
|
| 311 |
+
return rows_html
|
| 312 |
+
|
| 313 |
+
|
| 314 |
# truncate cell value to maximum allowed length
|
| 315 |
def truncate_cell(cell_value, max_len):
|
| 316 |
if cell_value is None:
|
|
|
|
| 549 |
"""
|
| 550 |
|
| 551 |
def process_value(self, table: Any) -> Any:
|
| 552 |
+
table_input = recursive_copy(table)
|
| 553 |
return self.replace_header(table_content=table_input)
|
| 554 |
|
| 555 |
# replaces header with sequential column names
|
|
|
|
| 582 |
"""
|
| 583 |
|
| 584 |
def process_value(self, table: Any) -> Any:
|
| 585 |
+
table_input = recursive_copy(table)
|
| 586 |
return shuffle_rows(table_input)
|
| 587 |
|
| 588 |
|
|
|
|
| 603 |
"""
|
| 604 |
|
| 605 |
def process_value(self, table: Any) -> Any:
|
| 606 |
+
table_input = recursive_copy(table)
|
| 607 |
return shuffle_columns(table_input)
|
| 608 |
|
| 609 |
|
|
|
|
| 717 |
instance[self.to_field] = output_dict
|
| 718 |
|
| 719 |
return instance
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class TransposeTable(FieldOperator):
|
| 723 |
+
"""Transpose a table.
|
| 724 |
+
|
| 725 |
+
Sample Input:
|
| 726 |
+
{
|
| 727 |
+
"header": ["name", "age", "sex"],
|
| 728 |
+
"rows": [["Alice", 26, "F"], ["Raj", 34, "M"], ["Donald", 39, "M"]],
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
Sample Output:
|
| 732 |
+
{
|
| 733 |
+
"header": [" ", "0", "1", "2"],
|
| 734 |
+
"rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
|
| 735 |
+
}
|
| 736 |
+
"""
|
| 737 |
+
|
| 738 |
+
def process_value(self, table: Any) -> Any:
|
| 739 |
+
return self.transpose_table(table)
|
| 740 |
+
|
| 741 |
+
def transpose_table(self, table: Dict) -> Dict:
|
| 742 |
+
# Extract the header and rows from the table object
|
| 743 |
+
header = table["header"]
|
| 744 |
+
rows = table["rows"]
|
| 745 |
+
|
| 746 |
+
# Transpose the table by converting rows as columns and vice versa
|
| 747 |
+
transposed_header = [" "] + [str(i) for i in range(len(rows))]
|
| 748 |
+
transposed_rows = [
|
| 749 |
+
[header[i]] + [row[i] for row in rows] for i in range(len(header))
|
| 750 |
+
]
|
| 751 |
+
|
| 752 |
+
return {"header": transposed_header, "rows": transposed_rows}
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
class DuplicateTableRows(FieldOperator):
|
| 756 |
+
"""Duplicates specific rows of a table for the given number of times.
|
| 757 |
+
|
| 758 |
+
Args:
|
| 759 |
+
row_indices (List[int]) - rows to be duplicated
|
| 760 |
+
times(int) - how many times to duplicate
|
| 761 |
+
"""
|
| 762 |
+
|
| 763 |
+
row_indices: List[int] = []
|
| 764 |
+
times: int = 1
|
| 765 |
+
|
| 766 |
+
def process_value(self, table: Any) -> Any:
|
| 767 |
+
# Extract the header and rows from the table
|
| 768 |
+
header = table["header"]
|
| 769 |
+
rows = table["rows"]
|
| 770 |
+
|
| 771 |
+
# Duplicate only the specified rows
|
| 772 |
+
duplicated_rows = []
|
| 773 |
+
for i, row in enumerate(rows):
|
| 774 |
+
if i in self.row_indices:
|
| 775 |
+
duplicated_rows.extend(
|
| 776 |
+
[row] * self.times
|
| 777 |
+
) # Duplicate the selected rows
|
| 778 |
+
else:
|
| 779 |
+
duplicated_rows.append(row) # Leave other rows unchanged
|
| 780 |
+
|
| 781 |
+
# Return the new table with selectively duplicated rows
|
| 782 |
+
return {"header": header, "rows": duplicated_rows}
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
class DuplicateTableColumns(FieldOperator):
|
| 786 |
+
"""Duplicates specific columns of a table for the given number of times.
|
| 787 |
+
|
| 788 |
+
Args:
|
| 789 |
+
column_indices (List[int]) - columns to be duplicated
|
| 790 |
+
times(int) - how many times to duplicate
|
| 791 |
+
"""
|
| 792 |
+
|
| 793 |
+
column_indices: List[int] = []
|
| 794 |
+
times: int = 1
|
| 795 |
+
|
| 796 |
+
def process_value(self, table: Any) -> Any:
|
| 797 |
+
# Extract the header and rows from the table
|
| 798 |
+
header = table["header"]
|
| 799 |
+
rows = table["rows"]
|
| 800 |
+
|
| 801 |
+
# Duplicate the specified columns in the header
|
| 802 |
+
duplicated_header = []
|
| 803 |
+
for i, col in enumerate(header):
|
| 804 |
+
if i in self.column_indices:
|
| 805 |
+
duplicated_header.extend([col] * self.times)
|
| 806 |
+
else:
|
| 807 |
+
duplicated_header.append(col)
|
| 808 |
+
|
| 809 |
+
# Duplicate the specified columns in each row
|
| 810 |
+
duplicated_rows = []
|
| 811 |
+
for row in rows:
|
| 812 |
+
new_row = []
|
| 813 |
+
for i, value in enumerate(row):
|
| 814 |
+
if i in self.column_indices:
|
| 815 |
+
new_row.extend([value] * self.times)
|
| 816 |
+
else:
|
| 817 |
+
new_row.append(value)
|
| 818 |
+
duplicated_rows.append(new_row)
|
| 819 |
+
|
| 820 |
+
# Return the new table with selectively duplicated columns
|
| 821 |
+
return {"header": duplicated_header, "rows": duplicated_rows}
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
class InsertEmptyTableRows(FieldOperator):
|
| 825 |
+
"""Inserts empty rows in a table randomly for the given number of times.
|
| 826 |
+
|
| 827 |
+
Args:
|
| 828 |
+
times(int) - how many times to insert
|
| 829 |
+
"""
|
| 830 |
+
|
| 831 |
+
times: int = 0
|
| 832 |
+
|
| 833 |
+
def process_value(self, table: Any) -> Any:
|
| 834 |
+
# Extract the header and rows from the table
|
| 835 |
+
header = table["header"]
|
| 836 |
+
rows = table["rows"]
|
| 837 |
+
|
| 838 |
+
# Insert empty rows at random positions
|
| 839 |
+
for _ in range(self.times):
|
| 840 |
+
empty_row = [""] * len(
|
| 841 |
+
header
|
| 842 |
+
) # Create an empty row with the same number of columns
|
| 843 |
+
insert_pos = random.randint(
|
| 844 |
+
0, len(rows)
|
| 845 |
+
) # Get a random position to insert the empty row created
|
| 846 |
+
rows.insert(insert_pos, empty_row)
|
| 847 |
+
|
| 848 |
+
# Return the modified table
|
| 849 |
+
return {"header": header, "rows": rows}
|
templates.py
CHANGED
|
@@ -210,7 +210,7 @@ class ApplyTemplate(InstanceOperator):
|
|
| 210 |
if self.demos_field not in instance:
|
| 211 |
raise ValueError("Demos field is missing.")
|
| 212 |
instance[self.demos_field] = [
|
| 213 |
-
self.apply(template, demo_instance
|
| 214 |
for demo_instance in instance[self.demos_field]
|
| 215 |
]
|
| 216 |
dict_set(instance, "recipe_metadata/template", template)
|
|
|
|
| 210 |
if self.demos_field not in instance:
|
| 211 |
raise ValueError("Demos field is missing.")
|
| 212 |
instance[self.demos_field] = [
|
| 213 |
+
self.apply(template, demo_instance)
|
| 214 |
for demo_instance in instance[self.demos_field]
|
| 215 |
]
|
| 216 |
dict_set(instance, "recipe_metadata/template", template)
|
type_utils.py
CHANGED
|
@@ -4,6 +4,7 @@ import io
|
|
| 4 |
import itertools
|
| 5 |
import re
|
| 6 |
import typing
|
|
|
|
| 7 |
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
| 8 |
|
| 9 |
from .utils import safe_eval
|
|
@@ -810,6 +811,7 @@ class NormalizedType(typing.NamedTuple):
|
|
| 810 |
return f"{self.origin}[{self.args}])"
|
| 811 |
|
| 812 |
|
|
|
|
| 813 |
def _normalize_args(tps: TypeArgs):
|
| 814 |
if isinstance(tps, str):
|
| 815 |
return tps
|
|
@@ -918,6 +920,7 @@ def _is_origin_subtype_args(
|
|
| 918 |
return _is_normal_subtype(left, right, forward_refs)
|
| 919 |
|
| 920 |
|
|
|
|
| 921 |
def _is_normal_subtype(
|
| 922 |
left: NormalizedType,
|
| 923 |
right: NormalizedType,
|
|
|
|
| 4 |
import itertools
|
| 5 |
import re
|
| 6 |
import typing
|
| 7 |
+
from functools import lru_cache
|
| 8 |
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
| 9 |
|
| 10 |
from .utils import safe_eval
|
|
|
|
| 811 |
return f"{self.origin}[{self.args}])"
|
| 812 |
|
| 813 |
|
| 814 |
+
@lru_cache(maxsize=None)
|
| 815 |
def _normalize_args(tps: TypeArgs):
|
| 816 |
if isinstance(tps, str):
|
| 817 |
return tps
|
|
|
|
| 920 |
return _is_normal_subtype(left, right, forward_refs)
|
| 921 |
|
| 922 |
|
| 923 |
+
@lru_cache(maxsize=None)
|
| 924 |
def _is_normal_subtype(
|
| 925 |
left: NormalizedType,
|
| 926 |
right: NormalizedType,
|
utils.py
CHANGED
|
@@ -148,5 +148,88 @@ def import_module_from_file(file_path):
|
|
| 148 |
return module
|
| 149 |
|
| 150 |
|
| 151 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
return copy.deepcopy(obj)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
return module
|
| 149 |
|
| 150 |
|
| 151 |
+
def deep_copy(obj):
|
| 152 |
+
"""Creates a deep copy of the given object.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
obj: The object to be deep copied.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
A deep copy of the original object.
|
| 159 |
+
"""
|
| 160 |
return copy.deepcopy(obj)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def shallow_copy(obj):
|
| 164 |
+
"""Creates a shallow copy of the given object.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
obj: The object to be shallow copied.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
A shallow copy of the original object.
|
| 171 |
+
"""
|
| 172 |
+
return copy.copy(obj)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def recursive_copy(obj, internal_copy=None):
|
| 176 |
+
"""Recursively copies an object with a selective copy method.
|
| 177 |
+
|
| 178 |
+
For `list`, `dict`, and `tuple` types, it recursively copies their contents.
|
| 179 |
+
For other types, it uses the provided `internal_copy` function if available.
|
| 180 |
+
Objects without a `copy` method are returned as is.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
obj: The object to be copied.
|
| 184 |
+
internal_copy (callable, optional): The copy function to use for non-container objects.
|
| 185 |
+
If `None`, objects without a `copy` method are returned as is.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
The recursively copied object.
|
| 189 |
+
"""
|
| 190 |
+
# Handle dictionaries
|
| 191 |
+
if isinstance(obj, dict):
|
| 192 |
+
return type(obj)(
|
| 193 |
+
{key: recursive_copy(value, internal_copy) for key, value in obj.items()}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Handle named tuples
|
| 197 |
+
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
|
| 198 |
+
return type(obj)(*(recursive_copy(item, internal_copy) for item in obj))
|
| 199 |
+
|
| 200 |
+
# Handle tuples and lists
|
| 201 |
+
if isinstance(obj, (tuple, list)):
|
| 202 |
+
return type(obj)(recursive_copy(item, internal_copy) for item in obj)
|
| 203 |
+
|
| 204 |
+
if internal_copy is None:
|
| 205 |
+
return obj
|
| 206 |
+
|
| 207 |
+
return internal_copy(obj)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def recursive_deep_copy(obj):
|
| 211 |
+
"""Performs a recursive deep copy of the given object.
|
| 212 |
+
|
| 213 |
+
This function uses `deep_copy` as the internal copy method for non-container objects.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
obj: The object to be deep copied.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
A recursively deep-copied version of the original object.
|
| 220 |
+
"""
|
| 221 |
+
return recursive_copy(obj, deep_copy)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def recursive_shallow_copy(obj):
|
| 225 |
+
"""Performs a recursive shallow copy of the given object.
|
| 226 |
+
|
| 227 |
+
This function uses `shallow_copy` as the internal copy method for non-container objects.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
obj: The object to be shallow copied.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
A recursively shallow-copied version of the original object.
|
| 234 |
+
"""
|
| 235 |
+
return recursive_copy(obj, shallow_copy)
|
version.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
version = "1.
|
|
|
|
| 1 |
+
version = "1.14.0"
|