Spaces:
Runtime error
Runtime error
Handle multiple configs
Browse files
app.py
CHANGED
|
@@ -16,6 +16,7 @@ from utils import (
|
|
| 16 |
create_autotrain_project_name,
|
| 17 |
format_col_mapping,
|
| 18 |
get_compatible_models,
|
|
|
|
| 19 |
get_dataset_card_url,
|
| 20 |
get_key,
|
| 21 |
get_metadata,
|
|
@@ -123,16 +124,6 @@ SUPPORTED_METRICS = [
|
|
| 123 |
]
|
| 124 |
|
| 125 |
|
| 126 |
-
def get_config_metadata(config, metadata=None):
|
| 127 |
-
if metadata is None:
|
| 128 |
-
return None
|
| 129 |
-
config_metadata = [m for m in metadata if m["config"] == config]
|
| 130 |
-
if len(config_metadata) == 1:
|
| 131 |
-
return config_metadata[0]
|
| 132 |
-
else:
|
| 133 |
-
return None
|
| 134 |
-
|
| 135 |
-
|
| 136 |
#######
|
| 137 |
# APP #
|
| 138 |
#######
|
|
@@ -190,10 +181,6 @@ if metadata is None:
|
|
| 190 |
|
| 191 |
with st.expander("Advanced configuration"):
|
| 192 |
# Select task
|
| 193 |
-
# Hack to filter for unsupported tasks
|
| 194 |
-
# TODO(lewtun): remove this once we have SQuAD metrics support
|
| 195 |
-
if metadata is not None and metadata[0]["task_id"] in UNSUPPORTED_TASKS:
|
| 196 |
-
metadata = None
|
| 197 |
selected_task = st.selectbox(
|
| 198 |
"Select a task",
|
| 199 |
SUPPORTED_TASKS,
|
|
@@ -211,6 +198,9 @@ with st.expander("Advanced configuration"):
|
|
| 211 |
See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
|
| 212 |
""",
|
| 213 |
)
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
# Select splits
|
| 216 |
splits_resp = http_get(
|
|
@@ -225,8 +215,8 @@ with st.expander("Advanced configuration"):
|
|
| 225 |
if split["config"] == selected_config:
|
| 226 |
split_names.append(split["split"])
|
| 227 |
|
| 228 |
-
if
|
| 229 |
-
eval_split =
|
| 230 |
else:
|
| 231 |
eval_split = None
|
| 232 |
selected_split = st.selectbox(
|
|
@@ -270,12 +260,16 @@ with st.expander("Advanced configuration"):
|
|
| 270 |
text_col = st.selectbox(
|
| 271 |
"This column should contain the text to be classified",
|
| 272 |
col_names,
|
| 273 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 274 |
)
|
| 275 |
target_col = st.selectbox(
|
| 276 |
"This column should contain the labels associated with the text",
|
| 277 |
col_names,
|
| 278 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 279 |
)
|
| 280 |
col_mapping[text_col] = "text"
|
| 281 |
col_mapping[target_col] = "target"
|
|
@@ -289,11 +283,13 @@ with st.expander("Advanced configuration"):
|
|
| 289 |
st.text("")
|
| 290 |
st.text("")
|
| 291 |
st.text("")
|
|
|
|
| 292 |
st.markdown("`text2` column")
|
| 293 |
st.text("")
|
| 294 |
st.text("")
|
| 295 |
st.text("")
|
| 296 |
st.text("")
|
|
|
|
| 297 |
st.markdown("`target` column")
|
| 298 |
with col2:
|
| 299 |
text1_col = st.selectbox(
|
|
@@ -333,12 +329,16 @@ with st.expander("Advanced configuration"):
|
|
| 333 |
tokens_col = st.selectbox(
|
| 334 |
"This column should contain the array of tokens to be classified",
|
| 335 |
col_names,
|
| 336 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 337 |
)
|
| 338 |
tags_col = st.selectbox(
|
| 339 |
"This column should contain the labels associated with each part of the text",
|
| 340 |
col_names,
|
| 341 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 342 |
)
|
| 343 |
col_mapping[tokens_col] = "tokens"
|
| 344 |
col_mapping[tags_col] = "tags"
|
|
@@ -355,12 +355,16 @@ with st.expander("Advanced configuration"):
|
|
| 355 |
text_col = st.selectbox(
|
| 356 |
"This column should contain the text to be translated",
|
| 357 |
col_names,
|
| 358 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 359 |
)
|
| 360 |
target_col = st.selectbox(
|
| 361 |
"This column should contain the target translation",
|
| 362 |
col_names,
|
| 363 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 364 |
)
|
| 365 |
col_mapping[text_col] = "source"
|
| 366 |
col_mapping[target_col] = "target"
|
|
@@ -377,19 +381,23 @@ with st.expander("Advanced configuration"):
|
|
| 377 |
text_col = st.selectbox(
|
| 378 |
"This column should contain the text to be summarized",
|
| 379 |
col_names,
|
| 380 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 381 |
)
|
| 382 |
target_col = st.selectbox(
|
| 383 |
"This column should contain the target summary",
|
| 384 |
col_names,
|
| 385 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 386 |
)
|
| 387 |
col_mapping[text_col] = "text"
|
| 388 |
col_mapping[target_col] = "target"
|
| 389 |
|
| 390 |
elif selected_task == "extractive_question_answering":
|
| 391 |
-
if
|
| 392 |
-
col_mapping =
|
| 393 |
# Hub YAML parser converts periods to hyphens, so we remap them here
|
| 394 |
col_mapping = format_col_mapping(col_mapping)
|
| 395 |
with col1:
|
|
@@ -413,22 +421,24 @@ with st.expander("Advanced configuration"):
|
|
| 413 |
context_col = st.selectbox(
|
| 414 |
"This column should contain the question's context",
|
| 415 |
col_names,
|
| 416 |
-
index=col_names.index(get_key(col_mapping, "context")) if
|
| 417 |
)
|
| 418 |
question_col = st.selectbox(
|
| 419 |
"This column should contain the question to be answered, given the context",
|
| 420 |
col_names,
|
| 421 |
-
index=col_names.index(get_key(col_mapping, "question")) if
|
| 422 |
)
|
| 423 |
answers_text_col = st.selectbox(
|
| 424 |
"This column should contain example answers to the question, extracted from the context",
|
| 425 |
col_names,
|
| 426 |
-
index=col_names.index(get_key(col_mapping, "answers.text")) if
|
| 427 |
)
|
| 428 |
answers_start_col = st.selectbox(
|
| 429 |
"This column should contain the indices in the context of the first character of each `answers.text`",
|
| 430 |
col_names,
|
| 431 |
-
index=col_names.index(get_key(col_mapping, "answers.answer_start"))
|
|
|
|
|
|
|
| 432 |
)
|
| 433 |
col_mapping[context_col] = "context"
|
| 434 |
col_mapping[question_col] = "question"
|
|
@@ -446,12 +456,16 @@ with st.expander("Advanced configuration"):
|
|
| 446 |
image_col = st.selectbox(
|
| 447 |
"This column should contain the images to be classified",
|
| 448 |
col_names,
|
| 449 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 450 |
)
|
| 451 |
target_col = st.selectbox(
|
| 452 |
"This column should contain the labels associated with the images",
|
| 453 |
col_names,
|
| 454 |
-
index=col_names.index(get_key(
|
|
|
|
|
|
|
| 455 |
)
|
| 456 |
col_mapping[image_col] = "image"
|
| 457 |
col_mapping[target_col] = "target"
|
|
|
|
| 16 |
create_autotrain_project_name,
|
| 17 |
format_col_mapping,
|
| 18 |
get_compatible_models,
|
| 19 |
+
get_config_metadata,
|
| 20 |
get_dataset_card_url,
|
| 21 |
get_key,
|
| 22 |
get_metadata,
|
|
|
|
| 124 |
]
|
| 125 |
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
#######
|
| 128 |
# APP #
|
| 129 |
#######
|
|
|
|
| 181 |
|
| 182 |
with st.expander("Advanced configuration"):
|
| 183 |
# Select task
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
selected_task = st.selectbox(
|
| 185 |
"Select a task",
|
| 186 |
SUPPORTED_TASKS,
|
|
|
|
| 198 |
See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
|
| 199 |
""",
|
| 200 |
)
|
| 201 |
+
# Get metadata for config
|
| 202 |
+
config_metadata = get_config_metadata(selected_config, metadata)
|
| 203 |
+
print(f"INFO -- Config metadata: {config_metadata}")
|
| 204 |
|
| 205 |
# Select splits
|
| 206 |
splits_resp = http_get(
|
|
|
|
| 215 |
if split["config"] == selected_config:
|
| 216 |
split_names.append(split["split"])
|
| 217 |
|
| 218 |
+
if config_metadata is not None:
|
| 219 |
+
eval_split = config_metadata["splits"].get("eval_split", None)
|
| 220 |
else:
|
| 221 |
eval_split = None
|
| 222 |
selected_split = st.selectbox(
|
|
|
|
| 260 |
text_col = st.selectbox(
|
| 261 |
"This column should contain the text to be classified",
|
| 262 |
col_names,
|
| 263 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
| 264 |
+
if config_metadata is not None
|
| 265 |
+
else 0,
|
| 266 |
)
|
| 267 |
target_col = st.selectbox(
|
| 268 |
"This column should contain the labels associated with the text",
|
| 269 |
col_names,
|
| 270 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 271 |
+
if config_metadata is not None
|
| 272 |
+
else 0,
|
| 273 |
)
|
| 274 |
col_mapping[text_col] = "text"
|
| 275 |
col_mapping[target_col] = "target"
|
|
|
|
| 283 |
st.text("")
|
| 284 |
st.text("")
|
| 285 |
st.text("")
|
| 286 |
+
st.text("")
|
| 287 |
st.markdown("`text2` column")
|
| 288 |
st.text("")
|
| 289 |
st.text("")
|
| 290 |
st.text("")
|
| 291 |
st.text("")
|
| 292 |
+
st.text("")
|
| 293 |
st.markdown("`target` column")
|
| 294 |
with col2:
|
| 295 |
text1_col = st.selectbox(
|
|
|
|
| 329 |
tokens_col = st.selectbox(
|
| 330 |
"This column should contain the array of tokens to be classified",
|
| 331 |
col_names,
|
| 332 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "tokens"))
|
| 333 |
+
if config_metadata is not None
|
| 334 |
+
else 0,
|
| 335 |
)
|
| 336 |
tags_col = st.selectbox(
|
| 337 |
"This column should contain the labels associated with each part of the text",
|
| 338 |
col_names,
|
| 339 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "tags"))
|
| 340 |
+
if config_metadata is not None
|
| 341 |
+
else 0,
|
| 342 |
)
|
| 343 |
col_mapping[tokens_col] = "tokens"
|
| 344 |
col_mapping[tags_col] = "tags"
|
|
|
|
| 355 |
text_col = st.selectbox(
|
| 356 |
"This column should contain the text to be translated",
|
| 357 |
col_names,
|
| 358 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "source"))
|
| 359 |
+
if config_metadata is not None
|
| 360 |
+
else 0,
|
| 361 |
)
|
| 362 |
target_col = st.selectbox(
|
| 363 |
"This column should contain the target translation",
|
| 364 |
col_names,
|
| 365 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 366 |
+
if config_metadata is not None
|
| 367 |
+
else 0,
|
| 368 |
)
|
| 369 |
col_mapping[text_col] = "source"
|
| 370 |
col_mapping[target_col] = "target"
|
|
|
|
| 381 |
text_col = st.selectbox(
|
| 382 |
"This column should contain the text to be summarized",
|
| 383 |
col_names,
|
| 384 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
| 385 |
+
if config_metadata is not None
|
| 386 |
+
else 0,
|
| 387 |
)
|
| 388 |
target_col = st.selectbox(
|
| 389 |
"This column should contain the target summary",
|
| 390 |
col_names,
|
| 391 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 392 |
+
if config_metadata is not None
|
| 393 |
+
else 0,
|
| 394 |
)
|
| 395 |
col_mapping[text_col] = "text"
|
| 396 |
col_mapping[target_col] = "target"
|
| 397 |
|
| 398 |
elif selected_task == "extractive_question_answering":
|
| 399 |
+
if config_metadata is not None:
|
| 400 |
+
col_mapping = config_metadata["col_mapping"]
|
| 401 |
# Hub YAML parser converts periods to hyphens, so we remap them here
|
| 402 |
col_mapping = format_col_mapping(col_mapping)
|
| 403 |
with col1:
|
|
|
|
| 421 |
context_col = st.selectbox(
|
| 422 |
"This column should contain the question's context",
|
| 423 |
col_names,
|
| 424 |
+
index=col_names.index(get_key(col_mapping, "context")) if config_metadata is not None else 0,
|
| 425 |
)
|
| 426 |
question_col = st.selectbox(
|
| 427 |
"This column should contain the question to be answered, given the context",
|
| 428 |
col_names,
|
| 429 |
+
index=col_names.index(get_key(col_mapping, "question")) if config_metadata is not None else 0,
|
| 430 |
)
|
| 431 |
answers_text_col = st.selectbox(
|
| 432 |
"This column should contain example answers to the question, extracted from the context",
|
| 433 |
col_names,
|
| 434 |
+
index=col_names.index(get_key(col_mapping, "answers.text")) if config_metadata is not None else 0,
|
| 435 |
)
|
| 436 |
answers_start_col = st.selectbox(
|
| 437 |
"This column should contain the indices in the context of the first character of each `answers.text`",
|
| 438 |
col_names,
|
| 439 |
+
index=col_names.index(get_key(col_mapping, "answers.answer_start"))
|
| 440 |
+
if config_metadata is not None
|
| 441 |
+
else 0,
|
| 442 |
)
|
| 443 |
col_mapping[context_col] = "context"
|
| 444 |
col_mapping[question_col] = "question"
|
|
|
|
| 456 |
image_col = st.selectbox(
|
| 457 |
"This column should contain the images to be classified",
|
| 458 |
col_names,
|
| 459 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "image"))
|
| 460 |
+
if config_metadata is not None
|
| 461 |
+
else 0,
|
| 462 |
)
|
| 463 |
target_col = st.selectbox(
|
| 464 |
"This column should contain the labels associated with the images",
|
| 465 |
col_names,
|
| 466 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 467 |
+
if config_metadata is not None
|
| 468 |
+
else 0,
|
| 469 |
)
|
| 470 |
col_mapping[image_col] = "image"
|
| 471 |
col_mapping[target_col] = "target"
|
utils.py
CHANGED
|
@@ -198,3 +198,14 @@ def create_autotrain_project_name(dataset_id: str) -> str:
|
|
| 198 |
# Project names need to be unique, so we append a random string to guarantee this
|
| 199 |
project_id = str(uuid.uuid4())[:8]
|
| 200 |
return f"eval-project-{dataset_id_formatted}-{project_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
# Project names need to be unique, so we append a random string to guarantee this
|
| 199 |
project_id = str(uuid.uuid4())[:8]
|
| 200 |
return f"eval-project-{dataset_id_formatted}-{project_id}"
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
|
| 204 |
+
"""Gets the dataset card metadata for the given config."""
|
| 205 |
+
if metadata is None:
|
| 206 |
+
return None
|
| 207 |
+
config_metadata = [m for m in metadata if m["config"] == config]
|
| 208 |
+
if len(config_metadata) == 1:
|
| 209 |
+
return config_metadata[0]
|
| 210 |
+
else:
|
| 211 |
+
return None
|