wassemgtk commited on
Commit
43c6ebf
·
verified ·
1 Parent(s): 966dd1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -29
app.py CHANGED
@@ -1,7 +1,6 @@
1
 
2
  import os
3
  import json
4
- import time
5
  import requests
6
  import gradio as gr
7
 
@@ -22,8 +21,7 @@ def _fireworks_stream(payload):
22
  "Content-Type": "application/json",
23
  "Authorization": f"Bearer {FIREWORKS_API_KEY}",
24
  }
25
- # Ensure we stream
26
- payload = dict(payload) # shallow copy
27
  payload["stream"] = True
28
  with requests.post(FIREWORKS_URL, headers=headers, json=payload, stream=True) as r:
29
  r.raise_for_status()
@@ -38,33 +36,41 @@ def _fireworks_stream(payload):
38
  try:
39
  obj = json.loads(data)
40
  except json.JSONDecodeError:
41
- # In case of partial line; accumulate
42
  buffer += data
43
  try:
44
  obj = json.loads(buffer)
45
  buffer = ""
46
  except Exception:
47
  continue
48
- # Fireworks streams OpenAI-style deltas
49
  try:
50
  delta = obj["choices"][0]["delta"]
51
  if "content" in delta and delta["content"]:
52
  yield delta["content"]
53
  except Exception:
54
- # Some events may be role changes or tool calls; ignore silently
55
  continue
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def _build_messages(history, user_message):
58
  messages = []
59
- # Insert a hidden system message from server-side secret
60
  if SYSTEM_PROMPT:
61
  messages.append({"role": "system", "content": SYSTEM_PROMPT})
62
- # History from Gradio ChatInterface comes as list of (user, assistant) tuples
63
- for u, a in history:
64
- if u:
65
- messages.append({"role": "user", "content": u})
66
- if a:
67
- messages.append({"role": "assistant", "content": a})
68
  if user_message:
69
  messages.append({"role": "user", "content": user_message})
70
  return messages
@@ -80,7 +86,6 @@ def chat_fn(user_message, history, max_tokens, temperature, top_p, top_k, presen
80
  "frequency_penalty": float(frequency_penalty),
81
  "messages": _build_messages(history, user_message),
82
  }
83
- # Stream tokens back to the UI
84
  for token in _fireworks_stream(payload):
85
  yield token
86
 
@@ -96,14 +101,15 @@ div.controls { gap: 10px !important; }
96
  <div style="display:flex; align-items:center; gap:12px; margin: 6px 0 16px;">
97
  <svg width="28" height="28" viewBox="0 0 24 24" fill="none"><path d="M12 3l7 4v6c0 5-7 8-7 8s-7-3-7-8V7l7-4z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/></svg>
98
  <div>
99
- <div id="title" style="font-size:1.25rem;">Palmyra-sec Chat Playground</div>
100
  <div style="opacity:0.7; font-size:0.95rem;">Secure, streamed chat to <code>inference/v1/chat/completions</code></div>
101
  </div>
102
  </div>
103
  """)
104
  with gr.Row():
105
  with gr.Column(scale=3):
106
- chatbot = gr.Chatbot(height=480, avatar_images=(None, None), bubble_full_width=False)
 
107
  with gr.Row(elem_classes=["controls"]):
108
  max_tokens = gr.Slider(32, 8192, value=4000, step=16, label="Max tokens")
109
  temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.05, label="Temperature")
@@ -115,8 +121,7 @@ div.controls { gap: 10px !important; }
115
  frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="frequency_penalty")
116
  gr.Markdown("""
117
  **Security notes**
118
- - Your API key and system prompt are kept on the server as environment variables.
119
- - They are never shown in the UI or sent to the browser.
120
  - Change the model id with `FIREWORKS_MODEL_ID` (env var).
121
  """)
122
  clear_btn = gr.Button("Clear", variant="secondary")
@@ -125,22 +130,11 @@ div.controls { gap: 10px !important; }
125
  chatbot=chatbot,
126
  additional_inputs=[max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty],
127
  title=None,
128
- retry_btn=None,
129
- undo_btn="Undo last",
130
- clear_btn=None,
131
  submit_btn="Send",
132
- autofocus=True,
133
- fill_height=False,
134
- cache_examples=False,
135
- concurrency_limit=10,
136
- multimodal=False,
137
- analytics_enabled=False,
138
- enable_queue=True,
139
  examples=["Hello!", "Summarize: Why is retrieval-augmented generation useful for insurers?", "Write a 3-bullet status update for the Palmyra team."],
140
  description="Start chatting below. Streaming is enabled."
141
  )
142
  clear_btn.click(fn=clear_history, outputs=chatbot)
143
 
144
  if __name__ == "__main__":
145
- # Use 0.0.0.0 for container friendliness; set GRADIO_SERVER_PORT externally if needed
146
  demo.queue().launch(server_name="0.0.0.0")
 
1
 
2
  import os
3
  import json
 
4
  import requests
5
  import gradio as gr
6
 
 
21
  "Content-Type": "application/json",
22
  "Authorization": f"Bearer {FIREWORKS_API_KEY}",
23
  }
24
+ payload = dict(payload)
 
25
  payload["stream"] = True
26
  with requests.post(FIREWORKS_URL, headers=headers, json=payload, stream=True) as r:
27
  r.raise_for_status()
 
36
  try:
37
  obj = json.loads(data)
38
  except json.JSONDecodeError:
 
39
  buffer += data
40
  try:
41
  obj = json.loads(buffer)
42
  buffer = ""
43
  except Exception:
44
  continue
 
45
  try:
46
  delta = obj["choices"][0]["delta"]
47
  if "content" in delta and delta["content"]:
48
  yield delta["content"]
49
  except Exception:
 
50
  continue
51
 
52
+ def _normalize_history_to_messages(history):
53
+ """Normalize history from Gradio into OpenAI-style messages without system prompt."""
54
+ # Chatbot(type='messages') already gives a list of dicts: [{'role': 'user'|'assistant', 'content': '...'}, ...]
55
+ if not history:
56
+ return []
57
+ if isinstance(history, list) and len(history) > 0 and isinstance(history[0], dict) and "role" in history[0]:
58
+ # Already messages format; pass through (filter any roles other than user/assistant)
59
+ return [m for m in history if m.get("role") in ("user", "assistant")]
60
+ # Back-compat: history may be list of (user, assistant) tuples
61
+ msgs = []
62
+ for u, a in history:
63
+ if u:
64
+ msgs.append({"role": "user", "content": u})
65
+ if a:
66
+ msgs.append({"role": "assistant", "content": a})
67
+ return msgs
68
+
69
  def _build_messages(history, user_message):
70
  messages = []
 
71
  if SYSTEM_PROMPT:
72
  messages.append({"role": "system", "content": SYSTEM_PROMPT})
73
+ messages.extend(_normalize_history_to_messages(history))
 
 
 
 
 
74
  if user_message:
75
  messages.append({"role": "user", "content": user_message})
76
  return messages
 
86
  "frequency_penalty": float(frequency_penalty),
87
  "messages": _build_messages(history, user_message),
88
  }
 
89
  for token in _fireworks_stream(payload):
90
  yield token
91
 
 
101
  <div style="display:flex; align-items:center; gap:12px; margin: 6px 0 16px;">
102
  <svg width="28" height="28" viewBox="0 0 24 24" fill="none"><path d="M12 3l7 4v6c0 5-7 8-7 8s-7-3-7-8V7l7-4z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/></svg>
103
  <div>
104
+ <div id="title" style="font-size:1.25rem;">Fireworks Chat Playground</div>
105
  <div style="opacity:0.7; font-size:0.95rem;">Secure, streamed chat to <code>inference/v1/chat/completions</code></div>
106
  </div>
107
  </div>
108
  """)
109
  with gr.Row():
110
  with gr.Column(scale=3):
111
+ # Use messages format to avoid deprecation
112
+ chatbot = gr.Chatbot(height=480, type="messages", avatar_images=(None, None))
113
  with gr.Row(elem_classes=["controls"]):
114
  max_tokens = gr.Slider(32, 8192, value=4000, step=16, label="Max tokens")
115
  temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.05, label="Temperature")
 
121
  frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="frequency_penalty")
122
  gr.Markdown("""
123
  **Security notes**
124
+ - API key and system prompt are server-side environment variables.
 
125
  - Change the model id with `FIREWORKS_MODEL_ID` (env var).
126
  """)
127
  clear_btn = gr.Button("Clear", variant="secondary")
 
130
  chatbot=chatbot,
131
  additional_inputs=[max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty],
132
  title=None,
 
 
 
133
  submit_btn="Send",
 
 
 
 
 
 
 
134
  examples=["Hello!", "Summarize: Why is retrieval-augmented generation useful for insurers?", "Write a 3-bullet status update for the Palmyra team."],
135
  description="Start chatting below. Streaming is enabled."
136
  )
137
  clear_btn.click(fn=clear_history, outputs=chatbot)
138
 
139
  if __name__ == "__main__":
 
140
  demo.queue().launch(server_name="0.0.0.0")