import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Load tokenizer and model model_name = "prithivMLmods/rStar-Coder-Qwen3-0.6B" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() model = torch.compile(model) if torch.cuda.is_available(): model = model.to("cuda") history = [] def stream_chat(user_input): global history history.append(f"User: {user_input}") context = "\n".join(history) + "\nBot:" # Tokenize input input_ids = tokenizer(context, return_tensors="pt").input_ids if torch.cuda.is_available(): input_ids = input_ids.to("cuda") # Generate token by token output_ids = input_ids.clone() bot_reply = "" max_new_tokens = 200 # adjust as needed for _ in range(max_new_tokens): with torch.no_grad(): outputs = model(output_ids) next_token_logits = outputs.logits[0, -1, :] next_token = torch.argmax(next_token_logits).unsqueeze(0) output_ids = torch.cat([output_ids, next_token.unsqueeze(0)], dim=1) token_str = tokenizer.decode(next_token) bot_reply += token_str # Yield streaming output yield bot_reply # Stop if EOS token if next_token.item() == tokenizer.eos_token_id: break history.append(f"Bot: {bot_reply}") # Gradio interface with gr.Blocks() as demo: chatbot_ui = gr.Chatbot() msg = gr.Textbox(placeholder="Type a message...") def respond(user_input, chat_history): chat_history.append((user_input, "")) for partial in stream_chat(user_input): chat_history[-1] = (user_input, partial) yield chat_history, chat_history state = gr.State([]) msg.submit(respond, [msg, state], [chatbot_ui, state]) demo.launch()