Spaces:
Runtime error
Runtime error
Add NLI support
Browse files
app.py
CHANGED
|
@@ -36,6 +36,7 @@ TASK_TO_ID = {
|
|
| 36 |
"image_multi_class_classification": 18,
|
| 37 |
"binary_classification": 1,
|
| 38 |
"multi_class_classification": 2,
|
|
|
|
| 39 |
"entity_extraction": 4,
|
| 40 |
"extractive_question_answering": 5,
|
| 41 |
"translation": 6,
|
|
@@ -50,6 +51,7 @@ TASK_TO_DEFAULT_METRICS = {
|
|
| 50 |
"recall",
|
| 51 |
"accuracy",
|
| 52 |
],
|
|
|
|
| 53 |
"entity_extraction": ["precision", "recall", "f1", "accuracy"],
|
| 54 |
"extractive_question_answering": ["f1", "exact_match"],
|
| 55 |
"translation": ["sacrebleu"],
|
|
@@ -117,11 +119,19 @@ SUPPORTED_METRICS = [
|
|
| 117 |
"jordyvl/ece",
|
| 118 |
"lvwerra/ai4code",
|
| 119 |
"lvwerra/amex",
|
| 120 |
-
"lvwerra/test",
|
| 121 |
-
"lvwerra/test_metric",
|
| 122 |
]
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
#######
|
| 126 |
# APP #
|
| 127 |
#######
|
|
@@ -269,6 +279,47 @@ with st.expander("Advanced configuration"):
|
|
| 269 |
col_mapping[text_col] = "text"
|
| 270 |
col_mapping[target_col] = "target"
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
elif selected_task == "entity_extraction":
|
| 273 |
with col1:
|
| 274 |
st.markdown("`tokens` column")
|
|
|
|
| 36 |
"image_multi_class_classification": 18,
|
| 37 |
"binary_classification": 1,
|
| 38 |
"multi_class_classification": 2,
|
| 39 |
+
"natural_language_inference": 22,
|
| 40 |
"entity_extraction": 4,
|
| 41 |
"extractive_question_answering": 5,
|
| 42 |
"translation": 6,
|
|
|
|
| 51 |
"recall",
|
| 52 |
"accuracy",
|
| 53 |
],
|
| 54 |
+
"natural_language_inference": ["f1", "precision", "recall", "auc", "accuracy"],
|
| 55 |
"entity_extraction": ["precision", "recall", "f1", "accuracy"],
|
| 56 |
"extractive_question_answering": ["f1", "exact_match"],
|
| 57 |
"translation": ["sacrebleu"],
|
|
|
|
| 119 |
"jordyvl/ece",
|
| 120 |
"lvwerra/ai4code",
|
| 121 |
"lvwerra/amex",
|
|
|
|
|
|
|
| 122 |
]
|
| 123 |
|
| 124 |
|
| 125 |
+
def get_config_metadata(config, metadata=None):
|
| 126 |
+
if metadata is None:
|
| 127 |
+
return None
|
| 128 |
+
config_metadata = [m for m in metadata if m["config"] == config]
|
| 129 |
+
if len(config_metadata) == 1:
|
| 130 |
+
return config_metadata[0]
|
| 131 |
+
else:
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
#######
|
| 136 |
# APP #
|
| 137 |
#######
|
|
|
|
| 279 |
col_mapping[text_col] = "text"
|
| 280 |
col_mapping[target_col] = "target"
|
| 281 |
|
| 282 |
+
col_mapping = {}
|
| 283 |
+
if selected_task in ["natural_language_inference"]:
|
| 284 |
+
config_metadata = get_config_metadata(selected_config, metadata)
|
| 285 |
+
with col1:
|
| 286 |
+
st.markdown("`text1` column")
|
| 287 |
+
st.text("")
|
| 288 |
+
st.text("")
|
| 289 |
+
st.text("")
|
| 290 |
+
st.text("")
|
| 291 |
+
st.markdown("`text2` column")
|
| 292 |
+
st.text("")
|
| 293 |
+
st.text("")
|
| 294 |
+
st.text("")
|
| 295 |
+
st.text("")
|
| 296 |
+
st.markdown("`target` column")
|
| 297 |
+
with col2:
|
| 298 |
+
text1_col = st.selectbox(
|
| 299 |
+
"This column should contain the first text passage to be classified",
|
| 300 |
+
col_names,
|
| 301 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text1"))
|
| 302 |
+
if config_metadata is not None
|
| 303 |
+
else 0,
|
| 304 |
+
)
|
| 305 |
+
text2_col = st.selectbox(
|
| 306 |
+
"This column should contain the second text passage to be classified",
|
| 307 |
+
col_names,
|
| 308 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text2"))
|
| 309 |
+
if config_metadata is not None
|
| 310 |
+
else 0,
|
| 311 |
+
)
|
| 312 |
+
target_col = st.selectbox(
|
| 313 |
+
"This column should contain the labels associated with the text",
|
| 314 |
+
col_names,
|
| 315 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
| 316 |
+
if config_metadata is not None
|
| 317 |
+
else 0,
|
| 318 |
+
)
|
| 319 |
+
col_mapping[text1_col] = "text1"
|
| 320 |
+
col_mapping[text2_col] = "text2"
|
| 321 |
+
col_mapping[target_col] = "target"
|
| 322 |
+
|
| 323 |
elif selected_task == "entity_extraction":
|
| 324 |
with col1:
|
| 325 |
st.markdown("`tokens` column")
|
utils.py
CHANGED
|
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|
| 12 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
| 13 |
"binary_classification": "text-classification",
|
| 14 |
"multi_class_classification": "text-classification",
|
|
|
|
| 15 |
"entity_extraction": "token-classification",
|
| 16 |
"extractive_question_answering": "question-answering",
|
| 17 |
"translation": "translation",
|
|
|
|
| 12 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
| 13 |
"binary_classification": "text-classification",
|
| 14 |
"multi_class_classification": "text-classification",
|
| 15 |
+
"natural_language_inference": "text-classification",
|
| 16 |
"entity_extraction": "token-classification",
|
| 17 |
"extractive_question_answering": "question-answering",
|
| 18 |
"translation": "translation",
|