File size: 4,429 Bytes
c514928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Gradio interface for TinyStories Llama model chat.
"""
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import warnings
import os

warnings.filterwarnings('ignore', category=UserWarning)

MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/tinystories-llama-15m")

print(f"Loading model and tokenizer from {MODEL_REPO}...")
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO)
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()

print(f"Model loaded on {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


def generate_story(
    prompt,
    max_length=200,
    temperature=0.8,
    top_k=50,
    top_p=0.9,
    do_sample=True
):
    """Generate a story continuation from the prompt."""
    if not prompt.strip():
        return "Please provide a story prompt!"

    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            do_sample=do_sample,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    # Decode and return
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text


with gr.Blocks(title="TinyStories Story Generator") as demo:
    gr.Markdown(
        """
        # TinyStories Llama Model Chat

        This is a small Llama-architecture model trained on the TinyStories dataset.
        It generates simple, coherent children's stories using vocabulary that a typical 3-4 year old would understand.

        **Try starting your story with:**
        - "Once upon a time, there was a..."
        - "One day, a little boy named..."
        - "In a small town, there lived a..."
        """
    )

    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Story Prompt",
                placeholder="Once upon a time, there was a",
                lines=3
            )

            with gr.Accordion("Generation Settings", open=False):
                max_length_slider = gr.Slider(
                    minimum=50,
                    maximum=256,
                    value=200,
                    step=10,
                    label="Max Length (tokens)"
                )
                temperature_slider = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.8,
                    step=0.1,
                    label="Temperature (higher = more creative)"
                )
                top_k_slider = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=50,
                    step=1,
                    label="Top-k"
                )
                top_p_slider = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.05,
                    label="Top-p (nucleus sampling)"
                )
                do_sample_checkbox = gr.Checkbox(
                    label="Use Sampling",
                    value=True
                )

            generate_btn = gr.Button("Generate Story", variant="primary")

        with gr.Column():
            output_text = gr.Textbox(
                label="Generated Story",
                lines=15,
                show_copy_button=True
            )

    gr.Examples(
        examples=[
            ["Once upon a time, there was a little girl named Lily."],
            ["One day, a little boy found a magic"],
            ["The little dog was very happy because"],
            ["In a small garden, there lived a"],
            ["Timmy wanted to play with his friend, but"],
        ],
        inputs=prompt_input,
        label="Example Prompts"
    )

    generate_btn.click(
        fn=generate_story,
        inputs=[
            prompt_input,
            max_length_slider,
            temperature_slider,
            top_k_slider,
            top_p_slider,
            do_sample_checkbox
        ],
        outputs=output_text
    )

if __name__ == "__main__":
    demo.launch()