Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig | |
| import os | |
| import time | |
| # Disable wandb | |
| os.environ["WANDB_DISABLED"] = "true" | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| training_status = "Not started" | |
| def load_model(): | |
| global model, tokenizer | |
| try: | |
| # Configure 4-bit quantization for memory efficiency | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| # Load model and tokenizer | |
| model_name = "LLM360/K2-Think" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=quantization_config, | |
| device_map="auto" | |
| ) | |
| # Set padding token | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return "Model loaded successfully!" | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| def prepare_data(): | |
| try: | |
| # Load a sample dataset (you can replace this with your own) | |
| dataset = load_dataset("imdb") | |
| # Preprocessing function | |
| def preprocess_function(examples): | |
| # Format the text for instruction tuning | |
| texts = [] | |
| for text, label in zip(examples["text"], examples["label"]): | |
| sentiment = "positive" if label == 1 else "negative" | |
| texts.append(f"Analyze the sentiment of this movie review: {text}\nSentiment: {sentiment}") | |
| # Tokenize | |
| tokenized = tokenizer(texts, truncation=True, padding=True, max_length=256) | |
| # Create labels | |
| tokenized["labels"] = tokenized["input_ids"].copy() | |
| return tokenized | |
| # Apply preprocessing | |
| tokenized_dataset = dataset.map( | |
| preprocess_function, | |
| batched=True, | |
| remove_columns=dataset["train"].column_names | |
| ) | |
| # Use small subset for demo | |
| train_dataset = tokenized_dataset["train"].shuffle().select(range(50)) | |
| return train_dataset, "Data prepared successfully!" | |
| except Exception as e: | |
| return None, f"Error preparing data: {str(e)}" | |
| def train_model(): | |
| global model, tokenizer, training_status | |
| try: | |
| training_status = "Starting training..." | |
| yield training_status | |
| # Prepare data | |
| train_dataset, status = prepare_data() | |
| if train_dataset is None: | |
| training_status = status | |
| yield training_status | |
| return | |
| training_status = status | |
| yield training_status | |
| # Set up training arguments | |
| training_args = TrainingArguments( | |
| output_dir="./k2-think-finetuned", | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| num_train_epochs=1, | |
| learning_rate=2e-5, | |
| fp16=True, | |
| save_strategy="no", | |
| logging_steps=5, | |
| ) | |
| training_status = "Training configuration set up..." | |
| yield training_status | |
| # Create trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| ) | |
| training_status = "Starting training process..." | |
| yield training_status | |
| # Start training | |
| trainer.train() | |
| training_status = "Training completed! Saving model..." | |
| yield training_status | |
| # Save model | |
| model.save_pretrained("./k2-think-finetuned") | |
| tokenizer.save_pretrained("./k2-think-finetuned") | |
| training_status = "Model saved successfully! Ready for inference." | |
| yield training_status | |
| except Exception as e: | |
| training_status = f"Error during training: {str(e)}" | |
| yield training_status | |
| def generate_text(prompt): | |
| if model is None or tokenizer is None: | |
| return "Please load the model first." | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_length=200, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| except Exception as e: | |
| return f"Error generating text: {str(e)}" | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# K2-Think Model Training") | |
| with gr.Tab("Training"): | |
| gr.Markdown("## Fine-tune K2-Think Model") | |
| with gr.Row(): | |
| load_btn = gr.Button("Load Model") | |
| train_btn = gr.Button("Start Training") | |
| status_output = gr.Textbox(label="Training Status", value=training_status) | |
| load_btn.click(load_model, outputs=status_output) | |
| train_btn.click(train_model, outputs=status_output) | |
| with gr.Tab("Inference"): | |
| gr.Markdown("## Test Your Fine-tuned Model") | |
| with gr.Row(): | |
| prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Analyze the sentiment of this movie review: This movie was amazing!") | |
| generate_btn = gr.Button("Generate") | |
| output_text = gr.Textbox(label="Generated Text") | |
| generate_btn.click(generate_text, inputs=prompt_input, outputs=output_text) | |
| demo.launch() |