Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Load your custom model | |
| model_path = "alexdev404/gpt2-finetuned-chat" | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForCausalLM.from_pretrained(model_path) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| def generate_response(prompt, system_message, conversation_history=None, max_tokens=75, temperature=0.78, top_p=0.85, repetition_penalty=1.031, top_k=55): | |
| """Generate using your custom training format""" | |
| # Build context using your NEW format | |
| context = "" | |
| if conversation_history: | |
| # Last 2-3 exchanges | |
| # Use more conversation history to fill GPT-2's context window (1024 tokens) | |
| # Estimate ~20-30 tokens per exchange, so we can fit ~30-40 exchanges | |
| recent = conversation_history[-30:] if len(conversation_history) > 30 else conversation_history | |
| is_first_message = False | |
| for i, message in enumerate(recent): | |
| if i == 0: | |
| is_first_message = True | |
| context += f"<|start|>User:<|message|>{system_message}<|end|>\n<|start|>Assistant:<|message|>Hey, what's up nice to meet you. I'm glad to be here!<|end|>\n" | |
| if message['role'] == 'user': | |
| context += f"<|start|>User:<|message|>{message['content']}<|end|>\n" | |
| else: | |
| context += f"<|start|>Assistant:<|message|>{message['content']}<|end|>\n" | |
| # Format input to match training | |
| # formatted_input = None | |
| # if is_first_message: | |
| # formatted_input = f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>" | |
| # else: | |
| formatted_input = f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>" | |
| # Debug: Print the formatted input | |
| print(f"Formatted input: {repr(formatted_input)}") | |
| inputs = tokenizer( | |
| formatted_input, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, # Consider top 55 tokens | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| repetition_penalty=repetition_penalty, | |
| eos_token_id=tokenizer.encode("<|end|>", add_special_tokens=False)[0] | |
| ) | |
| # Decode only new tokens | |
| new_tokens = outputs[0][inputs.input_ids.shape[-1]:] | |
| response = tokenizer.decode(new_tokens, skip_special_tokens=False) | |
| return response.strip() | |
| def respond( | |
| message, | |
| history: list[dict[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| repetition_penalty, | |
| top_k, | |
| ): | |
| """ | |
| Modified to use your custom GPT-2 model instead of Hugging Face Inference API | |
| """ | |
| # Convert gradio history format to your format | |
| # Gradio history is already in the correct format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] | |
| conversation_history = history # Use history directly | |
| # Debug: Print the formatted input to see what's being sent to the model | |
| print(f"User message: {message}") | |
| print(f"History length: {len(conversation_history)}") | |
| # Generate response using your model | |
| response = generate_response( | |
| message, | |
| system_message, | |
| conversation_history, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| top_k=top_k | |
| ) | |
| # print(f"Raw response: {repr(response)}") | |
| # Clean up the response | |
| if "<|end|>" in response: | |
| response = response.split("<|end|>")[0] | |
| # Remove any remaining special tokens | |
| # response = response.replace("<|start|>", "") | |
| # response = response.replace("<|message|>", "") | |
| # response = response.replace("User:", "") | |
| # response = response.replace("Assistant:", "") | |
| # print(f"Cleaned response: {repr(response)}") | |
| return response.strip() | |
| """ | |
| Gradio ChatInterface for your custom GPT-2 model | |
| """ | |
| chatbot = gr.ChatInterface( | |
| respond, | |
| type="messages", | |
| title="Chat with the model", | |
| description="Chat with the GPT-2-based model trained on WhatsApp data", | |
| additional_inputs=[ | |
| gr.Textbox(value="Hey I\'m Alice and you\'re Grace. You are having a casual peer-to-peer conversation with someone. Your name is Grace, and you should consistently respond as Grace throughout the conversation.\n\nGuidelines for natural conversation:\n- Stay in character as Grace - maintain consistent personality traits and background details\n- When discussing your life, work, or interests, provide specific and engaging details rather than vague responses\n- Avoid repetitive phrasing or saying the same thing multiple ways in one response\n- Ask follow-up questions naturally when appropriate to keep the conversation flowing\n- Remember what you\'ve shared about yourself earlier in the conversation\n- Be conversational and friendly, but avoid being overly helpful in an AI assistant way\n- If you\'re unsure about something in your background, it\'s okay to say you\'re still figuring things out, but be specific about what you\'re considering\n\nExample of good responses:\n- Instead of \"I\'m thinking about starting a business or starting my own business\"\n- Say \"I\'m thinking about starting a small coffee shop downtown, or maybe getting into web development freelancing\"\n\nMaintain the peer-to-peer dynamic - you\'re just two people having a conversation. The user has entered the chat. Introduce yourself.", label="System message"), | |
| gr.Slider(minimum=10, maximum=150, value=75, step=5, label="Max new tokens"), | |
| gr.Slider(minimum=0.01, maximum=1.2, value=0.7, step=0.01, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.01, | |
| maximum=1.0, | |
| value=0.85, | |
| step=0.01, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| gr.Slider( | |
| minimum=1.0, | |
| maximum=1.5, | |
| value=1.031, | |
| step=0.001, | |
| label="Repetition penalty", | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=55, | |
| step=1, | |
| label="Top-k (prediction sampling)", | |
| ), | |
| ], | |
| ) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| chatbot.render() | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", # Makes it accessible from other devices on your network | |
| server_port=7860, # Default gradio port | |
| share=False, # Set to True to get a public shareable link | |
| debug=True | |
| ) | |