Spaces:
Running
Running
| import collections | |
| import logging | |
| import threading | |
| import uuid | |
| import datasets | |
| import gradio as gr | |
| import pandas as pd | |
| import leaderboard | |
| from io_utils import ( | |
| read_column_mapping, | |
| write_column_mapping, | |
| read_scanners, | |
| write_scanners, | |
| ) | |
| from run_jobs import save_job_to_pipe | |
| from text_classification import ( | |
| check_model_task, | |
| preload_hf_inference_api, | |
| get_example_prediction, | |
| get_labels_and_features_from_dataset, | |
| check_hf_token_validity, | |
| HuggingFaceInferenceAPIResponse, | |
| ) | |
| from wordings import ( | |
| EXAMPLE_MODEL_ID, | |
| CHECK_CONFIG_OR_SPLIT_RAW, | |
| CONFIRM_MAPPING_DETAILS_FAIL_RAW, | |
| MAPPING_STYLED_ERROR_WARNING, | |
| NOT_FOUND_DATASET_RAW, | |
| NOT_FOUND_MODEL_RAW, | |
| NOT_TEXT_CLASSIFICATION_MODEL_RAW, | |
| UNMATCHED_MODEL_DATASET_STYLED_ERROR, | |
| CHECK_LOG_SECTION_RAW, | |
| VALIDATED_MODEL_DATASET_STYLED, | |
| get_dataset_fetch_error_raw, | |
| ) | |
| import os | |
| from app_env import HF_WRITE_TOKEN | |
| MAX_LABELS = 40 | |
| MAX_FEATURES = 20 | |
| ds_dict = None | |
| ds_config = None | |
| def get_related_datasets_from_leaderboard(model_id, dataset_id_input): | |
| records = leaderboard.records | |
| model_records = records[records["model_id"] == model_id] | |
| datasets_unique = list(model_records["dataset_id"].unique()) | |
| if len(datasets_unique) == 0: | |
| return gr.update(choices=[]) | |
| if dataset_id_input in datasets_unique: | |
| return gr.update(choices=datasets_unique) | |
| return gr.update(choices=datasets_unique, value="") | |
| logger = logging.getLogger(__file__) | |
| def get_dataset_splits(dataset_id, dataset_config): | |
| try: | |
| splits = datasets.get_dataset_split_names( | |
| dataset_id, dataset_config, trust_remote_code=True | |
| ) | |
| return gr.update(choices=splits, value=splits[0], visible=True) | |
| except Exception as e: | |
| logger.warning( | |
| f"Check your dataset {dataset_id} and config {dataset_config}: {e}" | |
| ) | |
| return gr.update(visible=False) | |
| def check_dataset(dataset_id): | |
| logger.info(f"Loading {dataset_id}") | |
| if not dataset_id or len(dataset_id) == 0: | |
| return (gr.update(visible=False), gr.update(visible=False), "") | |
| try: | |
| configs = datasets.get_dataset_config_names(dataset_id, trust_remote_code=True) | |
| if len(configs) == 0: | |
| return (gr.update(visible=False), gr.update(visible=False), "") | |
| splits = datasets.get_dataset_split_names( | |
| dataset_id, configs[0], trust_remote_code=True | |
| ) | |
| return ( | |
| gr.update(choices=configs, value=configs[0], visible=True), | |
| gr.update(choices=splits, value=splits[0], visible=True), | |
| "", | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Check your dataset {dataset_id}: {e}") | |
| if "doesn't exist on the Hub or cannot be accessed" in str(e): | |
| gr.Warning(NOT_FOUND_DATASET_RAW) | |
| elif "forbidden" in str(e).lower(): | |
| # GSK-2770: illegal name | |
| gr.Warning(get_dataset_fetch_error_raw(e)) | |
| else: | |
| # Unknown error | |
| gr.Warning(get_dataset_fetch_error_raw(e)) | |
| return (gr.update(visible=False), gr.update(visible=False), "") | |
| def empty_column_mapping(uid): | |
| write_column_mapping(None, uid) | |
| def write_column_mapping_to_config(uid, *labels): | |
| # TODO: Substitute 'text' with more features for zero-shot | |
| # we are not using ds features because we only support "text" for now | |
| all_mappings = read_column_mapping(uid) | |
| if labels is None: | |
| return | |
| all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS]) | |
| all_mappings = export_mappings( | |
| all_mappings, | |
| "features", | |
| ["text"], | |
| labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)], | |
| ) | |
| write_column_mapping(all_mappings, uid) | |
| def export_mappings(all_mappings, key, subkeys, values): | |
| if key not in all_mappings.keys(): | |
| all_mappings[key] = dict() | |
| if subkeys is None: | |
| subkeys = list(all_mappings[key].keys()) | |
| if not subkeys: | |
| logging.debug(f"subkeys is empty for {key}") | |
| return all_mappings | |
| for i, subkey in enumerate(subkeys): | |
| if subkey: | |
| all_mappings[key][subkey] = values[i % len(values)] | |
| return all_mappings | |
| def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels, uid): | |
| all_mappings = read_column_mapping(uid) | |
| # For flattened raw datasets with no labels | |
| # check if there are shared labels between model and dataset | |
| shared_labels = set(model_labels).intersection(set(ds_labels)) | |
| if shared_labels: | |
| ds_labels = list(shared_labels) | |
| if len(ds_labels) > MAX_LABELS: | |
| ds_labels = ds_labels[:MAX_LABELS] | |
| gr.Warning( | |
| f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd." | |
| ) | |
| # sort labels to make sure the order is consistent | |
| # prediction gives the order based on probability | |
| ds_labels.sort() | |
| model_labels.sort() | |
| lables = [ | |
| gr.Dropdown( | |
| label=f"{label}", | |
| choices=model_labels, | |
| value=model_labels[i % len(model_labels)], | |
| interactive=True, | |
| visible=True, | |
| ) | |
| for i, label in enumerate(ds_labels) | |
| ] | |
| lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] | |
| all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels) | |
| # TODO: Substitute 'text' with more features for zero-shot | |
| features = [ | |
| gr.Dropdown( | |
| label=f"{feature}", | |
| choices=ds_features, | |
| value=ds_features[0], | |
| interactive=True, | |
| visible=True, | |
| ) | |
| for feature in ["text"] | |
| ] | |
| features += [ | |
| gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features)) | |
| ] | |
| all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features) | |
| write_column_mapping(all_mappings, uid) | |
| return lables + features | |
| def precheck_model_ds_enable_example_btn( | |
| model_id, dataset_id, dataset_config, dataset_split | |
| ): | |
| model_task = check_model_task(model_id) | |
| if not model_task: | |
| # Model might be not found | |
| error_msg_html = f"<p style='color: red;'>{NOT_FOUND_MODEL_RAW}</p>" | |
| if model_id.startswith("http://") or model_id.startswith("https://"): | |
| error_msg = f"Please input your model id, such as {EXAMPLE_MODEL_ID}, instead of URL" | |
| error_msg_html = f"<p style='color: red;'>{error_msg}</p>" | |
| return ( | |
| gr.update(interactive=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(value=error_msg_html, visible=True), | |
| ) | |
| if model_task != "text-classification": | |
| gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW) | |
| return ( | |
| gr.update(interactive=False), | |
| gr.update(value=df, visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update( | |
| value=f"<p style='color: red;'>{NOT_TEXT_CLASSIFICATION_MODEL_RAW}", | |
| visible=True, | |
| ), | |
| ) | |
| preload_hf_inference_api(model_id) | |
| if dataset_config is None or dataset_split is None or len(dataset_config) == 0: | |
| return ( | |
| gr.update(interactive=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| try: | |
| ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True) | |
| df: pd.DataFrame = ds[dataset_split].to_pandas().head(5) | |
| ds_labels, ds_features, _ = get_labels_and_features_from_dataset( | |
| ds[dataset_split] | |
| ) | |
| if not isinstance(ds_labels, list) or not isinstance(ds_features, list): | |
| gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) | |
| return ( | |
| gr.update(interactive=False), | |
| gr.update(value=df, visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| return ( | |
| gr.update(interactive=True), | |
| gr.update(value=df, visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| except Exception as e: | |
| # Config or split wrong | |
| logger.warning( | |
| f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}" | |
| ) | |
| return ( | |
| gr.update(interactive=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| def align_columns_and_show_prediction( | |
| model_id, | |
| dataset_id, | |
| dataset_config, | |
| dataset_split, | |
| uid, | |
| inference_token, | |
| ): | |
| model_task = check_model_task(model_id) | |
| if model_task is None or model_task != "text-classification": | |
| gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| gr.update(interactive=False), | |
| "", | |
| *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)], | |
| ) | |
| dropdown_placement = [ | |
| gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES) | |
| ] | |
| hf_token = os.environ.get(HF_WRITE_TOKEN, default="") | |
| prediction_input, prediction_response = get_example_prediction( | |
| model_id, dataset_id, dataset_config, dataset_split, hf_token | |
| ) | |
| if prediction_input is None or prediction_response is None: | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| gr.update(interactive=False), | |
| "", | |
| *dropdown_placement, | |
| ) | |
| if isinstance(prediction_response, HuggingFaceInferenceAPIResponse): | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| gr.update(interactive=False), | |
| f"Hugging Face Inference API is loading your model. {prediction_response.message}", | |
| *dropdown_placement, | |
| ) | |
| model_labels = list(prediction_response.keys()) | |
| ds = datasets.load_dataset( | |
| dataset_id, dataset_config, split=dataset_split, trust_remote_code=True | |
| ) | |
| ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds) | |
| # when dataset does not have labels or features | |
| if not isinstance(ds_labels, list) or not isinstance(ds_features, list): | |
| gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| gr.update(interactive=False), | |
| "", | |
| *dropdown_placement, | |
| ) | |
| if len(ds_labels) != len(model_labels): | |
| return ( | |
| gr.update(value=UNMATCHED_MODEL_DATASET_STYLED_ERROR, visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| gr.update(interactive=False), | |
| "", | |
| *dropdown_placement, | |
| ) | |
| column_mappings = list_labels_and_features_from_dataset( | |
| ds_labels, | |
| ds_features, | |
| model_labels, | |
| uid, | |
| ) | |
| # when labels or features are not aligned | |
| # show manually column mapping | |
| if ( | |
| collections.Counter(model_labels) != collections.Counter(ds_labels) | |
| or ds_features[0] != "text" | |
| ): | |
| return ( | |
| gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True), | |
| gr.update( | |
| value=prediction_input, | |
| lines=min(len(prediction_input) // 225 + 1, 5), | |
| visible=True, | |
| ), | |
| gr.update(value=prediction_response, visible=True), | |
| gr.update(visible=True, open=True), | |
| gr.update(interactive=(inference_token != "")), | |
| "", | |
| *column_mappings, | |
| ) | |
| return ( | |
| gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True), | |
| gr.update( | |
| value=prediction_input, | |
| lines=min(len(prediction_input) // 225 + 1, 5), | |
| visible=True, | |
| ), | |
| gr.update(value=prediction_response, visible=True), | |
| gr.update(visible=True, open=False), | |
| gr.update(interactive=(inference_token != "")), | |
| "", | |
| *column_mappings, | |
| ) | |
| def check_column_mapping_keys_validity(all_mappings): | |
| if all_mappings is None: | |
| logger.warning("all_mapping is None") | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return False | |
| if "labels" not in all_mappings.keys(): | |
| logger.warning(f"Label mapping is not valid, all_mappings: {all_mappings}") | |
| return False | |
| return True | |
| def enable_run_btn( | |
| uid, inference_token, model_id, dataset_id, dataset_config, dataset_split | |
| ): | |
| if inference_token == "": | |
| logger.warning("Inference API is not enabled") | |
| return gr.update(interactive=False) | |
| if ( | |
| model_id == "" | |
| or dataset_id == "" | |
| or dataset_config == "" | |
| or dataset_split == "" | |
| ): | |
| logger.warning("Model id or dataset id is not selected") | |
| return gr.update(interactive=False) | |
| all_mappings = read_column_mapping(uid) | |
| if not check_column_mapping_keys_validity(all_mappings): | |
| logger.warning("Column mapping is not valid") | |
| return gr.update(interactive=False) | |
| if not check_hf_token_validity(inference_token): | |
| logger.warning("HF token is not valid") | |
| return gr.update(interactive=False) | |
| return gr.update(interactive=True) | |
| def construct_label_and_feature_mapping( | |
| all_mappings, ds_labels, ds_features, label_keys=None | |
| ): | |
| label_mapping = {} | |
| if len(all_mappings["labels"].keys()) != len(ds_labels): | |
| logger.warning( | |
| f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}. | |
| \nall_mappings: {all_mappings}\nds_labels: {ds_labels}""" | |
| ) | |
| if len(all_mappings["features"].keys()) != len(ds_features): | |
| logger.warning( | |
| f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}. | |
| \nall_mappings: {all_mappings}\nds_features: {ds_features}""" | |
| ) | |
| for i, label in zip(range(len(ds_labels)), ds_labels): | |
| # align the saved labels with dataset labels order | |
| label_mapping.update({str(i): all_mappings["labels"][label]}) | |
| if "features" not in all_mappings.keys(): | |
| logger.warning("features not in all_mappings") | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| feature_mapping = all_mappings["features"] | |
| if len(label_keys) > 0: | |
| feature_mapping.update({"label": label_keys[0]}) | |
| return label_mapping, feature_mapping | |
| def show_hf_token_info(token): | |
| valid = check_hf_token_validity(token) | |
| if not valid: | |
| return gr.update(visible=True) | |
| return gr.update(visible=False) | |
| def try_submit(m_id, d_id, config, split, inference_token, uid, verbose): | |
| all_mappings = read_column_mapping(uid) | |
| if not check_column_mapping_keys_validity(all_mappings): | |
| return (gr.update(interactive=True), gr.update(visible=False)) | |
| # get ds labels and features again for alignment | |
| ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True) | |
| ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds) | |
| label_mapping, feature_mapping = construct_label_and_feature_mapping( | |
| all_mappings, ds_labels, ds_features, label_keys | |
| ) | |
| eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" | |
| save_job_to_pipe( | |
| uid, | |
| ( | |
| m_id, | |
| d_id, | |
| config, | |
| split, | |
| inference_token, | |
| uid, | |
| label_mapping, | |
| feature_mapping, | |
| verbose, | |
| ), | |
| eval_str, | |
| threading.Lock(), | |
| ) | |
| gr.Info("Your evaluation has been submitted") | |
| new_uid = uuid.uuid4() | |
| scanners = read_scanners(uid) | |
| write_scanners(scanners, new_uid) | |
| return ( | |
| gr.update(interactive=False), # Submit button | |
| gr.update( | |
| value=f"{CHECK_LOG_SECTION_RAW}Your job id is: {uid}. ", | |
| lines=5, | |
| visible=True, | |
| interactive=False, | |
| ), | |
| new_uid, # Allocate a new uuid | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |