hotmemeh commited on
Commit
a18f23e
·
verified ·
1 Parent(s): cbb85ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -1,22 +1,39 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
3
 
4
- # Load a Hugging Face text generation model (swap model as needed)
5
- generator = pipeline("text-generation", model="gpt2")
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Streaming response function
8
  def respond(message, history):
9
- # Simulate token-by-token streaming from the model
10
- output = generator(message, max_length=200, num_return_sequences=1)[0]["generated_text"]
11
-
12
- # Yield chunks progressively
 
 
 
 
 
 
13
  for i in range(0, len(output), 20):
14
  yield {"role": "assistant", "content": output[: i + 20]}
15
 
16
- # Build the UI
17
  chat = gr.ChatInterface(
18
  fn=respond,
19
- type="messages", # use new OpenAI-style format
20
  chatbot=gr.Chatbot(height=600, show_copy_button=True, type="messages"),
21
  )
22
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import torch
4
 
5
+ # Auto-select model based on device
6
+ if torch.cuda.is_available():
7
+ MODEL_NAME = "tiiuae/falcon-7b-instruct" # GPU model
8
+ device = 0
9
+ else:
10
+ MODEL_NAME = "gpt2" # CPU fallback
11
+ device = -1
12
+
13
+ print(f"Loading model: {MODEL_NAME} on {'GPU' if device == 0 else 'CPU'}")
14
+
15
+ # Load Hugging Face pipeline
16
+ generator = pipeline("text-generation", model=MODEL_NAME, device=device)
17
 
18
  # Streaming response function
19
  def respond(message, history):
20
+ output = generator(
21
+ message,
22
+ max_new_tokens=256, # use this instead of max_length
23
+ num_return_sequences=1,
24
+ do_sample=True,
25
+ temperature=0.7,
26
+ truncation=True, # fixes truncation warning
27
+ )[0]["generated_text"]
28
+
29
+ # Stream output in chunks
30
  for i in range(0, len(output), 20):
31
  yield {"role": "assistant", "content": output[: i + 20]}
32
 
33
+ # Build the Gradio chat
34
  chat = gr.ChatInterface(
35
  fn=respond,
36
+ type="messages",
37
  chatbot=gr.Chatbot(height=600, show_copy_button=True, type="messages"),
38
  )
39