Spaces:
Runtime error
Runtime error
File size: 2,219 Bytes
7a8e236 f0385bc 7a8e236 f0385bc 7fa395b e6a3517 f0385bc a702d2b 7fa395b a702d2b f0385bc e6a3517 7a8e236 7e58c6c 7a8e236 7e58c6c 7a8e236 f0385bc 7a8e236 f0385bc 7a8e236 7e58c6c 7a8e236 7e58c6c 7a8e236 67dc437 3180549 67dc437 7a8e236 67dc437 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
import torch
from transformers import AutoTokenizer
from model import GPT, GPTConfig
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
line_divider = (
".~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._-~"
)
header = f"""
{line_divider}
______ __ __ _
/ ____/____ / // /(_)___ ____ ___
/ / / __ \/ // // / __ \/ __ \/ _ \
| |___/ /_/ / // // / /_/ / /_/ / ___/
\____/\__._/_//_//_/\____/ .___/\___/
/_/
-- your personal muse --
{line_divider}
"""
def setup(model_path: str):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if DEVICE == "cpu":
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
else:
checkpoint = torch.load(model_path)
model = GPT(GPTConfig(**checkpoint["model_args"]))
# rename keys because of torch >=2.1
state_dict = {}
for key, val in checkpoint["model"].items():
if key.startswith("_orig_mod"):
state_dict[key[10:]] = val
else:
state_dict[key] = val
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
return model, tokenizer
model, tokenizer = setup("checkpoints/Calliope-nano.pt")
def generate(
message,
max_tokens=128,
temperature=0.8,
):
idx = model.generate(
torch.tensor(
[tokenizer.encode(message, add_special_tokens=False)], device=DEVICE
),
max_new_tokens=max_tokens,
temperature=temperature,
)
return tokenizer.decode(idx[0].cpu().numpy())
app = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(lines=16, label="your starting point..", placeholder="rain falls slowly on your darling cheeks")
],
outputs=[
gr.Textbox(lines=16, label="calliope continues..")
],
allow_flagging="never",
)
with gr.Blocks() as demo:
gr.HTML(f"<div style='text-align: center;'><pre>{header}</pre></div>")
app.render()
if __name__ == "__main__":
demo.launch() |