Spaces:
Runtime error
Runtime error
File size: 4,429 Bytes
c514928 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""
Gradio interface for TinyStories Llama model chat.
"""
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import warnings
import os
warnings.filterwarnings('ignore', category=UserWarning)
MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/tinystories-llama-15m")
print(f"Loading model and tokenizer from {MODEL_REPO}...")
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO)
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
print(f"Model loaded on {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
def generate_story(
prompt,
max_length=200,
temperature=0.8,
top_k=50,
top_p=0.9,
do_sample=True
):
"""Generate a story continuation from the prompt."""
if not prompt.strip():
return "Please provide a story prompt!"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode and return
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
with gr.Blocks(title="TinyStories Story Generator") as demo:
gr.Markdown(
"""
# TinyStories Llama Model Chat
This is a small Llama-architecture model trained on the TinyStories dataset.
It generates simple, coherent children's stories using vocabulary that a typical 3-4 year old would understand.
**Try starting your story with:**
- "Once upon a time, there was a..."
- "One day, a little boy named..."
- "In a small town, there lived a..."
"""
)
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Story Prompt",
placeholder="Once upon a time, there was a",
lines=3
)
with gr.Accordion("Generation Settings", open=False):
max_length_slider = gr.Slider(
minimum=50,
maximum=256,
value=200,
step=10,
label="Max Length (tokens)"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature (higher = more creative)"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-k"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (nucleus sampling)"
)
do_sample_checkbox = gr.Checkbox(
label="Use Sampling",
value=True
)
generate_btn = gr.Button("Generate Story", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Story",
lines=15,
show_copy_button=True
)
gr.Examples(
examples=[
["Once upon a time, there was a little girl named Lily."],
["One day, a little boy found a magic"],
["The little dog was very happy because"],
["In a small garden, there lived a"],
["Timmy wanted to play with his friend, but"],
],
inputs=prompt_input,
label="Example Prompts"
)
generate_btn.click(
fn=generate_story,
inputs=[
prompt_input,
max_length_slider,
temperature_slider,
top_k_slider,
top_p_slider,
do_sample_checkbox
],
outputs=output_text
)
if __name__ == "__main__":
demo.launch()
|