Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer | |
| from model import GPT, GPTConfig | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| line_divider = ( | |
| ".~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._.~^~._-~" | |
| ) | |
| header = f""" | |
| {line_divider} | |
| ______ __ __ _ | |
| / ____/____ / // /(_)___ ____ ___ | |
| / / / __ \/ // // / __ \/ __ \/ _ \ | |
| | |___/ /_/ / // // / /_/ / /_/ / ___/ | |
| \____/\__._/_//_//_/\____/ .___/\___/ | |
| /_/ | |
| -- your personal muse -- | |
| {line_divider} | |
| """ | |
| def setup(model_path: str): | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| if DEVICE == "cpu": | |
| checkpoint = torch.load(model_path, map_location=torch.device("cpu")) | |
| else: | |
| checkpoint = torch.load(model_path) | |
| model = GPT(GPTConfig(**checkpoint["model_args"])) | |
| # rename keys because of torch >=2.1 | |
| state_dict = {} | |
| for key, val in checkpoint["model"].items(): | |
| if key.startswith("_orig_mod"): | |
| state_dict[key[10:]] = val | |
| else: | |
| state_dict[key] = val | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model, tokenizer | |
| model, tokenizer = setup("checkpoints/Calliope-nano.pt") | |
| def generate( | |
| message, | |
| max_tokens=128, | |
| temperature=0.8, | |
| ): | |
| idx = model.generate( | |
| torch.tensor( | |
| [tokenizer.encode(message, add_special_tokens=False)], device=DEVICE | |
| ), | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| return tokenizer.decode(idx[0].cpu().numpy()) | |
| app = gr.Interface( | |
| fn=generate, | |
| inputs=[ | |
| gr.Textbox(lines=16, label="your starting point..", placeholder="rain falls slowly on your darling cheeks") | |
| ], | |
| outputs=[ | |
| gr.Textbox(lines=16, label="calliope continues..") | |
| ], | |
| allow_flagging="never", | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.HTML(f"<div style='text-align: center;'><pre>{header}</pre></div>") | |
| app.render() | |
| if __name__ == "__main__": | |
| demo.launch() |