JuliusCaesarGPT / app.py
Adityak204's picture
Add tqdm
34a1570 verified
raw
history blame
3.26 kB
import streamlit as st
import torch
from pathlib import Path
import math
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F
from src.gpt_base import GPT
import json
from huggingface_hub import hf_hub_download
from tqdm import tqdm
# Config class for model parameters
@dataclass
class GPTConfig:
block_size: int = 1024 # max sequence length
vocab_size: int = 65
num_layer: int = 12 # number of layers
num_head: int = 12 # number of heads
emb_dim: int = 768 # embedding dimension
dropout: float = 0.1 # dropout rate
# Copy all the model classes (GPT, MultiHeadAttention, FeedForward, TransformerBlock) here
# [Previous model code goes here]
# Load stoi and itos from docs
with open("docs/stoi.json") as f:
stoi = json.load(f)
with open("docs/itos.json") as f:
itos = json.load(f)
# Encoding/Decoding functions
def encode(s):
return [stoi[c] for c in s]
def decode(l):
return "".join([itos[i] for i in l])
def predict_next_word(text, model, seq_len=50):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for _ in tqdm(range(seq_len)):
xb = torch.tensor(encode(text)).unsqueeze(0).to(device)
yb = model(xb)
next_word = yb[0, -1].argmax().item()
text += itos[str(next_word)]
return text
# Streamlit app
st.title("GPT Text Generation")
# Add some usage instructions
st.markdown(
"""
### How to use:
1. Enter your text prompt in the text box below
2. Adjust the sequence length using the slider
3. Click 'Generate Text' to see the model's output
Note: Longer sequence lengths will take more time to generate.
"""
)
# Input text box
input_text = st.text_area("Enter your text prompt:", height=100)
# Sequence length slider
seq_length = st.slider(
"Select sequence length for prediction:",
min_value=50,
max_value=500,
value=200,
step=50,
)
# Model loading and prediction
if st.button("Generate Text"):
if input_text:
try:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = GPTConfig()
model = GPT(config)
model = model.to(device)
# Load checkpoint
# checkpoint_path = "/Users/aditya/Documents/self_learning/ERA V3/week 12/model artifacts/gpt_model_and_loss.pth"
model_repo = "Adityak204/JuliusCaesarGPT"
model_filename = "gpt_model_and_loss.pth"
checkpoint_path = hf_hub_download(
repo_id=model_repo, filename=model_filename
)
with st.spinner("Loading model and generating text..."):
_dict = torch.load(checkpoint_path, map_location=device)
model_state_dict = _dict["model_state_dict"]
model.load_state_dict(model_state_dict)
# Generate text
generated_text = predict_next_word(input_text, model, seq_length)
# Display results
st.subheader("Generated Text:")
st.write(generated_text)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
else:
st.warning("Please enter some text first!")