import gradio as gr import spaces from transformers import pipeline import torch from typing import List, Dict, Optional # Global variable to store pipelines model_cache = {} # Available models AVAILABLE_MODELS = { "Apollo-1-4B": "Loom-Labs/Apollo-1-4B", "Apollo-1-8B": "Loom-Labs/Apollo-1-8B", "Apollo-1-2B": "Loom-Labs/Apollo-1-2B", "Daedalus-1-2B": "Loom-Labs/Daedalus-1-2B", "Daedalus-1-8B": "Loom-Labs/Daedalus-1-8B", } @spaces.GPU def initialize_model(model_name): global model_cache if model_name not in AVAILABLE_MODELS: raise ValueError(f"Model {model_name} not found in available models") model_id = AVAILABLE_MODELS[model_name] # Check if model is already cached if model_id not in model_cache: try: model_cache[model_id] = pipeline( "text-generation", model=model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) except Exception as e: # Fallback to CPU if GPU fails model_cache[model_id] = pipeline( "text-generation", model=model_id, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True ) return model_cache[model_id] @spaces.GPU def generate_response(message, history, model_name, max_length=512, temperature=0.7, top_p=0.9): """Generate response using the selected model""" # Initialize model inside the GPU-decorated function try: model_pipe = initialize_model(model_name) except Exception as e: return f"Error loading model {model_name}: {str(e)}" # Format the conversation history messages = [] # Add conversation history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Generate response try: # Some models may not support the messages format, so we'll try different approaches try: # Try with messages format first response = model_pipe( messages, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=model_pipe.tokenizer.eos_token_id, return_full_text=False ) except: # Fallback to simple text format conversation_text = "" for msg in messages: if msg["role"] == "user": conversation_text += f"User: {msg['content']}\n" else: conversation_text += f"Assistant: {msg['content']}\n" conversation_text += "Assistant:" response = model_pipe( conversation_text, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=model_pipe.tokenizer.eos_token_id, return_full_text=False ) # Extract the generated text if isinstance(response, list) and len(response) > 0: generated_text = response[0]['generated_text'] else: generated_text = str(response) # Clean up the response if isinstance(generated_text, list): assistant_response = generated_text[-1]['content'] else: # Remove the prompt and extract assistant response assistant_response = str(generated_text).strip() if "Assistant:" in assistant_response: assistant_response = assistant_response.split("Assistant:")[-1].strip() return assistant_response except Exception as e: return f"Error generating response: {str(e)}" @spaces.GPU def generate( model: str, user_input: str, history: Optional[str] = "", temperature: float = 0.7, system_prompt: Optional[str] = "", max_tokens: int = 512 ): """ API endpoint for LLM generation Args: model: Model name to use (Nous-1-2B, Nous-1-4B, or Nous-1-8B) user_input: Current user message/input history: JSON string of conversation history in format [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] temperature: Temperature for generation (0.1-2.0) system_prompt: System prompt to guide the model max_tokens: Maximum tokens to generate (1-8192) Returns: Generated response from the model """ # Validate model if model not in AVAILABLE_MODELS: return f"Error: Model {model} not available. Available models: {list(AVAILABLE_MODELS.keys())}" # Initialize model try: model_pipe = initialize_model(model) except Exception as e: return f"Error loading model {model}: {str(e)}" # Parse history if provided and convert to gradio format gradio_history = [] if history and history.strip(): try: import json history_list = json.loads(history) current_pair = [None, None] for msg in history_list: if isinstance(msg, dict) and "role" in msg and "content" in msg: if msg["role"] == "user": if current_pair[0] is not None: gradio_history.append([current_pair[0], current_pair[1]]) current_pair = [msg["content"], None] elif msg["role"] == "assistant": current_pair[1] = msg["content"] if current_pair[0] is not None: gradio_history.append([current_pair[0], current_pair[1]]) except: # If history parsing fails, continue without history pass # Add system prompt to user input if provided final_user_input = user_input if system_prompt and system_prompt.strip(): final_user_input = f"System: {system_prompt}\n\nUser: {user_input}" # Use the original generate_response function return generate_response(final_user_input, gradio_history, model, max_tokens, temperature, 0.9) # Create the Gradio interface def create_interface(): with gr.Blocks(title="Multi-Model Chat") as demo: gr.Markdown(""" # 🚀 Loom Labs Model Chat Interface Chat with the models by Loom Labs. **Available Models:** - Apollo-1-4B (4 billion parameters) - Apollo-1-8B (8 billion parameters) - Apollo-1-2B (2 billion parameters) - Daedalus-1-2B (2 billion parameters) - Daedalus-1-8B (8 billion parameters) """) with gr.Row(): model_selector = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value="Apollo-1-4B", label="Select Model", info="Choose which model to use for generation" ) chatbot = gr.Chatbot( height=400, placeholder="Select a model and start chatting...", label="Chat" ) msg = gr.Textbox( placeholder="Type your message here...", label="Message", lines=2 ) with gr.Row(): submit_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear Chat", variant="secondary") with gr.Accordion("Advanced Settings", open=False): max_length = gr.Slider( minimum=200, maximum=8192, value=2048, step=50, label="Max Length", info="Maximum length of generated response" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Controls randomness in generation" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P", info="Controls diversity via nucleus sampling" ) # Event handlers def user_message(message, history): return "", history + [[message, None]] def bot_response(history, model_name, max_len, temp, top_p): if history: user_message = history[-1][0] bot_message = generate_response( user_message, history[:-1], model_name, max_len, temp, top_p ) history[-1][1] = bot_message return history def model_changed(model_name): return gr.update(placeholder=f"Chat with {model_name}...") # Wire up the events msg.submit(user_message, [msg, chatbot], [msg, chatbot]).then( bot_response, [chatbot, model_selector, max_length, temperature, top_p], chatbot ) submit_btn.click(user_message, [msg, chatbot], [msg, chatbot]).then( bot_response, [chatbot, model_selector, max_length, temperature, top_p], chatbot ) clear_btn.click(lambda: None, None, chatbot, queue=False) model_selector.change(model_changed, model_selector, chatbot) return demo # Launch the app if __name__ == "__main__": demo = create_interface() # Enable API and launch demo.launch(share=True)