trackio / ui /helpers /run_selection.py
kshitijthakkar's picture
Upload folder using huggingface_hub
b2ba7d2 verified
from dataclasses import dataclass, field
try:
import trackio.utils as utils
except ImportError:
import utils
@dataclass
class RunSelection:
choices: list[str] = field(default_factory=list)
selected: list[str] = field(default_factory=list)
locked: bool = False
def update_choices(
self, runs: list[str], preferred: list[str] | None = None
) -> bool:
if self.choices == runs:
return False
new_choices = set(runs) - set(self.choices)
self.choices = list(runs)
if self.locked:
base = set(self.selected) | new_choices
elif preferred:
base = set(preferred)
else:
base = set(runs)
self.selected = [run for run in self.choices if run in base]
return True
def select(self, runs: list[str]) -> list[str]:
choice_set = set(self.choices)
self.selected = [run for run in runs if run in choice_set]
self.locked = True
return self.selected
def replace_group(
self, group_runs: list[str], new_subset: list[str] | None
) -> tuple[list[str], list[str]]:
new_subset = utils.ordered_subset(group_runs, new_subset)
selection_set = set(self.selected)
selection_set.difference_update(group_runs)
selection_set.update(new_subset)
self.selected = [run for run in self.choices if run in selection_set]
self.locked = True
return new_subset, self.selected