Spaces:
Runtime error
Runtime error
Merge pull request #51 from huggingface/add-nli
Browse files
app.py
CHANGED
|
@@ -16,6 +16,7 @@ from utils import (
|
|
| 16 |
create_autotrain_project_name,
|
| 17 |
format_col_mapping,
|
| 18 |
get_compatible_models,
|
|
|
|
| 19 |
get_dataset_card_url,
|
| 20 |
get_key,
|
| 21 |
get_metadata,
|
|
@@ -37,6 +38,7 @@ TASK_TO_ID = {
|
|
| 37 |
"image_multi_class_classification": 18,
|
| 38 |
"binary_classification": 1,
|
| 39 |
"multi_class_classification": 2,
|
|
|
|
| 40 |
"entity_extraction": 4,
|
| 41 |
"extractive_question_answering": 5,
|
| 42 |
"translation": 6,
|
|
@@ -51,6 +53,7 @@ TASK_TO_DEFAULT_METRICS = {
|
|
| 51 |
"recall",
|
| 52 |
"accuracy",
|
| 53 |
],
|
|
|
|
| 54 |
"entity_extraction": ["precision", "recall", "f1", "accuracy"],
|
| 55 |
"extractive_question_answering": ["f1", "exact_match"],
|
| 56 |
"translation": ["sacrebleu"],
|
|
@@ -72,7 +75,6 @@ AUTOTRAIN_TASK_TO_LANG = {
|
|
| 72 |
|
| 73 |
|
| 74 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
| 75 |
-
UNSUPPORTED_TASKS = []
|
| 76 |
|
| 77 |
# Extracted from utils.get_supported_metrics
|
| 78 |
# Hardcoded for now due to speed / caching constraints
|
|
@@ -118,8 +120,6 @@ SUPPORTED_METRICS = [
|
|
| 118 |
"jordyvl/ece",
|
| 119 |
"lvwerra/ai4code",
|
| 120 |
"lvwerra/amex",
|
| 121 |
-
"lvwerra/test",
|
| 122 |
-
"lvwerra/test_metric",
|
| 123 |
]
|
| 124 |
|
| 125 |
|
|
@@ -180,10 +180,6 @@ if metadata is None:
|
|
| 180 |
|
| 181 |
with st.expander("Advanced configuration"):
|
| 182 |
# Select task
|
| 183 |
-
# Hack to filter for unsupported tasks
|
| 184 |
-
# TODO(lewtun): remove this once we have SQuAD metrics support
|
| 185 |
-
if metadata is not None and metadata[0]["task_id"] in UNSUPPORTED_TASKS:
|
| 186 |
-
metadata = None
|
| 187 |
selected_task = st.selectbox(
|
| 188 |
"Select a task",
|
| 189 |
SUPPORTED_TASKS,
|
|
@@ -201,6 +197,9 @@ with st.expander("Advanced configuration"):
|
|
| 201 |
See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
|
| 202 |
""",
|
| 203 |
)
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
# Select splits
|
| 206 |
splits_resp = http_get(
|
|
@@ -215,8 +214,8 @@ with st.expander("Advanced configuration"):
|
|
| 215 |
if split["config"] == selected_config:
|
| 216 |
split_names.append(split["split"])
|
| 217 |
|
| 218 |
-
if
|
| 219 |
-
eval_split =
|
| 220 |
else:
|
| 221 |
eval_split = None
|
| 222 |
selected_split = st.selectbox(
|
|
@@ -260,16 +259,62 @@ with st.expander("Advanced configuration"):
|
|
| 260 |
text_col = st.selectbox(
|
| 261 |
"This column should contain the text to be classified",
|
| 262 |
col_names,
|
| 263 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 264 |
)
|
| 265 |
target_col = st.selectbox(
|
| 266 |
"This column should contain the labels associated with the text",
|
| 267 |
col_names,
|
| 268 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 269 |
)
|
| 270 |
col_mapping[text_col] = "text"
|
| 271 |
col_mapping[target_col] = "target"
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
elif selected_task == "entity_extraction":
|
| 274 |
with col1:
|
| 275 |
st.markdown("`tokens` column")
|
|
@@ -282,12 +327,16 @@ with st.expander("Advanced configuration"):
|
|
| 282 |
tokens_col = st.selectbox(
|
| 283 |
"This column should contain the array of tokens to be classified",
|
| 284 |
col_names,
|
| 285 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 286 |
)
|
| 287 |
tags_col = st.selectbox(
|
| 288 |
"This column should contain the labels associated with each part of the text",
|
| 289 |
col_names,
|
| 290 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 291 |
)
|
| 292 |
col_mapping[tokens_col] = "tokens"
|
| 293 |
col_mapping[tags_col] = "tags"
|
|
@@ -304,12 +353,16 @@ with st.expander("Advanced configuration"):
|
|
| 304 |
text_col = st.selectbox(
|
| 305 |
"This column should contain the text to be translated",
|
| 306 |
col_names,
|
| 307 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 308 |
)
|
| 309 |
target_col = st.selectbox(
|
| 310 |
"This column should contain the target translation",
|
| 311 |
col_names,
|
| 312 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 313 |
)
|
| 314 |
col_mapping[text_col] = "source"
|
| 315 |
col_mapping[target_col] = "target"
|
|
@@ -326,19 +379,23 @@ with st.expander("Advanced configuration"):
|
|
| 326 |
text_col = st.selectbox(
|
| 327 |
"This column should contain the text to be summarized",
|
| 328 |
col_names,
|
| 329 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 330 |
)
|
| 331 |
target_col = st.selectbox(
|
| 332 |
"This column should contain the target summary",
|
| 333 |
col_names,
|
| 334 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 335 |
)
|
| 336 |
col_mapping[text_col] = "text"
|
| 337 |
col_mapping[target_col] = "target"
|
| 338 |
|
| 339 |
elif selected_task == "extractive_question_answering":
|
| 340 |
-
if
|
| 341 |
-
col_mapping =
|
| 342 |
# Hub YAML parser converts periods to hyphens, so we remap them here
|
| 343 |
col_mapping = format_col_mapping(col_mapping)
|
| 344 |
with col1:
|
|
@@ -362,22 +419,24 @@ with st.expander("Advanced configuration"):
|
|
| 362 |
context_col = st.selectbox(
|
| 363 |
"This column should contain the question's context",
|
| 364 |
col_names,
|
| 365 |
-
index=col_names.index(get_key(col_mapping, "context")) if
|
| 366 |
)
|
| 367 |
question_col = st.selectbox(
|
| 368 |
"This column should contain the question to be answered, given the context",
|
| 369 |
col_names,
|
| 370 |
-
index=col_names.index(get_key(col_mapping, "question")) if
|
| 371 |
)
|
| 372 |
answers_text_col = st.selectbox(
|
| 373 |
"This column should contain example answers to the question, extracted from the context",
|
| 374 |
col_names,
|
| 375 |
-
index=col_names.index(get_key(col_mapping, "answers.text")) if
|
| 376 |
)
|
| 377 |
answers_start_col = st.selectbox(
|
| 378 |
"This column should contain the indices in the context of the first character of each `answers.text`",
|
| 379 |
col_names,
|
| 380 |
-
index=col_names.index(get_key(col_mapping, "answers.answer_start"))
|
|
|
|
|
|
|
| 381 |
)
|
| 382 |
col_mapping[context_col] = "context"
|
| 383 |
col_mapping[question_col] = "question"
|
|
@@ -395,12 +454,16 @@ with st.expander("Advanced configuration"):
|
|
| 395 |
image_col = st.selectbox(
|
| 396 |
"This column should contain the images to be classified",
|
| 397 |
col_names,
|
| 398 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 399 |
)
|
| 400 |
target_col = st.selectbox(
|
| 401 |
"This column should contain the labels associated with the images",
|
| 402 |
col_names,
|
| 403 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 404 |
)
|
| 405 |
col_mapping[image_col] = "image"
|
| 406 |
col_mapping[target_col] = "target"
|
|
|
|
| 16 |
create_autotrain_project_name,
|
| 17 |
format_col_mapping,
|
| 18 |
get_compatible_models,
|
| 19 |
+
get_config_metadata,
|
| 20 |
get_dataset_card_url,
|
| 21 |
get_key,
|
| 22 |
get_metadata,
|
|
|
|
| 38 |
"image_multi_class_classification": 18,
|
| 39 |
"binary_classification": 1,
|
| 40 |
"multi_class_classification": 2,
|
| 41 |
+
"natural_language_inference": 22,
|
| 42 |
"entity_extraction": 4,
|
| 43 |
"extractive_question_answering": 5,
|
| 44 |
"translation": 6,
|
|
|
|
| 53 |
"recall",
|
| 54 |
"accuracy",
|
| 55 |
],
|
| 56 |
+
"natural_language_inference": ["f1", "precision", "recall", "auc", "accuracy"],
|
| 57 |
"entity_extraction": ["precision", "recall", "f1", "accuracy"],
|
| 58 |
"extractive_question_answering": ["f1", "exact_match"],
|
| 59 |
"translation": ["sacrebleu"],
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
|
|
|
| 78 |
|
| 79 |
# Extracted from utils.get_supported_metrics
|
| 80 |
# Hardcoded for now due to speed / caching constraints
|
|
|
|
| 120 |
"jordyvl/ece",
|
| 121 |
"lvwerra/ai4code",
|
| 122 |
"lvwerra/amex",
|
|
|
|
|
|
|
| 123 |
]
|
| 124 |
|
| 125 |
|
|
|
|
| 180 |
|
| 181 |
with st.expander("Advanced configuration"):
|
| 182 |
# Select task
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
selected_task = st.selectbox(
|
| 184 |
"Select a task",
|
| 185 |
SUPPORTED_TASKS,
|
|
|
|
| 197 |
See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
|
| 198 |
""",
|
| 199 |
)
|
| 200 |
+
# Some datasets have multiple metadata (one per config), so we grab the one associated with the selected config
|
| 201 |
+
config_metadata = get_config_metadata(selected_config, metadata)
|
| 202 |
+
print(f"INFO -- Config metadata: {config_metadata}")
|
| 203 |
|
| 204 |
# Select splits
|
| 205 |
splits_resp = http_get(
|
|
|
|
| 214 |
if split["config"] == selected_config:
|
| 215 |
split_names.append(split["split"])
|
| 216 |
|
| 217 |
+
if config_metadata is not None:
|
| 218 |
+
eval_split = config_metadata["splits"].get("eval_split", None)
|
| 219 |
else:
|
| 220 |
eval_split = None
|
| 221 |
selected_split = st.selectbox(
|
|
|
|
| 259 |
text_col = st.selectbox(
|
| 260 |
"This column should contain the text to be classified",
|
| 261 |
col_names,
|
| 262 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
| 263 |
+
if config_metadata is not None
|
| 264 |
+
else 0,
|
| 265 |
)
|
| 266 |
target_col = st.selectbox(
|
| 267 |
"This column should contain the labels associated with the text",
|
| 268 |
col_names,
|
| 269 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 270 |
+
if config_metadata is not None
|
| 271 |
+
else 0,
|
| 272 |
)
|
| 273 |
col_mapping[text_col] = "text"
|
| 274 |
col_mapping[target_col] = "target"
|
| 275 |
|
| 276 |
+
if selected_task in ["natural_language_inference"]:
|
| 277 |
+
config_metadata = get_config_metadata(selected_config, metadata)
|
| 278 |
+
with col1:
|
| 279 |
+
st.markdown("`text1` column")
|
| 280 |
+
st.text("")
|
| 281 |
+
st.text("")
|
| 282 |
+
st.text("")
|
| 283 |
+
st.text("")
|
| 284 |
+
st.text("")
|
| 285 |
+
st.markdown("`text2` column")
|
| 286 |
+
st.text("")
|
| 287 |
+
st.text("")
|
| 288 |
+
st.text("")
|
| 289 |
+
st.text("")
|
| 290 |
+
st.text("")
|
| 291 |
+
st.markdown("`target` column")
|
| 292 |
+
with col2:
|
| 293 |
+
text1_col = st.selectbox(
|
| 294 |
+
"This column should contain the first text passage to be classified",
|
| 295 |
+
col_names,
|
| 296 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text1"))
|
| 297 |
+
if config_metadata is not None
|
| 298 |
+
else 0,
|
| 299 |
+
)
|
| 300 |
+
text2_col = st.selectbox(
|
| 301 |
+
"This column should contain the second text passage to be classified",
|
| 302 |
+
col_names,
|
| 303 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text2"))
|
| 304 |
+
if config_metadata is not None
|
| 305 |
+
else 0,
|
| 306 |
+
)
|
| 307 |
+
target_col = st.selectbox(
|
| 308 |
+
"This column should contain the labels associated with the text",
|
| 309 |
+
col_names,
|
| 310 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 311 |
+
if config_metadata is not None
|
| 312 |
+
else 0,
|
| 313 |
+
)
|
| 314 |
+
col_mapping[text1_col] = "text1"
|
| 315 |
+
col_mapping[text2_col] = "text2"
|
| 316 |
+
col_mapping[target_col] = "target"
|
| 317 |
+
|
| 318 |
elif selected_task == "entity_extraction":
|
| 319 |
with col1:
|
| 320 |
st.markdown("`tokens` column")
|
|
|
|
| 327 |
tokens_col = st.selectbox(
|
| 328 |
"This column should contain the array of tokens to be classified",
|
| 329 |
col_names,
|
| 330 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "tokens"))
|
| 331 |
+
if config_metadata is not None
|
| 332 |
+
else 0,
|
| 333 |
)
|
| 334 |
tags_col = st.selectbox(
|
| 335 |
"This column should contain the labels associated with each part of the text",
|
| 336 |
col_names,
|
| 337 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "tags"))
|
| 338 |
+
if config_metadata is not None
|
| 339 |
+
else 0,
|
| 340 |
)
|
| 341 |
col_mapping[tokens_col] = "tokens"
|
| 342 |
col_mapping[tags_col] = "tags"
|
|
|
|
| 353 |
text_col = st.selectbox(
|
| 354 |
"This column should contain the text to be translated",
|
| 355 |
col_names,
|
| 356 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "source"))
|
| 357 |
+
if config_metadata is not None
|
| 358 |
+
else 0,
|
| 359 |
)
|
| 360 |
target_col = st.selectbox(
|
| 361 |
"This column should contain the target translation",
|
| 362 |
col_names,
|
| 363 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 364 |
+
if config_metadata is not None
|
| 365 |
+
else 0,
|
| 366 |
)
|
| 367 |
col_mapping[text_col] = "source"
|
| 368 |
col_mapping[target_col] = "target"
|
|
|
|
| 379 |
text_col = st.selectbox(
|
| 380 |
"This column should contain the text to be summarized",
|
| 381 |
col_names,
|
| 382 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
| 383 |
+
if config_metadata is not None
|
| 384 |
+
else 0,
|
| 385 |
)
|
| 386 |
target_col = st.selectbox(
|
| 387 |
"This column should contain the target summary",
|
| 388 |
col_names,
|
| 389 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 390 |
+
if config_metadata is not None
|
| 391 |
+
else 0,
|
| 392 |
)
|
| 393 |
col_mapping[text_col] = "text"
|
| 394 |
col_mapping[target_col] = "target"
|
| 395 |
|
| 396 |
elif selected_task == "extractive_question_answering":
|
| 397 |
+
if config_metadata is not None:
|
| 398 |
+
col_mapping = config_metadata["col_mapping"]
|
| 399 |
# Hub YAML parser converts periods to hyphens, so we remap them here
|
| 400 |
col_mapping = format_col_mapping(col_mapping)
|
| 401 |
with col1:
|
|
|
|
| 419 |
context_col = st.selectbox(
|
| 420 |
"This column should contain the question's context",
|
| 421 |
col_names,
|
| 422 |
+
index=col_names.index(get_key(col_mapping, "context")) if config_metadata is not None else 0,
|
| 423 |
)
|
| 424 |
question_col = st.selectbox(
|
| 425 |
"This column should contain the question to be answered, given the context",
|
| 426 |
col_names,
|
| 427 |
+
index=col_names.index(get_key(col_mapping, "question")) if config_metadata is not None else 0,
|
| 428 |
)
|
| 429 |
answers_text_col = st.selectbox(
|
| 430 |
"This column should contain example answers to the question, extracted from the context",
|
| 431 |
col_names,
|
| 432 |
+
index=col_names.index(get_key(col_mapping, "answers.text")) if config_metadata is not None else 0,
|
| 433 |
)
|
| 434 |
answers_start_col = st.selectbox(
|
| 435 |
"This column should contain the indices in the context of the first character of each `answers.text`",
|
| 436 |
col_names,
|
| 437 |
+
index=col_names.index(get_key(col_mapping, "answers.answer_start"))
|
| 438 |
+
if config_metadata is not None
|
| 439 |
+
else 0,
|
| 440 |
)
|
| 441 |
col_mapping[context_col] = "context"
|
| 442 |
col_mapping[question_col] = "question"
|
|
|
|
| 454 |
image_col = st.selectbox(
|
| 455 |
"This column should contain the images to be classified",
|
| 456 |
col_names,
|
| 457 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "image"))
|
| 458 |
+
if config_metadata is not None
|
| 459 |
+
else 0,
|
| 460 |
)
|
| 461 |
target_col = st.selectbox(
|
| 462 |
"This column should contain the labels associated with the images",
|
| 463 |
col_names,
|
| 464 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 465 |
+
if config_metadata is not None
|
| 466 |
+
else 0,
|
| 467 |
)
|
| 468 |
col_mapping[image_col] = "image"
|
| 469 |
col_mapping[target_col] = "target"
|
utils.py
CHANGED
|
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|
| 12 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
| 13 |
"binary_classification": "text-classification",
|
| 14 |
"multi_class_classification": "text-classification",
|
|
|
|
| 15 |
"entity_extraction": "token-classification",
|
| 16 |
"extractive_question_answering": "question-answering",
|
| 17 |
"translation": "translation",
|
|
@@ -197,3 +198,14 @@ def create_autotrain_project_name(dataset_id: str) -> str:
|
|
| 197 |
# Project names need to be unique, so we append a random string to guarantee this
|
| 198 |
project_id = str(uuid.uuid4())[:8]
|
| 199 |
return f"eval-project-{dataset_id_formatted}-{project_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
| 13 |
"binary_classification": "text-classification",
|
| 14 |
"multi_class_classification": "text-classification",
|
| 15 |
+
"natural_language_inference": "text-classification",
|
| 16 |
"entity_extraction": "token-classification",
|
| 17 |
"extractive_question_answering": "question-answering",
|
| 18 |
"translation": "translation",
|
|
|
|
| 198 |
# Project names need to be unique, so we append a random string to guarantee this
|
| 199 |
project_id = str(uuid.uuid4())[:8]
|
| 200 |
return f"eval-project-{dataset_id_formatted}-{project_id}"
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
|
| 204 |
+
"""Gets the dataset card metadata for the given config."""
|
| 205 |
+
if metadata is None:
|
| 206 |
+
return None
|
| 207 |
+
config_metadata = [m for m in metadata if m["config"] == config]
|
| 208 |
+
if len(config_metadata) >= 1:
|
| 209 |
+
return config_metadata[0]
|
| 210 |
+
else:
|
| 211 |
+
return None
|