Spaces:
Runtime error
Runtime error
Commit
·
053730f
1
Parent(s):
2be5f55
update: app
Browse files
application_pages/train_classifier.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from dotenv import load_dotenv
|
| 3 |
|
|
@@ -26,7 +28,12 @@ dataset_name = st.sidebar.text_input("Dataset Name", value="")
|
|
| 26 |
st.session_state.dataset_name = dataset_name
|
| 27 |
|
| 28 |
base_model_name = st.sidebar.selectbox(
|
| 29 |
-
"Base Model",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
)
|
| 31 |
st.session_state.base_model_name = base_model_name
|
| 32 |
|
|
@@ -46,8 +53,9 @@ if st.session_state.should_start_training:
|
|
| 46 |
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
| 47 |
)
|
| 48 |
training_output = train_binary_classifier(
|
| 49 |
-
project_name="
|
| 50 |
-
entity_name="
|
|
|
|
| 51 |
dataset_repo=st.session_state.dataset_name,
|
| 52 |
model_name=st.session_state.base_model_name,
|
| 53 |
batch_size=st.session_state.batch_size,
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
|
|
|
|
| 28 |
st.session_state.dataset_name = dataset_name
|
| 29 |
|
| 30 |
base_model_name = st.sidebar.selectbox(
|
| 31 |
+
"Base Model",
|
| 32 |
+
options=[
|
| 33 |
+
"distilbert/distilbert-base-uncased",
|
| 34 |
+
"FacebookAI/roberta-base",
|
| 35 |
+
"microsoft/deberta-v3-base",
|
| 36 |
+
],
|
| 37 |
)
|
| 38 |
st.session_state.base_model_name = base_model_name
|
| 39 |
|
|
|
|
| 53 |
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
| 54 |
)
|
| 55 |
training_output = train_binary_classifier(
|
| 56 |
+
project_name=os.getenv("WANDB_PROJECT_NAME"),
|
| 57 |
+
entity_name=os.getenv("WANDB_ENTITY_NAME"),
|
| 58 |
+
run_name=f"{st.session_state.base_model_name}-finetuned",
|
| 59 |
dataset_repo=st.session_state.dataset_name,
|
| 60 |
model_name=st.session_state.base_model_name,
|
| 61 |
batch_size=st.session_state.batch_size,
|
guardrails_genie/train_classifier.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
import evaluate
|
| 3 |
import numpy as np
|
| 4 |
import streamlit as st
|
|
@@ -39,6 +38,7 @@ class StreamlitProgressbarCallback(TrainerCallback):
|
|
| 39 |
def train_binary_classifier(
|
| 40 |
project_name: str,
|
| 41 |
entity_name: str,
|
|
|
|
| 42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
| 43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
| 44 |
learning_rate: float = 2e-5,
|
|
@@ -47,7 +47,7 @@ def train_binary_classifier(
|
|
| 47 |
weight_decay: float = 0.01,
|
| 48 |
streamlit_mode: bool = False,
|
| 49 |
):
|
| 50 |
-
wandb.init(project=project_name, entity=entity_name)
|
| 51 |
dataset = load_dataset(dataset_repo)
|
| 52 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 53 |
|
|
|
|
|
|
|
| 1 |
import evaluate
|
| 2 |
import numpy as np
|
| 3 |
import streamlit as st
|
|
|
|
| 38 |
def train_binary_classifier(
|
| 39 |
project_name: str,
|
| 40 |
entity_name: str,
|
| 41 |
+
run_name: str,
|
| 42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
| 43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
| 44 |
learning_rate: float = 2e-5,
|
|
|
|
| 47 |
weight_decay: float = 0.01,
|
| 48 |
streamlit_mode: bool = False,
|
| 49 |
):
|
| 50 |
+
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
| 51 |
dataset = load_dataset(dataset_repo)
|
| 52 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 53 |
|