Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer | |
| import os | |
| from dataclasses import dataclass | |
| from huggingface_hub import hf_hub_download | |
| from src.model import SmolLM | |
| def greedy_decode(model, input_ids, max_length=100, tokenizer=None): | |
| current_ids = input_ids | |
| with torch.no_grad(): | |
| for _ in range(max_length - current_ids.shape[1]): | |
| outputs = model(current_ids) | |
| last_token_logits = outputs[:, -1, :] | |
| next_token = torch.argmax(last_token_logits, dim=-1).unsqueeze(0) | |
| current_ids = torch.cat([current_ids, next_token], dim=1) | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| return current_ids | |
| def generate_prediction(model, prompt, max_length=100): | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| device = next(model.parameters()).device | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| generated_ids = greedy_decode( | |
| model, input_ids, max_length=max_length, tokenizer=tokenizer | |
| ) | |
| generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| return generated_text | |
| def main(): | |
| # Set page configuration | |
| st.set_page_config(page_title="SmolLM2-TextGen", page_icon="π€") | |
| # Title and description | |
| st.title("SmolLM2-TextGen π€") | |
| st.write("Generate text using the SmolLM2 language model") | |
| # Load the model (you'll need to replace this with your actual model loading logic) | |
| def load_model(config): | |
| model = SmolLM(config) | |
| return model | |
| # Try to load the model | |
| try: | |
| class MainConfig: | |
| vocab_size: int = 49152 | |
| emb_dim: int = 576 | |
| intermediate_size: int = 1536 | |
| num_layers: int = 30 | |
| n_q_heads: int = 9 | |
| n_kv_heads: int = 3 | |
| max_seq_len: int = 1024 | |
| dropout: float = 0.1 | |
| rms_norm_eps: float = 1e-05 | |
| init_std: float = 0.041666666666666664 | |
| config = MainConfig() | |
| model = load_model(config) | |
| # load checkpoint | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # checkpoint_path = "/Users/aditya/Documents/self_learning/ERA V3/week 13/artifacts/m1/smolLM-v2.pth" | |
| model_repo = "Adityak204/SmolLM2-135-cosmopedia-10k" | |
| model_filename = "smolLM-v2.pth" | |
| checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename) | |
| checkpoint = torch.load(checkpoint_path, map_location=device)[ | |
| "model_state_dict" | |
| ] | |
| model.load_state_dict(checkpoint) | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return | |
| # Input prompt | |
| prompt = st.text_input( | |
| "Enter your prompt:", placeholder="Type a sentence to generate text..." | |
| ) | |
| # Max length slider | |
| max_length = st.slider( | |
| "Maximum Generation Length", min_value=10, max_value=200, value=100, step=10 | |
| ) | |
| # Generate button | |
| if st.button("Generate Text"): | |
| if not prompt: | |
| st.warning("Please enter a prompt.") | |
| return | |
| # Show loading spinner | |
| with st.spinner("Generating text..."): | |
| try: | |
| # Generate text | |
| generated_text = generate_prediction(model, prompt, max_length) | |
| # Display generated text | |
| st.subheader("Generated Text:") | |
| st.write(generated_text) | |
| except Exception as e: | |
| st.error(f"An error occurred during text generation: {e}") | |
| if __name__ == "__main__": | |
| main() | |