Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import csv | |
| import datetime | |
| import io | |
| import json | |
| import os | |
| import uuid | |
| from abc import ABC, abstractmethod | |
| from typing import TYPE_CHECKING, Any, List, Optional | |
| import gradio as gr | |
| from gradio import encryptor, utils | |
| from gradio.documentation import document, set_documentation_group | |
| if TYPE_CHECKING: | |
| from gradio.components import IOComponent | |
| set_documentation_group("flagging") | |
| def _get_dataset_features_info(is_new, components): | |
| """ | |
| Takes in a list of components and returns a dataset features info | |
| Parameters: | |
| is_new: boolean, whether the dataset is new or not | |
| components: list of components | |
| Returns: | |
| infos: a dictionary of the dataset features | |
| file_preview_types: dictionary mapping of gradio components to appropriate string. | |
| header: list of header strings | |
| """ | |
| infos = {"flagged": {"features": {}}} | |
| # File previews for certain input and output types | |
| file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} | |
| headers = [] | |
| # Generate the headers and dataset_infos | |
| if is_new: | |
| for component in components: | |
| headers.append(component.label) | |
| infos["flagged"]["features"][component.label] = { | |
| "dtype": "string", | |
| "_type": "Value", | |
| } | |
| if isinstance(component, tuple(file_preview_types)): | |
| headers.append(component.label + " file") | |
| for _component, _type in file_preview_types.items(): | |
| if isinstance(component, _component): | |
| infos["flagged"]["features"][component.label + " file"] = { | |
| "_type": _type | |
| } | |
| break | |
| headers.append("flag") | |
| infos["flagged"]["features"]["flag"] = { | |
| "dtype": "string", | |
| "_type": "Value", | |
| } | |
| return infos, file_preview_types, headers | |
| class FlaggingCallback(ABC): | |
| """ | |
| An abstract class for defining the methods that any FlaggingCallback should have. | |
| """ | |
| def setup(self, components: List[IOComponent], flagging_dir: str): | |
| """ | |
| This method should be overridden and ensure that everything is set up correctly for flag(). | |
| This method gets called once at the beginning of the Interface.launch() method. | |
| Parameters: | |
| components: Set of components that will provide flagged data. | |
| flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()). | |
| """ | |
| pass | |
| def flag( | |
| self, | |
| flag_data: List[Any], | |
| flag_option: Optional[str] = None, | |
| flag_index: Optional[int] = None, | |
| username: Optional[str] = None, | |
| ) -> int: | |
| """ | |
| This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments. | |
| This gets called every time the <flag> button is pressed. | |
| Parameters: | |
| interface: The Interface object that is being used to launch the flagging interface. | |
| flag_data: The data to be flagged. | |
| flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. | |
| flag_index (optional): The index of the sample that is being flagged. | |
| username (optional): The username of the user that is flagging the data, if logged in. | |
| Returns: | |
| (int) The total number of samples that have been flagged. | |
| """ | |
| pass | |
| class SimpleCSVLogger(FlaggingCallback): | |
| """ | |
| A simplified implementation of the FlaggingCallback abstract class | |
| provided for illustrative purposes. Each flagged sample (both the input and output data) | |
| is logged to a CSV file on the machine running the gradio app. | |
| Example: | |
| import gradio as gr | |
| def image_classifier(inp): | |
| return {'cat': 0.3, 'dog': 0.7} | |
| demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label", | |
| flagging_callback=SimpleCSVLogger()) | |
| """ | |
| def __init__(self): | |
| pass | |
| def setup(self, components: List[IOComponent], flagging_dir: str): | |
| self.components = components | |
| self.flagging_dir = flagging_dir | |
| os.makedirs(flagging_dir, exist_ok=True) | |
| def flag( | |
| self, | |
| flag_data: List[Any], | |
| flag_option: Optional[str] = None, | |
| flag_index: Optional[int] = None, | |
| username: Optional[str] = None, | |
| ) -> int: | |
| flagging_dir = self.flagging_dir | |
| log_filepath = os.path.join(flagging_dir, "log.csv") | |
| csv_data = [] | |
| for component, sample in zip(self.components, flag_data): | |
| save_dir = os.path.join( | |
| flagging_dir, utils.strip_invalid_filename_characters(component.label) | |
| ) | |
| csv_data.append( | |
| component.deserialize( | |
| sample, | |
| save_dir, | |
| None, | |
| ) | |
| ) | |
| with open(log_filepath, "a", newline="") as csvfile: | |
| writer = csv.writer(csvfile) | |
| writer.writerow(utils.sanitize_list_for_csv(csv_data)) | |
| with open(log_filepath, "r") as csvfile: | |
| line_count = len([None for row in csv.reader(csvfile)]) - 1 | |
| return line_count | |
| class CSVLogger(FlaggingCallback): | |
| """ | |
| The default implementation of the FlaggingCallback abstract class. Each flagged | |
| sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app. | |
| Example: | |
| import gradio as gr | |
| def image_classifier(inp): | |
| return {'cat': 0.3, 'dog': 0.7} | |
| demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label", | |
| flagging_callback=CSVLogger()) | |
| Guides: using_flagging | |
| """ | |
| def __init__(self): | |
| pass | |
| def setup( | |
| self, | |
| components: List[IOComponent], | |
| flagging_dir: str, | |
| encryption_key: Optional[str] = None, | |
| ): | |
| self.components = components | |
| self.flagging_dir = flagging_dir | |
| self.encryption_key = encryption_key | |
| os.makedirs(flagging_dir, exist_ok=True) | |
| def flag( | |
| self, | |
| flag_data: List[Any], | |
| flag_option: Optional[str] = None, | |
| flag_index: Optional[int] = None, | |
| username: Optional[str] = None, | |
| ) -> int: | |
| flagging_dir = self.flagging_dir | |
| log_filepath = os.path.join(flagging_dir, "log.csv") | |
| is_new = not os.path.exists(log_filepath) | |
| if flag_index is None: | |
| csv_data = [] | |
| for idx, (component, sample) in enumerate(zip(self.components, flag_data)): | |
| save_dir = os.path.join( | |
| flagging_dir, | |
| utils.strip_invalid_filename_characters( | |
| component.label or f"component {idx}" | |
| ), | |
| ) | |
| if utils.is_update(sample): | |
| csv_data.append(str(sample)) | |
| else: | |
| csv_data.append( | |
| component.deserialize( | |
| sample, | |
| save_dir=save_dir, | |
| encryption_key=self.encryption_key, | |
| ) | |
| if sample is not None | |
| else "" | |
| ) | |
| csv_data.append(flag_option if flag_option is not None else "") | |
| csv_data.append(username if username is not None else "") | |
| csv_data.append(str(datetime.datetime.now())) | |
| if is_new: | |
| headers = [ | |
| component.label or f"component {idx}" | |
| for idx, component in enumerate(self.components) | |
| ] + [ | |
| "flag", | |
| "username", | |
| "timestamp", | |
| ] | |
| def replace_flag_at_index(file_content): | |
| file_content = io.StringIO(file_content) | |
| content = list(csv.reader(file_content)) | |
| header = content[0] | |
| flag_col_index = header.index("flag") | |
| content[flag_index][flag_col_index] = flag_option | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| writer.writerows(utils.sanitize_list_for_csv(content)) | |
| return output.getvalue() | |
| if self.encryption_key: | |
| output = io.StringIO() | |
| if not is_new: | |
| with open(log_filepath, "rb", encoding="utf-8") as csvfile: | |
| encrypted_csv = csvfile.read() | |
| decrypted_csv = encryptor.decrypt( | |
| self.encryption_key, encrypted_csv | |
| ) | |
| file_content = decrypted_csv.decode() | |
| if flag_index is not None: | |
| file_content = replace_flag_at_index(file_content) | |
| output.write(file_content) | |
| writer = csv.writer(output) | |
| if flag_index is None: | |
| if is_new: | |
| writer.writerow(utils.sanitize_list_for_csv(headers)) | |
| writer.writerow(utils.sanitize_list_for_csv(csv_data)) | |
| with open(log_filepath, "wb", encoding="utf-8") as csvfile: | |
| csvfile.write( | |
| encryptor.encrypt(self.encryption_key, output.getvalue().encode()) | |
| ) | |
| else: | |
| if flag_index is None: | |
| with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile: | |
| writer = csv.writer(csvfile) | |
| if is_new: | |
| writer.writerow(utils.sanitize_list_for_csv(headers)) | |
| writer.writerow(utils.sanitize_list_for_csv(csv_data)) | |
| else: | |
| with open(log_filepath, encoding="utf-8") as csvfile: | |
| file_content = csvfile.read() | |
| file_content = replace_flag_at_index(file_content) | |
| with open( | |
| log_filepath, "w", newline="", encoding="utf-8" | |
| ) as csvfile: # newline parameter needed for Windows | |
| csvfile.write(utils.sanitize_list_for_csv(file_content)) | |
| with open(log_filepath, "r", encoding="utf-8") as csvfile: | |
| line_count = len([None for row in csv.reader(csvfile)]) - 1 | |
| return line_count | |
| class HuggingFaceDatasetSaver(FlaggingCallback): | |
| """ | |
| A callback that saves each flagged sample (both the input and output data) | |
| to a HuggingFace dataset. | |
| Example: | |
| import gradio as gr | |
| hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes") | |
| def image_classifier(inp): | |
| return {'cat': 0.3, 'dog': 0.7} | |
| demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label", | |
| allow_flagging="manual", flagging_callback=hf_writer) | |
| Guides: using_flagging | |
| """ | |
| def __init__( | |
| self, | |
| hf_token: str, | |
| dataset_name: str, | |
| organization: Optional[str] = None, | |
| private: bool = False, | |
| ): | |
| """ | |
| Parameters: | |
| hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset. | |
| dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1" | |
| organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token. | |
| private: Whether the dataset should be private (defaults to False). | |
| """ | |
| self.hf_token = hf_token | |
| self.dataset_name = dataset_name | |
| self.organization_name = organization | |
| self.dataset_private = private | |
| def setup(self, components: List[IOComponent], flagging_dir: str): | |
| """ | |
| Params: | |
| flagging_dir (str): local directory where the dataset is cloned, | |
| updated, and pushed from. | |
| """ | |
| try: | |
| import huggingface_hub | |
| except (ImportError, ModuleNotFoundError): | |
| raise ImportError( | |
| "Package `huggingface_hub` not found is needed " | |
| "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'." | |
| ) | |
| path_to_dataset_repo = huggingface_hub.create_repo( | |
| # name=self.dataset_name, | |
| repo_id=self.dataset_name, | |
| token=self.hf_token, | |
| private=self.dataset_private, | |
| repo_type="dataset", | |
| exist_ok=True, | |
| ) | |
| self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10" | |
| self.components = components | |
| self.flagging_dir = flagging_dir | |
| self.dataset_dir = os.path.join(flagging_dir, self.dataset_name) | |
| self.repo = huggingface_hub.Repository( | |
| local_dir=self.dataset_dir, | |
| clone_from=path_to_dataset_repo, | |
| use_auth_token=self.hf_token, | |
| ) | |
| self.repo.git_pull(lfs=True) | |
| # Should filename be user-specified? | |
| self.log_file = os.path.join(self.dataset_dir, "data.csv") | |
| self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json") | |
| def flag( | |
| self, | |
| flag_data: List[Any], | |
| flag_option: Optional[str] = None, | |
| flag_index: Optional[int] = None, | |
| username: Optional[str] = None, | |
| ) -> int: | |
| self.repo.git_pull(lfs=True) | |
| is_new = not os.path.exists(self.log_file) | |
| with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile: | |
| writer = csv.writer(csvfile) | |
| # File previews for certain input and output types | |
| infos, file_preview_types, headers = _get_dataset_features_info( | |
| is_new, self.components | |
| ) | |
| # Generate the headers and dataset_infos | |
| if is_new: | |
| writer.writerow(utils.sanitize_list_for_csv(headers)) | |
| # Generate the row corresponding to the flagged sample | |
| csv_data = [] | |
| for component, sample in zip(self.components, flag_data): | |
| save_dir = os.path.join( | |
| self.dataset_dir, | |
| utils.strip_invalid_filename_characters(component.label), | |
| ) | |
| # filepath = component.deserialize(sample, save_dir, None) | |
| if sample is not None and str(component)!='image': | |
| filepath = component.deserialize(sample, save_dir, None) | |
| else: | |
| filepath = component.deserialize(sample, None, None) # not saving image | |
| csv_data.append(filepath) | |
| if isinstance(component, tuple(file_preview_types)): | |
| csv_data.append( | |
| "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath) | |
| ) | |
| csv_data.append(flag_option if flag_option is not None else "") | |
| writer.writerow(utils.sanitize_list_for_csv(csv_data)) | |
| if is_new: | |
| json.dump(infos, open(self.infos_file, "w")) | |
| with open(self.log_file, "r", encoding="utf-8") as csvfile: | |
| line_count = len([None for row in csv.reader(csvfile)]) - 1 | |
| self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count)) | |
| return line_count | |
| class HuggingFaceDatasetJSONSaver(FlaggingCallback): | |
| """ | |
| A FlaggingCallback that saves flagged data to a Hugging Face dataset in JSONL format. | |
| Each data sample is saved in a different JSONL file, | |
| allowing multiple users to use flagging simultaneously. | |
| Saving to a single CSV would cause errors as only one user can edit at the same time. | |
| """ | |
| def __init__( | |
| self, | |
| hf_foken: str, | |
| dataset_name: str, | |
| organization: Optional[str] = None, | |
| private: bool = False, | |
| verbose: bool = True, | |
| ): | |
| """ | |
| Params: | |
| hf_token (str): The token to use to access the huggingface API. | |
| dataset_name (str): The name of the dataset to save the data to, e.g. | |
| "image-classifier-1" | |
| organization (str): The name of the organization to which to attach | |
| the datasets. If None, the dataset attaches to the user only. | |
| private (bool): If the dataset does not already exist, whether it | |
| should be created as a private dataset or public. Private datasets | |
| may require paid huggingface.co accounts | |
| verbose (bool): Whether to print out the status of the dataset | |
| creation. | |
| """ | |
| self.hf_foken = hf_foken | |
| self.dataset_name = dataset_name | |
| self.organization_name = organization | |
| self.dataset_private = private | |
| self.verbose = verbose | |
| def setup(self, components: List[IOComponent], flagging_dir: str): | |
| """ | |
| Params: | |
| components List[Component]: list of components for flagging | |
| flagging_dir (str): local directory where the dataset is cloned, | |
| updated, and pushed from. | |
| """ | |
| try: | |
| import huggingface_hub | |
| except (ImportError, ModuleNotFoundError): | |
| raise ImportError( | |
| "Package `huggingface_hub` not found is needed " | |
| "for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'." | |
| ) | |
| path_to_dataset_repo = huggingface_hub.create_repo( | |
| # name=self.dataset_name, https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py | |
| repo_id=self.dataset_name, | |
| token=self.hf_foken, | |
| private=self.dataset_private, | |
| repo_type="dataset", | |
| exist_ok=True, | |
| ) | |
| self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10" | |
| self.components = components | |
| self.flagging_dir = flagging_dir | |
| self.dataset_dir = os.path.join(flagging_dir, self.dataset_name) | |
| self.repo = huggingface_hub.Repository( | |
| local_dir=self.dataset_dir, | |
| clone_from=path_to_dataset_repo, | |
| use_auth_token=self.hf_foken, | |
| ) | |
| self.repo.git_pull(lfs=True) | |
| self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json") | |
| def flag( | |
| self, | |
| flag_data: List[Any], | |
| flag_option: Optional[str] = None, | |
| flag_index: Optional[int] = None, | |
| username: Optional[str] = None, | |
| ) -> int: | |
| self.repo.git_pull(lfs=True) | |
| # Generate unique folder for the flagged sample | |
| unique_name = self.get_unique_name() # unique name for folder | |
| folder_name = os.path.join( | |
| self.dataset_dir, unique_name | |
| ) # unique folder for specific example | |
| os.makedirs(folder_name) | |
| # Now uses the existence of `dataset_infos.json` to determine if new | |
| is_new = not os.path.exists(self.infos_file) | |
| # File previews for certain input and output types | |
| infos, file_preview_types, _ = _get_dataset_features_info( | |
| is_new, self.components | |
| ) | |
| # Generate the row and header corresponding to the flagged sample | |
| csv_data = [] | |
| headers = [] | |
| for component, sample in zip(self.components, flag_data): | |
| headers.append(component.label) | |
| try: | |
| filepath = component.save_flagged( | |
| folder_name, component.label, sample, None | |
| ) | |
| except Exception: | |
| # Could not parse 'sample' (mostly) because it was None and `component.save_flagged` | |
| # does not handle None cases. | |
| # for example: Label (line 3109 of components.py raises an error if data is None) | |
| filepath = None | |
| if isinstance(component, tuple(file_preview_types)): | |
| headers.append(str(component.label) + " file") | |
| csv_data.append( | |
| "{}/resolve/main/{}/{}".format( | |
| self.path_to_dataset_repo, unique_name, filepath | |
| ) | |
| if filepath is not None | |
| else None | |
| ) | |
| csv_data.append(filepath) | |
| headers.append("flag") | |
| csv_data.append(flag_option if flag_option is not None else "") | |
| # Creates metadata dict from row data and dumps it | |
| metadata_dict = { | |
| header: _csv_data for header, _csv_data in zip(headers, csv_data) | |
| } | |
| self.dump_json(metadata_dict, os.path.join(folder_name, "metadata.jsonl")) | |
| if is_new: | |
| json.dump(infos, open(self.infos_file, "w")) | |
| self.repo.push_to_hub(commit_message="Flagged sample {}".format(unique_name)) | |
| return unique_name | |
| def get_unique_name(self): | |
| id = uuid.uuid4() | |
| return str(id) | |
| def dump_json(self, thing: dict, file_path: str) -> None: | |
| with open(file_path, "w+", encoding="utf8") as f: | |
| json.dump(thing, f) |