Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import os | |
| hf_token = os.getenv("YOUR_HF_TOKEN") | |
| # Load model and tokenizer | |
| print("Loading model and tokenizer...") | |
| model_path = "microsoft/Phi-4-mini-instruct" # Can be changed to local path "./Phi-4-Mini-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| padding_side="left", | |
| token=hf_token, | |
| trust_remote_code=True | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| attn_implementation="eager", # "flash_attention_2", | |
| torch_dtype="auto", | |
| token=hf_token, | |
| trust_remote_code=True | |
| ) | |
| # Create pipeline for easier inference | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| ) | |
| print("Model and tokenizer loaded successfully!") | |
| # Format chat history to messages format | |
| def format_chat_history(message, history): | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful AI assistant."} | |
| ] | |
| # Add chat history | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| return messages | |
| # Streaming response generator | |
| def predict(message, history): | |
| messages = format_chat_history(message, history) | |
| generation_args = { | |
| "max_new_tokens": 1024, | |
| "return_full_text": False, | |
| "temperature": 0.001, | |
| "top_p": 1.0, | |
| "do_sample": True, | |
| "streamer": None, # Will be set in the generator | |
| } | |
| # Initialize an empty response | |
| partial_message = "" | |
| history_with_message = history + [[message, partial_message]] | |
| # Create a TextIteratorStreamer for streaming generation | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_args["streamer"] = streamer | |
| # Start a separate thread for generation | |
| thread = Thread(target=pipe, args=(messages,), kwargs=generation_args) | |
| thread.start() | |
| # Stream the response | |
| for new_text in streamer: | |
| partial_message += new_text | |
| yield history + [[message, partial_message]] | |
| # Create the Gradio interface | |
| css = """ | |
| .chatbot-container {max-width: 800px; margin: auto;} | |
| .chat-header {text-align: center; margin-bottom: 20px;} | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML("<div class='chat-header'><h1>Phi-4 Mini Chatbot</h1></div>") | |
| with gr.Column(elem_classes="chatbot-container"): | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(placeholder="Type your message here...", label="Input") | |
| clear = gr.Button("Clear Conversation") | |
| msg.submit(predict, [msg, chatbot], [chatbot], queue=True, api_name="chat").then( | |
| lambda: "", None, [msg] | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| # Launch the app | |
| demo.launch(share=True) # Set share=False if you don't want a public link |