Spaces:
Runtime error
Runtime error
Commit
·
65321e4
1
Parent(s):
351c0ef
update: docstring
Browse files
guardrails_genie/train_classifier.py
CHANGED
|
@@ -79,24 +79,18 @@ def train_binary_classifier(
|
|
| 79 |
entity_name (str): The Weights & Biases entity (user or team).
|
| 80 |
run_name (str): The name of the Weights & Biases run.
|
| 81 |
dataset_repo (str, optional): The Hugging Face dataset repository to load.
|
| 82 |
-
|
| 83 |
-
model_name (str, optional): The pre-trained model to use. Defaults to
|
| 84 |
-
"distilbert/distilbert-base-uncased".
|
| 85 |
prompt_column_name (str, optional): The column name in the dataset containing
|
| 86 |
-
the text prompts.
|
| 87 |
id2label (dict[int, str], optional): Mapping from label IDs to label names.
|
| 88 |
-
Defaults to {0: "SAFE", 1: "INJECTION"}.
|
| 89 |
label2id (dict[str, int], optional): Mapping from label names to label IDs.
|
| 90 |
-
|
| 91 |
-
learning_rate (float, optional): The learning rate for training. Defaults to 1e-5.
|
| 92 |
batch_size (int, optional): The batch size for training and evaluation.
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.01.
|
| 96 |
save_steps (int, optional): The number of steps between model checkpoints.
|
| 97 |
-
Defaults to 1000.
|
| 98 |
streamlit_mode (bool, optional): If True, integrates with Streamlit to display
|
| 99 |
-
a progress bar.
|
| 100 |
|
| 101 |
Returns:
|
| 102 |
dict: The output of the training process, including metrics and model state.
|
|
|
|
| 79 |
entity_name (str): The Weights & Biases entity (user or team).
|
| 80 |
run_name (str): The name of the Weights & Biases run.
|
| 81 |
dataset_repo (str, optional): The Hugging Face dataset repository to load.
|
| 82 |
+
model_name (str, optional): The pre-trained model to use.
|
|
|
|
|
|
|
| 83 |
prompt_column_name (str, optional): The column name in the dataset containing
|
| 84 |
+
the text prompts.
|
| 85 |
id2label (dict[int, str], optional): Mapping from label IDs to label names.
|
|
|
|
| 86 |
label2id (dict[str, int], optional): Mapping from label names to label IDs.
|
| 87 |
+
learning_rate (float, optional): The learning rate for training.
|
|
|
|
| 88 |
batch_size (int, optional): The batch size for training and evaluation.
|
| 89 |
+
num_epochs (int, optional): The number of training epochs.
|
| 90 |
+
weight_decay (float, optional): The weight decay for the optimizer.
|
|
|
|
| 91 |
save_steps (int, optional): The number of steps between model checkpoints.
|
|
|
|
| 92 |
streamlit_mode (bool, optional): If True, integrates with Streamlit to display
|
| 93 |
+
a progress bar.
|
| 94 |
|
| 95 |
Returns:
|
| 96 |
dict: The output of the training process, including metrics and model state.
|