Rai / app.py
TrabbyPatty's picture
Update app.py
9e5bead verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
# === Model ID and Token ===
model_id = "TrabbyPatty/mistral-7b-instruct-finetuned-flashcards-4bit"
hf_token = os.getenv("alluse") # Hugging Face token from Space secrets
# === Load tokenizer & model with authentication ===
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=hf_token,
use_fast=False # βœ… force slow tokenizer (fixes JSON error)
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # let HF map to GPU/CPU
torch_dtype=torch.float16,
token=hf_token
)
# === SYSTEM MESSAGE ===
SYSTEM_MESSAGE = """<<SYS>>
You are a strict flashcard generator.
- Only extract information from the input.
- Do NOT add outside knowledge, assumptions, or details not mentioned in the input.
- Always follow the requested format exactly.
<</SYS>>"""
def generate_flashcards(user_input, max_new_tokens=600, temperature=0.5):
# Format the prompt with system + user input
prompt = (
f"<s>[INST] {SYSTEM_MESSAGE}\n\n"
f"Create flashcards strictly using only the information provided.\n\n"
f"Input: {user_input}[/INST]\nOutput:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=False,
repetition_penalty=1.05,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the Output section
if "Output:" in response:
final_answer = response.split("Output:")[-1].strip()
else:
final_answer = response.strip()
return final_answer
# βœ… Gradio UI
demo = gr.Interface(
fn=generate_flashcards,
inputs=[
gr.Textbox(label="Enter study text", lines=8, placeholder="Paste your study material here..."),
gr.Slider(100, 1000, value=600, step=50, label="Max New Tokens"),
gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Temperature"),
],
outputs="text",
title="Flashcard Generator (Mistral-7B LoRA)",
description="Paste study material and generate flashcards. Model strictly extracts only from input."
)
if __name__ == "__main__":
demo.launch()