| import torch | |
| import json | |
| import tiktoken | |
| from your_model_module import GPT, GPTConfig # Assuming your GPT class is in a file called your_model_module.py | |
| # Load configuration | |
| with open("config.json", "r") as f: | |
| config_dict = json.load(f) | |
| config = GPTConfig(**config_dict) | |
| # Load model | |
| model = GPT(config) | |
| model.load_state_dict(torch.load("best_model_params.pt", map_location=torch.device("cpu"))) # Load to CPU | |
| model.eval() | |
| # Load tokenizer | |
| enc = tiktoken.get_encoding("gpt2") | |
| def generate_text(prompt, max_new_tokens=200, temperature=1.0, top_k=None): | |
| context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(dim=0) | |
| with torch.no_grad(): | |
| generated_tokens = model.generate(context, max_new_tokens, temperature=temperature, top_k=top_k) | |
| return enc.decode(generated_tokens.squeeze().tolist()) | |
| if __name__ == "__main__": | |
| prompt = input("Enter your prompt: ") | |
| generated_text = generate_text(prompt) | |
| print(generated_text) | |