Spaces:
Runtime error
Runtime error
validate labels on click, remove debug messages
Browse files
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -177,13 +177,6 @@ def generate_dataset(
|
|
| 177 |
distiset_results.append(record)
|
| 178 |
|
| 179 |
dataframe = pd.DataFrame(distiset_results)
|
| 180 |
-
if (
|
| 181 |
-
not labels
|
| 182 |
-
or len(set(label.lower().strip() for label in labels if label.strip())) < 2
|
| 183 |
-
):
|
| 184 |
-
raise gr.Error(
|
| 185 |
-
"Please provide at least 2 unique, non-empty labels to classify your text."
|
| 186 |
-
)
|
| 187 |
if multi_label:
|
| 188 |
dataframe["labels"] = dataframe["labels"].apply(
|
| 189 |
lambda x: list(
|
|
@@ -222,10 +215,6 @@ def push_dataset_to_hub(
|
|
| 222 |
pipeline_code: str = "",
|
| 223 |
progress=gr.Progress(),
|
| 224 |
):
|
| 225 |
-
gr.Info(
|
| 226 |
-
message=f"Dataframe columns in push dataset to hub: {dataframe.columns}",
|
| 227 |
-
duration=20,
|
| 228 |
-
)
|
| 229 |
progress(0.0, desc="Validating")
|
| 230 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 231 |
progress(0.3, desc="Preprocessing")
|
|
@@ -284,7 +273,6 @@ def push_dataset(
|
|
| 284 |
num_rows=num_rows,
|
| 285 |
temperature=temperature,
|
| 286 |
)
|
| 287 |
-
gr.Info(message=f"Dataframe columns: {dataframe.columns}", duration=20)
|
| 288 |
push_dataset_to_hub(
|
| 289 |
dataframe,
|
| 290 |
org_name,
|
|
@@ -393,10 +381,13 @@ def push_dataset(
|
|
| 393 |
return ""
|
| 394 |
|
| 395 |
|
| 396 |
-
def validate_input_labels(labels):
|
| 397 |
-
if
|
|
|
|
|
|
|
|
|
|
| 398 |
raise gr.Error(
|
| 399 |
-
f"Please
|
| 400 |
)
|
| 401 |
return labels
|
| 402 |
|
|
@@ -569,6 +560,11 @@ with gr.Blocks() as app:
|
|
| 569 |
)
|
| 570 |
|
| 571 |
btn_apply_to_sample_dataset.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
fn=generate_sample_dataset,
|
| 573 |
inputs=[system_prompt, difficulty, clarity, labels, multi_label],
|
| 574 |
outputs=[dataframe],
|
|
@@ -585,6 +581,11 @@ with gr.Blocks() as app:
|
|
| 585 |
inputs=[org_name, repo_name],
|
| 586 |
outputs=[success_message],
|
| 587 |
show_progress=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
).success(
|
| 589 |
fn=hide_success_message,
|
| 590 |
outputs=[success_message],
|
|
|
|
| 177 |
distiset_results.append(record)
|
| 178 |
|
| 179 |
dataframe = pd.DataFrame(distiset_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
if multi_label:
|
| 181 |
dataframe["labels"] = dataframe["labels"].apply(
|
| 182 |
lambda x: list(
|
|
|
|
| 215 |
pipeline_code: str = "",
|
| 216 |
progress=gr.Progress(),
|
| 217 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
progress(0.0, desc="Validating")
|
| 219 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 220 |
progress(0.3, desc="Preprocessing")
|
|
|
|
| 273 |
num_rows=num_rows,
|
| 274 |
temperature=temperature,
|
| 275 |
)
|
|
|
|
| 276 |
push_dataset_to_hub(
|
| 277 |
dataframe,
|
| 278 |
org_name,
|
|
|
|
| 381 |
return ""
|
| 382 |
|
| 383 |
|
| 384 |
+
def validate_input_labels(labels: List[str]) -> List[str]:
|
| 385 |
+
if (
|
| 386 |
+
not labels
|
| 387 |
+
or len(set(label.lower().strip() for label in labels if label.strip())) < 2
|
| 388 |
+
):
|
| 389 |
raise gr.Error(
|
| 390 |
+
f"Please provide at least 2 unique, non-empty labels to classify your text. You provided {len(labels) if labels else 0}."
|
| 391 |
)
|
| 392 |
return labels
|
| 393 |
|
|
|
|
| 560 |
)
|
| 561 |
|
| 562 |
btn_apply_to_sample_dataset.click(
|
| 563 |
+
fn=validate_input_labels,
|
| 564 |
+
inputs=[labels],
|
| 565 |
+
outputs=[labels],
|
| 566 |
+
show_progress=True,
|
| 567 |
+
).success(
|
| 568 |
fn=generate_sample_dataset,
|
| 569 |
inputs=[system_prompt, difficulty, clarity, labels, multi_label],
|
| 570 |
outputs=[dataframe],
|
|
|
|
| 581 |
inputs=[org_name, repo_name],
|
| 582 |
outputs=[success_message],
|
| 583 |
show_progress=True,
|
| 584 |
+
).success(
|
| 585 |
+
fn=validate_input_labels,
|
| 586 |
+
inputs=[labels],
|
| 587 |
+
outputs=[labels],
|
| 588 |
+
show_progress=True,
|
| 589 |
).success(
|
| 590 |
fn=hide_success_message,
|
| 591 |
outputs=[success_message],
|