thehigherbeing / app.py
zeroranker's picture
Create app.py
f1b46ec verified
import gradio as gr
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig
import os
import time
# Disable wandb
os.environ["WANDB_DISABLED"] = "true"
# Global variables
model = None
tokenizer = None
training_status = "Not started"
def load_model():
global model, tokenizer
try:
# Configure 4-bit quantization for memory efficiency
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
)
# Load model and tokenizer
model_name = "LLM360/K2-Think"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto"
)
# Set padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return "Model loaded successfully!"
except Exception as e:
return f"Error loading model: {str(e)}"
def prepare_data():
try:
# Load a sample dataset (you can replace this with your own)
dataset = load_dataset("imdb")
# Preprocessing function
def preprocess_function(examples):
# Format the text for instruction tuning
texts = []
for text, label in zip(examples["text"], examples["label"]):
sentiment = "positive" if label == 1 else "negative"
texts.append(f"Analyze the sentiment of this movie review: {text}\nSentiment: {sentiment}")
# Tokenize
tokenized = tokenizer(texts, truncation=True, padding=True, max_length=256)
# Create labels
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
# Apply preprocessing
tokenized_dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=dataset["train"].column_names
)
# Use small subset for demo
train_dataset = tokenized_dataset["train"].shuffle().select(range(50))
return train_dataset, "Data prepared successfully!"
except Exception as e:
return None, f"Error preparing data: {str(e)}"
def train_model():
global model, tokenizer, training_status
try:
training_status = "Starting training..."
yield training_status
# Prepare data
train_dataset, status = prepare_data()
if train_dataset is None:
training_status = status
yield training_status
return
training_status = status
yield training_status
# Set up training arguments
training_args = TrainingArguments(
output_dir="./k2-think-finetuned",
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_train_epochs=1,
learning_rate=2e-5,
fp16=True,
save_strategy="no",
logging_steps=5,
)
training_status = "Training configuration set up..."
yield training_status
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
training_status = "Starting training process..."
yield training_status
# Start training
trainer.train()
training_status = "Training completed! Saving model..."
yield training_status
# Save model
model.save_pretrained("./k2-think-finetuned")
tokenizer.save_pretrained("./k2-think-finetuned")
training_status = "Model saved successfully! Ready for inference."
yield training_status
except Exception as e:
training_status = f"Error during training: {str(e)}"
yield training_status
def generate_text(prompt):
if model is None or tokenizer is None:
return "Please load the model first."
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs.input_ids,
max_length=200,
num_return_sequences=1,
temperature=0.7,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Error generating text: {str(e)}"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# K2-Think Model Training")
with gr.Tab("Training"):
gr.Markdown("## Fine-tune K2-Think Model")
with gr.Row():
load_btn = gr.Button("Load Model")
train_btn = gr.Button("Start Training")
status_output = gr.Textbox(label="Training Status", value=training_status)
load_btn.click(load_model, outputs=status_output)
train_btn.click(train_model, outputs=status_output)
with gr.Tab("Inference"):
gr.Markdown("## Test Your Fine-tuned Model")
with gr.Row():
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Analyze the sentiment of this movie review: This movie was amazing!")
generate_btn = gr.Button("Generate")
output_text = gr.Textbox(label="Generated Text")
generate_btn.click(generate_text, inputs=prompt_input, outputs=output_text)
demo.launch()