File size: 2,219 Bytes
7a8e236
f0385bc
 
7a8e236
f0385bc
 
7fa395b
e6a3517
 
 
 
 
 
 
 
 
 
 
 
 
 
f0385bc
 
 
 
a702d2b
7fa395b
a702d2b
 
f0385bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6a3517
7a8e236
 
7e58c6c
7a8e236
7e58c6c
 
7a8e236
f0385bc
 
 
 
 
7a8e236
f0385bc
 
7a8e236
 
7e58c6c
 
 
 
 
 
 
 
7a8e236
7e58c6c
7a8e236
 
67dc437
3180549
67dc437
 
7a8e236
67dc437
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import torch
from transformers import AutoTokenizer

from model import GPT, GPTConfig

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
line_divider = (
    ".~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._-~"
)
header = f"""
{line_divider}
                       ______       __ __ _                 
                      / ____/____  / // /(_)___  ____  ___  
                     / /    / __ \/ // // / __ \/ __ \/ _ \ 
                     | |___/ /_/ / // // / /_/ / /_/ / ___/ 
                     \____/\__._/_//_//_/\____/ .___/\___/  
                                             /_/            
                            -- your personal muse --
{line_divider}
"""


def setup(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    if DEVICE == "cpu":
        checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
    else:
        checkpoint = torch.load(model_path)
    model = GPT(GPTConfig(**checkpoint["model_args"]))

    # rename keys because of torch >=2.1
    state_dict = {}
    for key, val in checkpoint["model"].items():
        if key.startswith("_orig_mod"):
            state_dict[key[10:]] = val
        else:
            state_dict[key] = val
    model.load_state_dict(state_dict)
    model.to(DEVICE)
    model.eval()
    return model, tokenizer


model, tokenizer = setup("checkpoints/Calliope-nano.pt")


def generate(
    message,
    max_tokens=128,
    temperature=0.8,
):
    idx = model.generate(
        torch.tensor(
            [tokenizer.encode(message, add_special_tokens=False)], device=DEVICE
        ),
        max_new_tokens=max_tokens,
        temperature=temperature,
    )
    return tokenizer.decode(idx[0].cpu().numpy())



app = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(lines=16, label="your starting point..", placeholder="rain falls slowly on your darling cheeks")
    ],
    outputs=[
        gr.Textbox(lines=16, label="calliope continues..")
    ],
    allow_flagging="never",
)

with gr.Blocks() as demo:
    gr.HTML(f"<div style='text-align: center;'><pre>{header}</pre></div>")
    app.render()

if __name__ == "__main__":
    demo.launch()