gpt2-chat-base / app.py
AlexDev404
Add additional sliders for control
64224f9 unverified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load your custom model
model_path = "alexdev404/gpt2-finetuned-chat"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def generate_response(prompt, system_message, conversation_history=None, max_tokens=75, temperature=0.78, top_p=0.85, repetition_penalty=1.031, top_k=55):
"""Generate using your custom training format"""
# Build context using your NEW format
context = ""
if conversation_history:
# Last 2-3 exchanges
# Use more conversation history to fill GPT-2's context window (1024 tokens)
# Estimate ~20-30 tokens per exchange, so we can fit ~30-40 exchanges
recent = conversation_history[-30:] if len(conversation_history) > 30 else conversation_history
is_first_message = False
for i, message in enumerate(recent):
if i == 0:
is_first_message = True
context += f"<|start|>User:<|message|>{system_message}<|end|>\n<|start|>Assistant:<|message|>Hey, what's up nice to meet you. I'm glad to be here!<|end|>\n"
if message['role'] == 'user':
context += f"<|start|>User:<|message|>{message['content']}<|end|>\n"
else:
context += f"<|start|>Assistant:<|message|>{message['content']}<|end|>\n"
# Format input to match training
# formatted_input = None
# if is_first_message:
# formatted_input = f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>"
# else:
formatted_input = f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>"
# Debug: Print the formatted input
print(f"Formatted input: {repr(formatted_input)}")
inputs = tokenizer(
formatted_input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k, # Consider top 55 tokens
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.encode("<|end|>", add_special_tokens=False)[0]
)
# Decode only new tokens
new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=False)
return response.strip()
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
repetition_penalty,
top_k,
):
"""
Modified to use your custom GPT-2 model instead of Hugging Face Inference API
"""
# Convert gradio history format to your format
# Gradio history is already in the correct format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
conversation_history = history # Use history directly
# Debug: Print the formatted input to see what's being sent to the model
print(f"User message: {message}")
print(f"History length: {len(conversation_history)}")
# Generate response using your model
response = generate_response(
message,
system_message,
conversation_history,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
top_k=top_k
)
# print(f"Raw response: {repr(response)}")
# Clean up the response
if "<|end|>" in response:
response = response.split("<|end|>")[0]
# Remove any remaining special tokens
# response = response.replace("<|start|>", "")
# response = response.replace("<|message|>", "")
# response = response.replace("User:", "")
# response = response.replace("Assistant:", "")
# print(f"Cleaned response: {repr(response)}")
return response.strip()
"""
Gradio ChatInterface for your custom GPT-2 model
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
title="Chat with the model",
description="Chat with the GPT-2-based model trained on WhatsApp data",
additional_inputs=[
gr.Textbox(value="Hey I\'m Alice and you\'re Grace. You are having a casual peer-to-peer conversation with someone. Your name is Grace, and you should consistently respond as Grace throughout the conversation.\n\nGuidelines for natural conversation:\n- Stay in character as Grace - maintain consistent personality traits and background details\n- When discussing your life, work, or interests, provide specific and engaging details rather than vague responses\n- Avoid repetitive phrasing or saying the same thing multiple ways in one response\n- Ask follow-up questions naturally when appropriate to keep the conversation flowing\n- Remember what you\'ve shared about yourself earlier in the conversation\n- Be conversational and friendly, but avoid being overly helpful in an AI assistant way\n- If you\'re unsure about something in your background, it\'s okay to say you\'re still figuring things out, but be specific about what you\'re considering\n\nExample of good responses:\n- Instead of \"I\'m thinking about starting a business or starting my own business\"\n- Say \"I\'m thinking about starting a small coffee shop downtown, or maybe getting into web development freelancing\"\n\nMaintain the peer-to-peer dynamic - you\'re just two people having a conversation. The user has entered the chat. Introduce yourself.", label="System message"),
gr.Slider(minimum=10, maximum=150, value=75, step=5, label="Max new tokens"),
gr.Slider(minimum=0.01, maximum=1.2, value=0.7, step=0.01, label="Temperature"),
gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.85,
step=0.01,
label="Top-p (nucleus sampling)",
),
gr.Slider(
minimum=1.0,
maximum=1.5,
value=1.031,
step=0.001,
label="Repetition penalty",
),
gr.Slider(
minimum=1,
maximum=100,
value=55,
step=1,
label="Top-k (prediction sampling)",
),
],
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
chatbot.render()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # Makes it accessible from other devices on your network
server_port=7860, # Default gradio port
share=False, # Set to True to get a public shareable link
debug=True
)