Spaces:
Runtime error
Runtime error
Commit
·
968f4bc
1
Parent(s):
0cde3e9
add: integration of training with app
Browse files- app.py +8 -1
- application_pages/train_classifier.py +57 -0
- guardrails_genie/train_classifier.py +31 -2
app.py
CHANGED
|
@@ -13,6 +13,13 @@ evaluation_page = st.Page(
|
|
| 13 |
title="Evaluation",
|
| 14 |
icon=":material/monitoring:",
|
| 15 |
)
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
| 18 |
page_navigation.run()
|
|
|
|
| 13 |
title="Evaluation",
|
| 14 |
icon=":material/monitoring:",
|
| 15 |
)
|
| 16 |
+
train_classifier_page = st.Page(
|
| 17 |
+
"application_pages/train_classifier.py",
|
| 18 |
+
title="Train Classifier",
|
| 19 |
+
icon=":material/fitness_center:",
|
| 20 |
+
)
|
| 21 |
+
page_navigation = st.navigation(
|
| 22 |
+
[intro_page, chat_page, evaluation_page, train_classifier_page]
|
| 23 |
+
)
|
| 24 |
st.set_page_config(page_title="Guardrails Genie", page_icon=":material/guardian:")
|
| 25 |
page_navigation.run()
|
application_pages/train_classifier.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
import wandb
|
| 5 |
+
from guardrails_genie.train_classifier import train_binary_classifier
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def initialize_session_state():
|
| 9 |
+
load_dotenv()
|
| 10 |
+
if "dataset_name" not in st.session_state:
|
| 11 |
+
st.session_state.dataset_name = None
|
| 12 |
+
if "base_model_name" not in st.session_state:
|
| 13 |
+
st.session_state.base_model_name = None
|
| 14 |
+
if "batch_size" not in st.session_state:
|
| 15 |
+
st.session_state.batch_size = 16
|
| 16 |
+
if "should_start_training" not in st.session_state:
|
| 17 |
+
st.session_state.should_start_training = False
|
| 18 |
+
if "training_output" not in st.session_state:
|
| 19 |
+
st.session_state.training_output = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
initialize_session_state()
|
| 23 |
+
st.title(":material/fitness_center: Train Classifier")
|
| 24 |
+
|
| 25 |
+
dataset_name = st.sidebar.text_input("Dataset Name", value="")
|
| 26 |
+
st.session_state.dataset_name = dataset_name
|
| 27 |
+
|
| 28 |
+
base_model_name = st.sidebar.selectbox(
|
| 29 |
+
"Base Model", options=["distilbert/distilbert-base-uncased", "roberta-base"]
|
| 30 |
+
)
|
| 31 |
+
st.session_state.base_model_name = base_model_name
|
| 32 |
+
|
| 33 |
+
batch_size = st.sidebar.slider(
|
| 34 |
+
"Batch Size", min_value=4, max_value=256, value=16, step=4
|
| 35 |
+
)
|
| 36 |
+
st.session_state.batch_size = batch_size
|
| 37 |
+
|
| 38 |
+
train_button = st.sidebar.button("Train")
|
| 39 |
+
st.session_state.should_start_training = (
|
| 40 |
+
train_button and st.session_state.dataset_name and st.session_state.base_model_name
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if st.session_state.should_start_training:
|
| 44 |
+
with st.expander("Training", expanded=True):
|
| 45 |
+
st.markdown(
|
| 46 |
+
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
| 47 |
+
)
|
| 48 |
+
training_output = train_binary_classifier(
|
| 49 |
+
project_name="guardrails-genie",
|
| 50 |
+
entity_name="geekyrakshit",
|
| 51 |
+
dataset_repo=st.session_state.dataset_name,
|
| 52 |
+
model_name=st.session_state.base_model_name,
|
| 53 |
+
batch_size=st.session_state.batch_size,
|
| 54 |
+
streamlit_mode=True,
|
| 55 |
+
)
|
| 56 |
+
st.session_state.training_output = training_output
|
| 57 |
+
st.write(training_output)
|
guardrails_genie/train_classifier.py
CHANGED
|
@@ -1,14 +1,39 @@
|
|
|
|
|
| 1 |
import evaluate
|
| 2 |
import numpy as np
|
| 3 |
-
import
|
| 4 |
from datasets import load_dataset
|
| 5 |
from transformers import (
|
| 6 |
AutoModelForSequenceClassification,
|
| 7 |
AutoTokenizer,
|
| 8 |
DataCollatorWithPadding,
|
| 9 |
Trainer,
|
|
|
|
| 10 |
TrainingArguments,
|
| 11 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def train_binary_classifier(
|
|
@@ -20,6 +45,7 @@ def train_binary_classifier(
|
|
| 20 |
batch_size: int = 16,
|
| 21 |
num_epochs: int = 2,
|
| 22 |
weight_decay: float = 0.01,
|
|
|
|
| 23 |
):
|
| 24 |
wandb.init(project=project_name, entity=entity_name)
|
| 25 |
dataset = load_dataset(dataset_repo)
|
|
@@ -69,5 +95,8 @@ def train_binary_classifier(
|
|
| 69 |
processing_class=tokenizer,
|
| 70 |
data_collator=data_collator,
|
| 71 |
compute_metrics=compute_metrics,
|
|
|
|
| 72 |
)
|
| 73 |
-
trainer.train()
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
import evaluate
|
| 3 |
import numpy as np
|
| 4 |
+
import streamlit as st
|
| 5 |
from datasets import load_dataset
|
| 6 |
from transformers import (
|
| 7 |
AutoModelForSequenceClassification,
|
| 8 |
AutoTokenizer,
|
| 9 |
DataCollatorWithPadding,
|
| 10 |
Trainer,
|
| 11 |
+
TrainerCallback,
|
| 12 |
TrainingArguments,
|
| 13 |
)
|
| 14 |
+
from transformers.trainer_callback import TrainerControl, TrainerState
|
| 15 |
+
|
| 16 |
+
import wandb
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class StreamlitProgressbarCallback(TrainerCallback):
|
| 20 |
+
|
| 21 |
+
def __init__(self, *args, **kwargs):
|
| 22 |
+
super().__init__(*args, **kwargs)
|
| 23 |
+
self.progress_bar = st.progress(0, text="Training")
|
| 24 |
+
|
| 25 |
+
def on_step_begin(
|
| 26 |
+
self,
|
| 27 |
+
args: TrainingArguments,
|
| 28 |
+
state: TrainerState,
|
| 29 |
+
control: TrainerControl,
|
| 30 |
+
**kwargs,
|
| 31 |
+
):
|
| 32 |
+
super().on_step_begin(args, state, control, **kwargs)
|
| 33 |
+
self.progress_bar.progress(
|
| 34 |
+
(state.global_step * 100 // state.max_steps) + 1,
|
| 35 |
+
text=f"Training {state.global_step} / {state.max_steps}",
|
| 36 |
+
)
|
| 37 |
|
| 38 |
|
| 39 |
def train_binary_classifier(
|
|
|
|
| 45 |
batch_size: int = 16,
|
| 46 |
num_epochs: int = 2,
|
| 47 |
weight_decay: float = 0.01,
|
| 48 |
+
streamlit_mode: bool = False,
|
| 49 |
):
|
| 50 |
wandb.init(project=project_name, entity=entity_name)
|
| 51 |
dataset = load_dataset(dataset_repo)
|
|
|
|
| 95 |
processing_class=tokenizer,
|
| 96 |
data_collator=data_collator,
|
| 97 |
compute_metrics=compute_metrics,
|
| 98 |
+
callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
|
| 99 |
)
|
| 100 |
+
training_output = trainer.train()
|
| 101 |
+
wandb.finish()
|
| 102 |
+
return training_output
|