import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os # === Model ID and Token === model_id = "TrabbyPatty/mistral-7b-instruct-finetuned-flashcards-4bit" hf_token = os.getenv("alluse") # Hugging Face token from Space secrets # === Load tokenizer & model with authentication === tokenizer = AutoTokenizer.from_pretrained( model_id, token=hf_token, use_fast=False # ✅ force slow tokenizer (fixes JSON error) ) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", # let HF map to GPU/CPU torch_dtype=torch.float16, token=hf_token ) # === SYSTEM MESSAGE === SYSTEM_MESSAGE = """<> You are a strict flashcard generator. - Only extract information from the input. - Do NOT add outside knowledge, assumptions, or details not mentioned in the input. - Always follow the requested format exactly. <>""" def generate_flashcards(user_input, max_new_tokens=600, temperature=0.5): # Format the prompt with system + user input prompt = ( f"[INST] {SYSTEM_MESSAGE}\n\n" f"Create flashcards strictly using only the information provided.\n\n" f"Input: {user_input}[/INST]\nOutput:" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=False, repetition_penalty=1.05, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the Output section if "Output:" in response: final_answer = response.split("Output:")[-1].strip() else: final_answer = response.strip() return final_answer # ✅ Gradio UI demo = gr.Interface( fn=generate_flashcards, inputs=[ gr.Textbox(label="Enter study text", lines=8, placeholder="Paste your study material here..."), gr.Slider(100, 1000, value=600, step=50, label="Max New Tokens"), gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Temperature"), ], outputs="text", title="Flashcard Generator (Mistral-7B LoRA)", description="Paste study material and generate flashcards. Model strictly extracts only from input." ) if __name__ == "__main__": demo.launch()