Upload operators.py with huggingface_hub
Browse files- operators.py +64 -1
operators.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import importlib
|
| 2 |
import inspect
|
| 3 |
import uuid
|
|
@@ -18,7 +19,7 @@ from typing import (
|
|
| 18 |
)
|
| 19 |
|
| 20 |
from .artifact import Artifact, fetch_artifact
|
| 21 |
-
from .dataclass import NonPositionalField
|
| 22 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
| 23 |
from .operator import (
|
| 24 |
MultiStream,
|
|
@@ -734,3 +735,65 @@ class EncodeLabels(StreamInstanceOperator):
|
|
| 734 |
dict_set(instance, field, new_values, use_dpath=True, set_multiple=True)
|
| 735 |
|
| 736 |
return instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
import importlib
|
| 3 |
import inspect
|
| 4 |
import uuid
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
from .artifact import Artifact, fetch_artifact
|
| 22 |
+
from .dataclass import NonPositionalField, OptionalField
|
| 23 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
| 24 |
from .operator import (
|
| 25 |
MultiStream,
|
|
|
|
| 735 |
dict_set(instance, field, new_values, use_dpath=True, set_multiple=True)
|
| 736 |
|
| 737 |
return instance
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
class StreamRefiner(SingleStreamOperator):
|
| 741 |
+
max_instances: int = None
|
| 742 |
+
|
| 743 |
+
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
| 744 |
+
if self.max_instances is not None:
|
| 745 |
+
yield from stream.take(self.max_instances)
|
| 746 |
+
else:
|
| 747 |
+
yield from stream
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
class DeterministicBalancer(StreamRefiner):
|
| 751 |
+
"""
|
| 752 |
+
A class used to balance streams deterministically.
|
| 753 |
+
|
| 754 |
+
Attributes:
|
| 755 |
+
fields (List[str]): A list of field names to be used in determining the signature of an instance.
|
| 756 |
+
streams (List[str]): A list of stream names to be processed by the balancer.
|
| 757 |
+
|
| 758 |
+
Usage:
|
| 759 |
+
balancer = DeterministicBalancer(fields=["field1", "field2"], streams=["stream1", "stream2"])
|
| 760 |
+
balanced_stream = balancer.process(stream)
|
| 761 |
+
"""
|
| 762 |
+
|
| 763 |
+
fields: List[str]
|
| 764 |
+
|
| 765 |
+
def signature(self, instance):
|
| 766 |
+
return str(tuple(dict_get(instance, field, use_dpath=True) for field in self.fields))
|
| 767 |
+
|
| 768 |
+
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
| 769 |
+
counter = collections.Counter()
|
| 770 |
+
|
| 771 |
+
for instance in stream:
|
| 772 |
+
counter[self.signature(instance)] += 1
|
| 773 |
+
|
| 774 |
+
lowest_count = counter.most_common()[-1][-1]
|
| 775 |
+
|
| 776 |
+
max_total_instances_per_sign = lowest_count
|
| 777 |
+
if self.max_instances is not None:
|
| 778 |
+
max_total_instances_per_sign = min(lowest_count, self.max_instances // len(counter))
|
| 779 |
+
|
| 780 |
+
counter = collections.Counter()
|
| 781 |
+
|
| 782 |
+
for instance in stream:
|
| 783 |
+
sign = self.signature(instance)
|
| 784 |
+
if counter[sign] < max_total_instances_per_sign:
|
| 785 |
+
counter[sign] += 1
|
| 786 |
+
yield instance
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
class LengthBalancer(DeterministicBalancer):
|
| 790 |
+
segments_boundaries: List[int]
|
| 791 |
+
|
| 792 |
+
def signature(self, instance):
|
| 793 |
+
total_len = 0
|
| 794 |
+
for field in self.fields:
|
| 795 |
+
total_len += len(dict_get(instance, field, use_dpath=True))
|
| 796 |
+
for i, val in enumerate(self.segments_boundaries):
|
| 797 |
+
if total_len < val:
|
| 798 |
+
return i
|
| 799 |
+
return i + 1
|