ThongCoder commited on
Commit
15776f9
·
verified ·
1 Parent(s): 8116812

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -25
app.py CHANGED
@@ -1,29 +1,48 @@
1
- from transformers import pipeline
2
  import gradio as gr
3
- import time
 
4
 
5
- # Load model
6
- pipe = pipeline("text-generation", model="prithivMLmods/rStar-Coder-Qwen3-0.6B")
 
 
 
 
 
 
7
 
8
  history = []
9
 
10
- def chat_fn_stream(user_input):
11
  global history
12
  history.append(f"User: {user_input}")
13
  context = "\n".join(history) + "\nBot:"
14
 
15
- # Use a generator for streaming
16
- for i in range(0, 8192, 20): # fake streaming in chunks
17
- output = pipe(
18
- context,
19
- max_new_tokens=i+20,
20
- do_sample=True,
21
- top_p=0.9,
22
- return_full_text=False
23
- )[0]['generated_text']
24
- bot_reply = output.split("Bot:")[-1].strip()
25
- yield bot_reply # stream partial reply
26
- time.sleep(0.1) # small delay to simulate streaming
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  history.append(f"Bot: {bot_reply}")
29
 
@@ -33,16 +52,11 @@ with gr.Blocks() as demo:
33
  msg = gr.Textbox(placeholder="Type a message...")
34
 
35
  def respond(user_input, chat_history):
36
- bot_reply = ""
37
- # Start by adding the user input
38
- chat_history.append((user_input, "")) # empty bot reply for now
39
- for partial in chat_fn_stream(user_input):
40
- bot_reply = partial
41
- # Update the last bot reply
42
- chat_history[-1] = (user_input, bot_reply)
43
  yield chat_history, chat_history
44
 
45
-
46
  state = gr.State([])
47
  msg.submit(respond, [msg, state], [chatbot_ui, state])
48
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # Load tokenizer and model
6
+ model_name = "prithivMLmods/rStar-Coder-Qwen3-0.6B"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+ model.eval()
10
+
11
+ if torch.cuda.is_available():
12
+ model = model.to("cuda")
13
 
14
  history = []
15
 
16
+ def stream_chat(user_input):
17
  global history
18
  history.append(f"User: {user_input}")
19
  context = "\n".join(history) + "\nBot:"
20
 
21
+ # Tokenize input
22
+ input_ids = tokenizer(context, return_tensors="pt").input_ids
23
+ if torch.cuda.is_available():
24
+ input_ids = input_ids.to("cuda")
25
+
26
+ # Generate token by token
27
+ output_ids = input_ids.clone()
28
+ bot_reply = ""
29
+ max_new_tokens = 200 # adjust as needed
30
+
31
+ for _ in range(max_new_tokens):
32
+ with torch.no_grad():
33
+ outputs = model(output_ids)
34
+ next_token_logits = outputs.logits[0, -1, :]
35
+ next_token = torch.argmax(next_token_logits).unsqueeze(0)
36
+ output_ids = torch.cat([output_ids, next_token.unsqueeze(0)], dim=1)
37
+ token_str = tokenizer.decode(next_token)
38
+ bot_reply += token_str
39
+
40
+ # Yield streaming output
41
+ yield bot_reply
42
+
43
+ # Stop if EOS token
44
+ if next_token.item() == tokenizer.eos_token_id:
45
+ break
46
 
47
  history.append(f"Bot: {bot_reply}")
48
 
 
52
  msg = gr.Textbox(placeholder="Type a message...")
53
 
54
  def respond(user_input, chat_history):
55
+ chat_history.append((user_input, ""))
56
+ for partial in stream_chat(user_input):
57
+ chat_history[-1] = (user_input, partial)
 
 
 
 
58
  yield chat_history, chat_history
59
 
 
60
  state = gr.State([])
61
  msg.submit(respond, [msg, state], [chatbot_ui, state])
62