Spaces:
Sleeping
Sleeping
File size: 6,958 Bytes
44cf422 64224f9 44cf422 64224f9 44cf422 64224f9 44cf422 64224f9 44cf422 64224f9 44cf422 ed90480 44cf422 ed90480 44cf422 ed90480 44cf422 64224f9 44cf422 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load your custom model
model_path = "alexdev404/gpt2-finetuned-chat"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def generate_response(prompt, system_message, conversation_history=None, max_tokens=75, temperature=0.78, top_p=0.85, repetition_penalty=1.031, top_k=55):
"""Generate using your custom training format"""
# Build context using your NEW format
context = ""
if conversation_history:
# Last 2-3 exchanges
# Use more conversation history to fill GPT-2's context window (1024 tokens)
# Estimate ~20-30 tokens per exchange, so we can fit ~30-40 exchanges
recent = conversation_history[-30:] if len(conversation_history) > 30 else conversation_history
is_first_message = False
for i, message in enumerate(recent):
if i == 0:
is_first_message = True
context += f"<|start|>User:<|message|>{system_message}<|end|>\n<|start|>Assistant:<|message|>Hey, what's up nice to meet you. I'm glad to be here!<|end|>\n"
if message['role'] == 'user':
context += f"<|start|>User:<|message|>{message['content']}<|end|>\n"
else:
context += f"<|start|>Assistant:<|message|>{message['content']}<|end|>\n"
# Format input to match training
# formatted_input = None
# if is_first_message:
# formatted_input = f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>"
# else:
formatted_input = f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>"
# Debug: Print the formatted input
print(f"Formatted input: {repr(formatted_input)}")
inputs = tokenizer(
formatted_input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k, # Consider top 55 tokens
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.encode("<|end|>", add_special_tokens=False)[0]
)
# Decode only new tokens
new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
response = tokenizer.decode(new_tokens, skip_special_tokens=False)
return response.strip()
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
repetition_penalty,
top_k,
):
"""
Modified to use your custom GPT-2 model instead of Hugging Face Inference API
"""
# Convert gradio history format to your format
# Gradio history is already in the correct format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
conversation_history = history # Use history directly
# Debug: Print the formatted input to see what's being sent to the model
print(f"User message: {message}")
print(f"History length: {len(conversation_history)}")
# Generate response using your model
response = generate_response(
message,
system_message,
conversation_history,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
top_k=top_k
)
# print(f"Raw response: {repr(response)}")
# Clean up the response
if "<|end|>" in response:
response = response.split("<|end|>")[0]
# Remove any remaining special tokens
# response = response.replace("<|start|>", "")
# response = response.replace("<|message|>", "")
# response = response.replace("User:", "")
# response = response.replace("Assistant:", "")
# print(f"Cleaned response: {repr(response)}")
return response.strip()
"""
Gradio ChatInterface for your custom GPT-2 model
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
title="Chat with the model",
description="Chat with the GPT-2-based model trained on WhatsApp data",
additional_inputs=[
gr.Textbox(value="Hey I\'m Alice and you\'re Grace. You are having a casual peer-to-peer conversation with someone. Your name is Grace, and you should consistently respond as Grace throughout the conversation.\n\nGuidelines for natural conversation:\n- Stay in character as Grace - maintain consistent personality traits and background details\n- When discussing your life, work, or interests, provide specific and engaging details rather than vague responses\n- Avoid repetitive phrasing or saying the same thing multiple ways in one response\n- Ask follow-up questions naturally when appropriate to keep the conversation flowing\n- Remember what you\'ve shared about yourself earlier in the conversation\n- Be conversational and friendly, but avoid being overly helpful in an AI assistant way\n- If you\'re unsure about something in your background, it\'s okay to say you\'re still figuring things out, but be specific about what you\'re considering\n\nExample of good responses:\n- Instead of \"I\'m thinking about starting a business or starting my own business\"\n- Say \"I\'m thinking about starting a small coffee shop downtown, or maybe getting into web development freelancing\"\n\nMaintain the peer-to-peer dynamic - you\'re just two people having a conversation. The user has entered the chat. Introduce yourself.", label="System message"),
gr.Slider(minimum=10, maximum=150, value=75, step=5, label="Max new tokens"),
gr.Slider(minimum=0.01, maximum=1.2, value=0.7, step=0.01, label="Temperature"),
gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.85,
step=0.01,
label="Top-p (nucleus sampling)",
),
gr.Slider(
minimum=1.0,
maximum=1.5,
value=1.031,
step=0.001,
label="Repetition penalty",
),
gr.Slider(
minimum=1,
maximum=100,
value=55,
step=1,
label="Top-k (prediction sampling)",
),
],
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
chatbot.render()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # Makes it accessible from other devices on your network
server_port=7860, # Default gradio port
share=False, # Set to True to get a public shareable link
debug=True
)
|