Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from pathlib import Path | |
| import math | |
| from dataclasses import dataclass | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.gpt_base import GPT | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| from tqdm import tqdm | |
| # Config class for model parameters | |
| class GPTConfig: | |
| block_size: int = 1024 # max sequence length | |
| vocab_size: int = 65 | |
| num_layer: int = 12 # number of layers | |
| num_head: int = 12 # number of heads | |
| emb_dim: int = 768 # embedding dimension | |
| dropout: float = 0.1 # dropout rate | |
| # Copy all the model classes (GPT, MultiHeadAttention, FeedForward, TransformerBlock) here | |
| # [Previous model code goes here] | |
| # Load stoi and itos from docs | |
| with open("docs/stoi.json") as f: | |
| stoi = json.load(f) | |
| with open("docs/itos.json") as f: | |
| itos = json.load(f) | |
| # Encoding/Decoding functions | |
| def encode(s): | |
| return [stoi[c] for c in s] | |
| def decode(l): | |
| return "".join([itos[i] for i in l]) | |
| def predict_next_word(text, model, seq_len=50): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| for _ in tqdm(range(seq_len)): | |
| xb = torch.tensor(encode(text)).unsqueeze(0).to(device) | |
| yb = model(xb) | |
| next_word = yb[0, -1].argmax().item() | |
| text += itos[str(next_word)] | |
| return text | |
| # Streamlit app | |
| st.title("GPT Text Generation") | |
| # Add some usage instructions | |
| st.markdown( | |
| """ | |
| ### How to use: | |
| 1. Enter your text prompt in the text box below | |
| 2. Adjust the sequence length using the slider | |
| 3. Click 'Generate Text' to see the model's output | |
| Note: Longer sequence lengths will take more time to generate. | |
| """ | |
| ) | |
| # Input text box | |
| input_text = st.text_area("Enter your text prompt:", height=100) | |
| # Sequence length slider | |
| seq_length = st.slider( | |
| "Select sequence length for prediction:", | |
| min_value=50, | |
| max_value=500, | |
| value=200, | |
| step=50, | |
| ) | |
| # Model loading and prediction | |
| if st.button("Generate Text"): | |
| if input_text: | |
| try: | |
| # Initialize model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| config = GPTConfig() | |
| model = GPT(config) | |
| model = model.to(device) | |
| # Load checkpoint | |
| # checkpoint_path = "/Users/aditya/Documents/self_learning/ERA V3/week 12/model artifacts/gpt_model_and_loss.pth" | |
| model_repo = "Adityak204/JuliusCaesarGPT" | |
| model_filename = "gpt_model_and_loss.pth" | |
| checkpoint_path = hf_hub_download( | |
| repo_id=model_repo, filename=model_filename | |
| ) | |
| with st.spinner("Loading model and generating text..."): | |
| _dict = torch.load(checkpoint_path, map_location=device) | |
| model_state_dict = _dict["model_state_dict"] | |
| model.load_state_dict(model_state_dict) | |
| # Generate text | |
| generated_text = predict_next_word(input_text, model, seq_length) | |
| # Display results | |
| st.subheader("Generated Text:") | |
| st.write(generated_text) | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| else: | |
| st.warning("Please enter some text first!") | |