Spaces:
Runtime error
Runtime error
Commit
ยท
5d3be21
1
Parent(s):
105084b
update textcat prompt based on multi_label
Browse files
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -104,8 +104,11 @@ def generate_dataset(
|
|
| 104 |
temperature=temperature,
|
| 105 |
is_sample=is_sample,
|
| 106 |
)
|
|
|
|
|
|
|
|
|
|
| 107 |
labeller_generator = get_labeller_generator(
|
| 108 |
-
system_prompt=
|
| 109 |
labels=labels,
|
| 110 |
multi_label=multi_label,
|
| 111 |
)
|
|
@@ -181,16 +184,20 @@ def generate_dataset(
|
|
| 181 |
[
|
| 182 |
label.lower().strip()
|
| 183 |
for label in x
|
| 184 |
-
if label.lower().strip() in labels
|
| 185 |
]
|
| 186 |
)
|
| 187 |
)
|
| 188 |
)
|
|
|
|
| 189 |
else:
|
| 190 |
dataframe = dataframe.rename(columns={"labels": "label"})
|
| 191 |
dataframe["label"] = dataframe["label"].apply(
|
| 192 |
-
lambda x: x.lower().strip()
|
|
|
|
|
|
|
| 193 |
)
|
|
|
|
| 194 |
|
| 195 |
progress(1.0, desc="Dataset created")
|
| 196 |
return dataframe
|
|
|
|
| 104 |
temperature=temperature,
|
| 105 |
is_sample=is_sample,
|
| 106 |
)
|
| 107 |
+
updated_system_prompt = f"{system_prompt}. Optional labels: {', '.join(labels)}."
|
| 108 |
+
if multi_label:
|
| 109 |
+
updated_system_prompt = f"{updated_system_prompt}. Only apply relevant labels. Applying less labels is better than applying too many labels."
|
| 110 |
labeller_generator = get_labeller_generator(
|
| 111 |
+
system_prompt=updated_system_prompt,
|
| 112 |
labels=labels,
|
| 113 |
multi_label=multi_label,
|
| 114 |
)
|
|
|
|
| 184 |
[
|
| 185 |
label.lower().strip()
|
| 186 |
for label in x
|
| 187 |
+
if label is not None and label.lower().strip() in labels
|
| 188 |
]
|
| 189 |
)
|
| 190 |
)
|
| 191 |
)
|
| 192 |
+
dataframe = dataframe[dataframe["labels"].notna()]
|
| 193 |
else:
|
| 194 |
dataframe = dataframe.rename(columns={"labels": "label"})
|
| 195 |
dataframe["label"] = dataframe["label"].apply(
|
| 196 |
+
lambda x: x.lower().strip()
|
| 197 |
+
if x and x.lower().strip() in labels
|
| 198 |
+
else random.choice(labels)
|
| 199 |
)
|
| 200 |
+
dataframe = dataframe[dataframe["text"].notna()]
|
| 201 |
|
| 202 |
progress(1.0, desc="Dataset created")
|
| 203 |
return dataframe
|