Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from huggingface_hub import login | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from transformers import BitsAndBytesConfig | |
| import os | |
| def initialize_model(): | |
| """Initialize the model and tokenizer with CPU support""" | |
| # Log in to Hugging Face | |
| token = os.environ.get("hf") | |
| if token: | |
| login(token) | |
| model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| try: | |
| # Try with regular CPU mode first (simpler and more reliable) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise e | |
| # Ensure padding token is defined | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return model, tokenizer | |
| def format_prompt(user_input, conversation_history=[]): | |
| """Format the prompt according to TinyLlama's expected chat format""" | |
| messages = [] | |
| # Add conversation history | |
| for turn in conversation_history: | |
| messages.append({"role": "user", "content": turn["user"]}) | |
| messages.append({"role": "assistant", "content": turn["assistant"]}) | |
| # Add current user input | |
| messages.append({"role": "user", "content": user_input}) | |
| # Format into TinyLlama chat format | |
| formatted_prompt = "<|system|>You are a helpful AI assistant.</s>" | |
| for message in messages: | |
| if message["role"] == "user": | |
| formatted_prompt += f"<|user|>{message['content']}</s>" | |
| else: | |
| formatted_prompt += f"<|assistant|>{message['content']}</s>" | |
| formatted_prompt += "<|assistant|>" | |
| return formatted_prompt | |
| def generate_response(model, tokenizer, prompt, conversation_history): | |
| """Generate model response""" | |
| try: | |
| # Format prompt using TinyLlama's chat template | |
| formatted_prompt = format_prompt(prompt, conversation_history[:-1]) | |
| # Tokenize input | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True, truncation=True) | |
| # Move inputs to the same device as the model | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Calculate max new tokens | |
| input_length = inputs["input_ids"].shape[1] | |
| max_model_length = 1024 | |
| max_new_tokens = min(150, max_model_length - input_length) | |
| # Generate response | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.pad_token_id, | |
| do_sample=True, | |
| min_length=10, | |
| no_repeat_ngram_size=3, | |
| eos_token_id=tokenizer.encode("</s>")[0] # Set end token | |
| ) | |
| # Decode response and extract only the assistant's message | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| # Extract only the last assistant response | |
| assistant_response = full_response.split("<|assistant|>")[-1].split("</s>")[0].strip() | |
| return assistant_response if assistant_response else "I apologize, but I couldn't generate a proper response." | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| torch.cuda.empty_cache() | |
| return "I apologize, but I ran out of memory. Please try a shorter message or clear the chat history." | |
| else: | |
| return f"An error occurred: {str(e)}" | |
| def main(): | |
| st.set_page_config( | |
| page_title="LLM Chat Interface", | |
| page_icon="π€", | |
| layout="wide" | |
| ) | |
| # Add CSS to make the chat interface more compact | |
| st.markdown(""" | |
| <style> | |
| .stChat { | |
| padding-top: 0rem; | |
| } | |
| .stChatMessage { | |
| padding: 0.5rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("Chat with TinyLlama π€") | |
| # Initialize session state for chat history | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Initialize model (only once) | |
| if "model" not in st.session_state: | |
| with st.spinner("Loading the model... This might take a minute..."): | |
| try: | |
| model, tokenizer = initialize_model() | |
| st.session_state.model = model | |
| st.session_state.tokenizer = tokenizer | |
| st.success("Model loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return | |
| # Display chat messages | |
| for message in st.session_state.chat_history: | |
| with st.chat_message("user"): | |
| st.write(message["user"]) | |
| with st.chat_message("assistant"): | |
| st.write(message["assistant"]) | |
| # Chat input | |
| if prompt := st.chat_input("What would you like to know?"): | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| # Generate and display assistant response | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| current_turn = {"user": prompt, "assistant": ""} | |
| st.session_state.chat_history.append(current_turn) | |
| response = generate_response( | |
| st.session_state.model, | |
| st.session_state.tokenizer, | |
| prompt, | |
| st.session_state.chat_history | |
| ) | |
| st.write(response) | |
| st.session_state.chat_history[-1]["assistant"] = response | |
| # Manage context window | |
| if len(st.session_state.chat_history) > 5: | |
| st.session_state.chat_history = st.session_state.chat_history[-5:] | |
| # Sidebar controls | |
| with st.sidebar: | |
| st.title("Controls") | |
| if st.button("Clear Chat"): | |
| st.session_state.chat_history = [] | |
| st.rerun() | |
| st.markdown("---") | |
| st.markdown(""" | |
| ### Model Info | |
| - Using TinyLlama 1.1B Chat | |
| - CPU optimized | |
| - Context window: 1024 tokens | |
| """) | |
| if __name__ == "__main__": | |
| main() |