Spaces:
Runtime error
Runtime error
add zero shot classification task (#45)
Browse files* add zero shot classification task
* fix default metric list for zero shot classification
* Update enum
* Rename to text_zero_shot_classification
* Merge conflict
* Lewis refactor incorporate
* Adhere to Hub naming conventions
* Incorporate Autotrain changes for deprecated data endpoint
* Sagemaker update
* Sagemaker changes
Co-authored-by: mathemakitten <helen.ngo14@gmail.com>
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
Co-authored-by: helen <31600291+mathemakitten@users.noreply.github.com>
app.py
CHANGED
|
@@ -43,6 +43,7 @@ TASK_TO_ID = {
|
|
| 43 |
"extractive_question_answering": 5,
|
| 44 |
"translation": 6,
|
| 45 |
"summarization": 8,
|
|
|
|
| 46 |
}
|
| 47 |
|
| 48 |
TASK_TO_DEFAULT_METRICS = {
|
|
@@ -65,6 +66,7 @@ TASK_TO_DEFAULT_METRICS = {
|
|
| 65 |
"recall",
|
| 66 |
"accuracy",
|
| 67 |
],
|
|
|
|
| 68 |
}
|
| 69 |
|
| 70 |
AUTOTRAIN_TASK_TO_LANG = {
|
|
@@ -73,6 +75,8 @@ AUTOTRAIN_TASK_TO_LANG = {
|
|
| 73 |
"image_multi_class_classification": "unk",
|
| 74 |
}
|
| 75 |
|
|
|
|
|
|
|
| 76 |
|
| 77 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
| 78 |
|
|
@@ -273,6 +277,45 @@ with st.expander("Advanced configuration"):
|
|
| 273 |
col_mapping[text_col] = "text"
|
| 274 |
col_mapping[target_col] = "target"
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
if selected_task in ["natural_language_inference"]:
|
| 277 |
config_metadata = get_config_metadata(selected_config, metadata)
|
| 278 |
with col1:
|
|
@@ -533,8 +576,10 @@ with st.form(key="form"):
|
|
| 533 |
else "en",
|
| 534 |
"max_models": 5,
|
| 535 |
"instance": {
|
| 536 |
-
"provider": "
|
| 537 |
-
"instance_type":
|
|
|
|
|
|
|
| 538 |
"max_runtime_seconds": 172800,
|
| 539 |
"num_instances": 1,
|
| 540 |
"disk_size_gb": 150,
|
|
@@ -560,17 +605,15 @@ with st.form(key="form"):
|
|
| 560 |
"split": 4, # use "auto" split choice in AutoTrain
|
| 561 |
"col_mapping": col_mapping,
|
| 562 |
"load_config": {"max_size_bytes": 0, "shuffle": False},
|
|
|
|
|
|
|
|
|
|
| 563 |
}
|
| 564 |
data_json_resp = http_post(
|
| 565 |
-
path=f"/projects/{project_json_resp['id']}/data/
|
| 566 |
payload=data_payload,
|
| 567 |
token=HF_TOKEN,
|
| 568 |
domain=AUTOTRAIN_BACKEND_API,
|
| 569 |
-
params={
|
| 570 |
-
"type": "dataset",
|
| 571 |
-
"config_name": selected_config,
|
| 572 |
-
"split_name": selected_split,
|
| 573 |
-
},
|
| 574 |
).json()
|
| 575 |
print(f"INFO -- Dataset creation response: {data_json_resp}")
|
| 576 |
if data_json_resp["download_status"] == 1:
|
|
|
|
| 43 |
"extractive_question_answering": 5,
|
| 44 |
"translation": 6,
|
| 45 |
"summarization": 8,
|
| 46 |
+
"text_zero_shot_classification": 23,
|
| 47 |
}
|
| 48 |
|
| 49 |
TASK_TO_DEFAULT_METRICS = {
|
|
|
|
| 66 |
"recall",
|
| 67 |
"accuracy",
|
| 68 |
],
|
| 69 |
+
"text_zero_shot_classification": ["accuracy", "loss"],
|
| 70 |
}
|
| 71 |
|
| 72 |
AUTOTRAIN_TASK_TO_LANG = {
|
|
|
|
| 75 |
"image_multi_class_classification": "unk",
|
| 76 |
}
|
| 77 |
|
| 78 |
+
AUTOTRAIN_MACHINE = {"text_zero_shot_classification": "r5.16x"}
|
| 79 |
+
|
| 80 |
|
| 81 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
| 82 |
|
|
|
|
| 277 |
col_mapping[text_col] = "text"
|
| 278 |
col_mapping[target_col] = "target"
|
| 279 |
|
| 280 |
+
elif selected_task == "text_zero_shot_classification":
|
| 281 |
+
with col1:
|
| 282 |
+
st.markdown("`text` column")
|
| 283 |
+
st.text("")
|
| 284 |
+
st.text("")
|
| 285 |
+
st.text("")
|
| 286 |
+
st.text("")
|
| 287 |
+
st.markdown("`classes` column")
|
| 288 |
+
st.text("")
|
| 289 |
+
st.text("")
|
| 290 |
+
st.text("")
|
| 291 |
+
st.text("")
|
| 292 |
+
st.markdown("`target` column")
|
| 293 |
+
with col2:
|
| 294 |
+
text_col = st.selectbox(
|
| 295 |
+
"This column should contain the text to be classified",
|
| 296 |
+
col_names,
|
| 297 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
| 298 |
+
if config_metadata is not None
|
| 299 |
+
else 0,
|
| 300 |
+
)
|
| 301 |
+
classes_col = st.selectbox(
|
| 302 |
+
"This column should contain the classes associated with the text",
|
| 303 |
+
col_names,
|
| 304 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "classes"))
|
| 305 |
+
if config_metadata is not None
|
| 306 |
+
else 0,
|
| 307 |
+
)
|
| 308 |
+
target_col = st.selectbox(
|
| 309 |
+
"This column should contain the index of the correct class",
|
| 310 |
+
col_names,
|
| 311 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 312 |
+
if config_metadata is not None
|
| 313 |
+
else 0,
|
| 314 |
+
)
|
| 315 |
+
col_mapping[text_col] = "text"
|
| 316 |
+
col_mapping[classes_col] = "classes"
|
| 317 |
+
col_mapping[target_col] = "target"
|
| 318 |
+
|
| 319 |
if selected_task in ["natural_language_inference"]:
|
| 320 |
config_metadata = get_config_metadata(selected_config, metadata)
|
| 321 |
with col1:
|
|
|
|
| 576 |
else "en",
|
| 577 |
"max_models": 5,
|
| 578 |
"instance": {
|
| 579 |
+
"provider": "sagemaker",
|
| 580 |
+
"instance_type": AUTOTRAIN_MACHINE[selected_task]
|
| 581 |
+
if selected_task in AUTOTRAIN_MACHINE.keys()
|
| 582 |
+
else "p3",
|
| 583 |
"max_runtime_seconds": 172800,
|
| 584 |
"num_instances": 1,
|
| 585 |
"disk_size_gb": 150,
|
|
|
|
| 605 |
"split": 4, # use "auto" split choice in AutoTrain
|
| 606 |
"col_mapping": col_mapping,
|
| 607 |
"load_config": {"max_size_bytes": 0, "shuffle": False},
|
| 608 |
+
"dataset_id": selected_dataset,
|
| 609 |
+
"dataset_config": selected_config,
|
| 610 |
+
"dataset_split": selected_split,
|
| 611 |
}
|
| 612 |
data_json_resp = http_post(
|
| 613 |
+
path=f"/projects/{project_json_resp['id']}/data/dataset",
|
| 614 |
payload=data_payload,
|
| 615 |
token=HF_TOKEN,
|
| 616 |
domain=AUTOTRAIN_BACKEND_API,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
).json()
|
| 618 |
print(f"INFO -- Dataset creation response: {data_json_resp}")
|
| 619 |
if data_json_resp["download_status"] == 1:
|
utils.py
CHANGED
|
@@ -19,6 +19,7 @@ AUTOTRAIN_TASK_TO_HUB_TASK = {
|
|
| 19 |
"summarization": "summarization",
|
| 20 |
"image_binary_classification": "image-classification",
|
| 21 |
"image_multi_class_classification": "image-classification",
|
|
|
|
| 22 |
}
|
| 23 |
|
| 24 |
|
|
@@ -82,7 +83,8 @@ def get_compatible_models(task: str, dataset_ids: List[str]) -> List[str]:
|
|
| 82 |
"""
|
| 83 |
compatible_models = []
|
| 84 |
# Allow any summarization model to be used for summarization tasks
|
| 85 |
-
|
|
|
|
| 86 |
model_filter = ModelFilter(
|
| 87 |
task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
|
| 88 |
library=["transformers", "pytorch"],
|
|
@@ -195,9 +197,11 @@ def create_autotrain_project_name(dataset_id: str, dataset_config: str) -> str:
|
|
| 195 |
"""Creates an AutoTrain project name for the given dataset ID."""
|
| 196 |
# Project names cannot have "/", so we need to format community datasets accordingly
|
| 197 |
dataset_id_formatted = dataset_id.replace("/", "__")
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
| 201 |
|
| 202 |
|
| 203 |
def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
|
|
|
|
| 19 |
"summarization": "summarization",
|
| 20 |
"image_binary_classification": "image-classification",
|
| 21 |
"image_multi_class_classification": "image-classification",
|
| 22 |
+
"text_zero_shot_classification": "text-generation",
|
| 23 |
}
|
| 24 |
|
| 25 |
|
|
|
|
| 83 |
"""
|
| 84 |
compatible_models = []
|
| 85 |
# Allow any summarization model to be used for summarization tasks
|
| 86 |
+
# and allow any text-generation model to be used for text_zero_shot_classification
|
| 87 |
+
if task in ("summarization", "text_zero_shot_classification"):
|
| 88 |
model_filter = ModelFilter(
|
| 89 |
task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
|
| 90 |
library=["transformers", "pytorch"],
|
|
|
|
| 197 |
"""Creates an AutoTrain project name for the given dataset ID."""
|
| 198 |
# Project names cannot have "/", so we need to format community datasets accordingly
|
| 199 |
dataset_id_formatted = dataset_id.replace("/", "__")
|
| 200 |
+
dataset_config_formatted = dataset_config.replace("--", "__")
|
| 201 |
+
# Project names need to be unique, so we append a random string to guarantee this while adhering to naming rules
|
| 202 |
+
basename = f"eval-{dataset_id_formatted}-{dataset_config_formatted}"
|
| 203 |
+
basename = basename[:60] if len(basename) > 60 else basename # Hub naming limitation
|
| 204 |
+
return f"{basename}-{str(uuid.uuid4())[:6]}"
|
| 205 |
|
| 206 |
|
| 207 |
def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
|