Spaces:
Running
Running
| import collections | |
| import json | |
| import logging | |
| import os | |
| import threading | |
| import uuid | |
| import datasets | |
| import gradio as gr | |
| from transformers.pipelines import TextClassificationPipeline | |
| from io_utils import (get_yaml_path, read_column_mapping, save_job_to_pipe, | |
| write_column_mapping, write_log_to_user_file) | |
| from text_classification import (check_model, get_example_prediction, | |
| get_labels_and_features_from_dataset) | |
| from wordings import (CHECK_CONFIG_OR_SPLIT_RAW, | |
| CONFIRM_MAPPING_DETAILS_FAIL_RAW, | |
| MAPPING_STYLED_ERROR_WARNING, get_styled_input) | |
| MAX_LABELS = 40 | |
| MAX_FEATURES = 20 | |
| HF_REPO_ID = "HF_REPO_ID" | |
| HF_SPACE_ID = "SPACE_ID" | |
| HF_WRITE_TOKEN = "HF_WRITE_TOKEN" | |
| HF_GSK_HUB_URL = "GSK_HUB_URL" | |
| HF_GSK_HUB_PROJECT_KEY = "GSK_HUB_PROJECT_KEY" | |
| HF_GSK_HUB_KEY = "GSK_API_KEY" | |
| HF_GSK_HUB_HF_TOKEN = "GSK_HF_TOKEN" | |
| HF_GSK_HUB_UNLOCK_TOKEN = "GSK_HUB_UNLOCK_TOKEN" | |
| def check_dataset_and_get_config(dataset_id): | |
| try: | |
| # write_column_mapping(None, uid) # reset column mapping | |
| configs = datasets.get_dataset_config_names(dataset_id) | |
| return gr.Dropdown(configs, value=configs[0], visible=True) | |
| except Exception: | |
| # Dataset may not exist | |
| pass | |
| def check_dataset_and_get_split(dataset_id, dataset_config): | |
| try: | |
| splits = list(datasets.load_dataset(dataset_id, dataset_config).keys()) | |
| return gr.Dropdown(splits, value=splits[0], visible=True) | |
| except Exception: | |
| # Dataset may not exist | |
| # gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}") | |
| pass | |
| def select_run_mode(run_inf): | |
| if run_inf: | |
| return (gr.update(visible=True), gr.update(value=False)) | |
| else: | |
| return (gr.update(visible=False), gr.update(value=True)) | |
| def deselect_run_inference(run_local): | |
| if run_local: | |
| return (gr.update(visible=False), gr.update(value=False)) | |
| else: | |
| return (gr.update(visible=True), gr.update(value=True)) | |
| def write_column_mapping_to_config(uid, *labels): | |
| # TODO: Substitute 'text' with more features for zero-shot | |
| # we are not using ds features because we only support "text" for now | |
| all_mappings = read_column_mapping(uid) | |
| if labels is None: | |
| return | |
| all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS]) | |
| all_mappings = export_mappings( | |
| all_mappings, | |
| "features", | |
| ["text"], | |
| labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)], | |
| ) | |
| write_column_mapping(all_mappings, uid) | |
| def export_mappings(all_mappings, key, subkeys, values): | |
| if key not in all_mappings.keys(): | |
| all_mappings[key] = dict() | |
| if subkeys is None: | |
| subkeys = list(all_mappings[key].keys()) | |
| if not subkeys: | |
| logging.debug(f"subkeys is empty for {key}") | |
| return all_mappings | |
| for i, subkey in enumerate(subkeys): | |
| if subkey: | |
| all_mappings[key][subkey] = values[i % len(values)] | |
| return all_mappings | |
| def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label, uid): | |
| model_labels = list(model_id2label.values()) | |
| all_mappings = read_column_mapping(uid) | |
| # For flattened raw datasets with no labels | |
| # check if there are shared labels between model and dataset | |
| shared_labels = set(model_labels).intersection(set(ds_labels)) | |
| if shared_labels: | |
| ds_labels = list(shared_labels) | |
| if len(ds_labels) > MAX_LABELS: | |
| ds_labels = ds_labels[:MAX_LABELS] | |
| gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}") | |
| ds_labels.sort() | |
| model_labels.sort() | |
| lables = [ | |
| gr.Dropdown( | |
| label=f"{label}", | |
| choices=model_labels, | |
| value=model_id2label[i % len(model_labels)], | |
| interactive=True, | |
| visible=True, | |
| ) | |
| for i, label in enumerate(ds_labels) | |
| ] | |
| lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))] | |
| all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels) | |
| # TODO: Substitute 'text' with more features for zero-shot | |
| features = [ | |
| gr.Dropdown( | |
| label=f"{feature}", | |
| choices=ds_features, | |
| value=ds_features[0], | |
| interactive=True, | |
| visible=True, | |
| ) | |
| for feature in ["text"] | |
| ] | |
| features += [ | |
| gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features)) | |
| ] | |
| all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features) | |
| write_column_mapping(all_mappings, uid) | |
| return lables + features | |
| def precheck_model_ds_enable_example_btn( | |
| model_id, dataset_id, dataset_config, dataset_split | |
| ): | |
| ppl = check_model(model_id) | |
| if ppl is None or not isinstance(ppl, TextClassificationPipeline): | |
| gr.Warning("Please check your model.") | |
| return gr.update(interactive=False) | |
| ds_labels, ds_features = get_labels_and_features_from_dataset( | |
| dataset_id, dataset_config, dataset_split | |
| ) | |
| if not isinstance(ds_labels, list) or not isinstance(ds_features, list): | |
| gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) | |
| return gr.update(interactive=False) | |
| return gr.update(interactive=True) | |
| def align_columns_and_show_prediction( | |
| model_id, dataset_id, dataset_config, dataset_split, uid | |
| ): | |
| ppl = check_model(model_id) | |
| if ppl is None or not isinstance(ppl, TextClassificationPipeline): | |
| gr.Warning("Please check your model.") | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| gr.update(interactive=False), | |
| *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)], | |
| ) | |
| dropdown_placement = [ | |
| gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES) | |
| ] | |
| if ppl is None: # pipeline not found | |
| gr.Warning("Model not found") | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| gr.update(interactive=False), | |
| *dropdown_placement, | |
| ) | |
| model_id2label = ppl.model.config.id2label | |
| ds_labels, ds_features = get_labels_and_features_from_dataset( | |
| dataset_id, dataset_config, dataset_split | |
| ) | |
| # when dataset does not have labels or features | |
| if not isinstance(ds_labels, list) or not isinstance(ds_features, list): | |
| gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False, open=False), | |
| gr.update(interactive=False), | |
| *dropdown_placement, | |
| ) | |
| column_mappings = list_labels_and_features_from_dataset( | |
| ds_labels, | |
| ds_features, | |
| model_id2label, | |
| uid, | |
| ) | |
| # when labels or features are not aligned | |
| # show manually column mapping | |
| if ( | |
| collections.Counter(model_id2label.values()) != collections.Counter(ds_labels) | |
| or ds_features[0] != "text" | |
| ): | |
| return ( | |
| gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=True, open=True), | |
| gr.update(interactive=True), | |
| *column_mappings, | |
| ) | |
| prediction_input, prediction_output = get_example_prediction( | |
| ppl, dataset_id, dataset_config, dataset_split | |
| ) | |
| return ( | |
| gr.update(value=get_styled_input(prediction_input), visible=True), | |
| gr.update(value=prediction_output, visible=True), | |
| gr.update(visible=True, open=False), | |
| gr.update(interactive=True), | |
| *column_mappings, | |
| ) | |
| def check_column_mapping_keys_validity(all_mappings): | |
| if all_mappings is None: | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return (gr.update(interactive=True), gr.update(visible=False)) | |
| if "labels" not in all_mappings.keys(): | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return (gr.update(interactive=True), gr.update(visible=False)) | |
| def construct_label_and_feature_mapping(all_mappings): | |
| label_mapping = {} | |
| for i, label in zip( | |
| range(len(all_mappings["labels"].keys())), all_mappings["labels"].keys() | |
| ): | |
| label_mapping.update({str(i): label}) | |
| if "features" not in all_mappings.keys(): | |
| gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW) | |
| return (gr.update(interactive=True), gr.update(visible=False)) | |
| feature_mapping = all_mappings["features"] | |
| return label_mapping, feature_mapping | |
| def try_submit(m_id, d_id, config, split, local, inference, inference_token, uid): | |
| all_mappings = read_column_mapping(uid) | |
| check_column_mapping_keys_validity(all_mappings) | |
| label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings) | |
| leaderboard_dataset = None | |
| if os.environ.get("SPACE_ID") == "giskardai/giskard-evaluator": | |
| leaderboard_dataset = "giskard-bot/evaluator-leaderboard" | |
| if local: | |
| inference_type = "hf_pipeline" | |
| if inference and inference_token: | |
| inference_type = "hf_inference_api" | |
| # TODO: Set column mapping for some dataset such as `amazon_polarity` | |
| command = [ | |
| "giskard_scanner", | |
| "--loader", | |
| "huggingface", | |
| "--model", | |
| m_id, | |
| "--dataset", | |
| d_id, | |
| "--dataset_config", | |
| config, | |
| "--dataset_split", | |
| split, | |
| "--output_format", | |
| "markdown", | |
| "--output_portal", | |
| "huggingface", | |
| "--feature_mapping", | |
| json.dumps(feature_mapping), | |
| "--label_mapping", | |
| json.dumps(label_mapping), | |
| "--scan_config", | |
| get_yaml_path(uid), | |
| "--inference_type", | |
| inference_type, | |
| "--inference_api_token", | |
| inference_token, | |
| ] | |
| # The token to publish post | |
| if os.environ.get(HF_WRITE_TOKEN): | |
| command.append("--hf_token") | |
| command.append(os.environ.get(HF_WRITE_TOKEN)) | |
| # The repo to publish post | |
| if os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID): | |
| command.append("--discussion_repo") | |
| # TODO: Replace by the model id | |
| command.append(os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID)) | |
| # The repo to publish for ranking | |
| if leaderboard_dataset: | |
| command.append("--leaderboard_dataset") | |
| command.append(leaderboard_dataset) | |
| # The info to upload to Giskard hub | |
| if os.environ.get(HF_GSK_HUB_KEY): | |
| command.append("--giskard_hub_api_key") | |
| command.append(os.environ.get(HF_GSK_HUB_KEY)) | |
| if os.environ.get(HF_GSK_HUB_URL): | |
| command.append("--giskard_hub_url") | |
| command.append(os.environ.get(HF_GSK_HUB_URL)) | |
| if os.environ.get(HF_GSK_HUB_PROJECT_KEY): | |
| command.append("--giskard_hub_project_key") | |
| command.append(os.environ.get(HF_GSK_HUB_PROJECT_KEY)) | |
| if os.environ.get(HF_GSK_HUB_HF_TOKEN): | |
| command.append("--giskard_hub_hf_token") | |
| command.append(os.environ.get(HF_GSK_HUB_HF_TOKEN)) | |
| if os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN): | |
| command.append("--giskard_hub_unlock_token") | |
| command.append(os.environ.get(HF_GSK_HUB_UNLOCK_TOKEN)) | |
| eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>" | |
| logging.info(f"Start local evaluation on {eval_str}") | |
| save_job_to_pipe(uid, command, eval_str, threading.Lock()) | |
| write_log_to_user_file( | |
| uid, | |
| f"Start local evaluation on {eval_str}. Please wait for your job to start...\n", | |
| ) | |
| gr.Info(f"Start local evaluation on {eval_str}") | |
| return ( | |
| gr.update(interactive=False), # Submit button | |
| gr.update(lines=5, visible=True, interactive=False), | |
| uuid.uuid4(), # Allocate a new uuid | |
| ) | |