Spaces:
Running
Running
| import os | |
| import torch | |
| import tiktoken | |
| import gradio as gr | |
| from transformers import GPT2Tokenizer | |
| from model import GPTLanguageModel | |
| # Initialize the GPT-2 tokenizer | |
| enc = tiktoken.get_encoding("gpt2") # Using tiktoken | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Using Hugging Face tokenizer for consistency | |
| # Load the GPT-2 model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Model hyperparameters (should match the training configuration) | |
| vocab_size = 50257 | |
| n_heads = 8 | |
| n_layers = 6 | |
| n_embd = 512 | |
| block_size = 128 | |
| dropout = 0.1 | |
| # Create the GPT model instance | |
| model = GPTLanguageModel(vocab_size, n_embd, block_size, n_layers, n_heads).to(device) | |
| # Load the trained model weights | |
| if os.path.exists("model_weights.pth"): | |
| model.load_state_dict(torch.load("model_weights.pth", map_location=device)) | |
| model.eval() | |
| # Function to generate a response based on the user input | |
| def get_response(prompt): | |
| # Tokenize the input prompt | |
| context = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=device) | |
| # Generate tokens from the model | |
| max_new_tokens = 200 # Number of tokens to generate | |
| temperature = 0.8 # Can adjust for different sampling behaviors | |
| generated_text_idx = model.generate(context, max_new_tokens) | |
| # Decode the generated token IDs into text | |
| generated_text = enc.decode(generated_text_idx[0].tolist()) | |
| return generated_text | |
| def main(): | |
| """Main function to run the app""" | |
| # Setup Gradio interface | |
| iface = gr.Interface(fn=get_response, inputs="text", outputs="text", title="StoryCrafterLLM") | |
| iface.launch() | |
| if __name__ == "__main__": | |
| main() | |