Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
|
| 2 |
import os
|
| 3 |
import json
|
| 4 |
-
import time
|
| 5 |
import requests
|
| 6 |
import gradio as gr
|
| 7 |
|
|
@@ -22,8 +21,7 @@ def _fireworks_stream(payload):
|
|
| 22 |
"Content-Type": "application/json",
|
| 23 |
"Authorization": f"Bearer {FIREWORKS_API_KEY}",
|
| 24 |
}
|
| 25 |
-
|
| 26 |
-
payload = dict(payload) # shallow copy
|
| 27 |
payload["stream"] = True
|
| 28 |
with requests.post(FIREWORKS_URL, headers=headers, json=payload, stream=True) as r:
|
| 29 |
r.raise_for_status()
|
|
@@ -38,33 +36,41 @@ def _fireworks_stream(payload):
|
|
| 38 |
try:
|
| 39 |
obj = json.loads(data)
|
| 40 |
except json.JSONDecodeError:
|
| 41 |
-
# In case of partial line; accumulate
|
| 42 |
buffer += data
|
| 43 |
try:
|
| 44 |
obj = json.loads(buffer)
|
| 45 |
buffer = ""
|
| 46 |
except Exception:
|
| 47 |
continue
|
| 48 |
-
# Fireworks streams OpenAI-style deltas
|
| 49 |
try:
|
| 50 |
delta = obj["choices"][0]["delta"]
|
| 51 |
if "content" in delta and delta["content"]:
|
| 52 |
yield delta["content"]
|
| 53 |
except Exception:
|
| 54 |
-
# Some events may be role changes or tool calls; ignore silently
|
| 55 |
continue
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def _build_messages(history, user_message):
|
| 58 |
messages = []
|
| 59 |
-
# Insert a hidden system message from server-side secret
|
| 60 |
if SYSTEM_PROMPT:
|
| 61 |
messages.append({"role": "system", "content": SYSTEM_PROMPT})
|
| 62 |
-
|
| 63 |
-
for u, a in history:
|
| 64 |
-
if u:
|
| 65 |
-
messages.append({"role": "user", "content": u})
|
| 66 |
-
if a:
|
| 67 |
-
messages.append({"role": "assistant", "content": a})
|
| 68 |
if user_message:
|
| 69 |
messages.append({"role": "user", "content": user_message})
|
| 70 |
return messages
|
|
@@ -80,7 +86,6 @@ def chat_fn(user_message, history, max_tokens, temperature, top_p, top_k, presen
|
|
| 80 |
"frequency_penalty": float(frequency_penalty),
|
| 81 |
"messages": _build_messages(history, user_message),
|
| 82 |
}
|
| 83 |
-
# Stream tokens back to the UI
|
| 84 |
for token in _fireworks_stream(payload):
|
| 85 |
yield token
|
| 86 |
|
|
@@ -96,14 +101,15 @@ div.controls { gap: 10px !important; }
|
|
| 96 |
<div style="display:flex; align-items:center; gap:12px; margin: 6px 0 16px;">
|
| 97 |
<svg width="28" height="28" viewBox="0 0 24 24" fill="none"><path d="M12 3l7 4v6c0 5-7 8-7 8s-7-3-7-8V7l7-4z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/></svg>
|
| 98 |
<div>
|
| 99 |
-
<div id="title" style="font-size:1.25rem;">
|
| 100 |
<div style="opacity:0.7; font-size:0.95rem;">Secure, streamed chat to <code>inference/v1/chat/completions</code></div>
|
| 101 |
</div>
|
| 102 |
</div>
|
| 103 |
""")
|
| 104 |
with gr.Row():
|
| 105 |
with gr.Column(scale=3):
|
| 106 |
-
|
|
|
|
| 107 |
with gr.Row(elem_classes=["controls"]):
|
| 108 |
max_tokens = gr.Slider(32, 8192, value=4000, step=16, label="Max tokens")
|
| 109 |
temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.05, label="Temperature")
|
|
@@ -115,8 +121,7 @@ div.controls { gap: 10px !important; }
|
|
| 115 |
frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="frequency_penalty")
|
| 116 |
gr.Markdown("""
|
| 117 |
**Security notes**
|
| 118 |
-
-
|
| 119 |
-
- They are never shown in the UI or sent to the browser.
|
| 120 |
- Change the model id with `FIREWORKS_MODEL_ID` (env var).
|
| 121 |
""")
|
| 122 |
clear_btn = gr.Button("Clear", variant="secondary")
|
|
@@ -125,22 +130,11 @@ div.controls { gap: 10px !important; }
|
|
| 125 |
chatbot=chatbot,
|
| 126 |
additional_inputs=[max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty],
|
| 127 |
title=None,
|
| 128 |
-
retry_btn=None,
|
| 129 |
-
undo_btn="Undo last",
|
| 130 |
-
clear_btn=None,
|
| 131 |
submit_btn="Send",
|
| 132 |
-
autofocus=True,
|
| 133 |
-
fill_height=False,
|
| 134 |
-
cache_examples=False,
|
| 135 |
-
concurrency_limit=10,
|
| 136 |
-
multimodal=False,
|
| 137 |
-
analytics_enabled=False,
|
| 138 |
-
enable_queue=True,
|
| 139 |
examples=["Hello!", "Summarize: Why is retrieval-augmented generation useful for insurers?", "Write a 3-bullet status update for the Palmyra team."],
|
| 140 |
description="Start chatting below. Streaming is enabled."
|
| 141 |
)
|
| 142 |
clear_btn.click(fn=clear_history, outputs=chatbot)
|
| 143 |
|
| 144 |
if __name__ == "__main__":
|
| 145 |
-
# Use 0.0.0.0 for container friendliness; set GRADIO_SERVER_PORT externally if needed
|
| 146 |
demo.queue().launch(server_name="0.0.0.0")
|
|
|
|
| 1 |
|
| 2 |
import os
|
| 3 |
import json
|
|
|
|
| 4 |
import requests
|
| 5 |
import gradio as gr
|
| 6 |
|
|
|
|
| 21 |
"Content-Type": "application/json",
|
| 22 |
"Authorization": f"Bearer {FIREWORKS_API_KEY}",
|
| 23 |
}
|
| 24 |
+
payload = dict(payload)
|
|
|
|
| 25 |
payload["stream"] = True
|
| 26 |
with requests.post(FIREWORKS_URL, headers=headers, json=payload, stream=True) as r:
|
| 27 |
r.raise_for_status()
|
|
|
|
| 36 |
try:
|
| 37 |
obj = json.loads(data)
|
| 38 |
except json.JSONDecodeError:
|
|
|
|
| 39 |
buffer += data
|
| 40 |
try:
|
| 41 |
obj = json.loads(buffer)
|
| 42 |
buffer = ""
|
| 43 |
except Exception:
|
| 44 |
continue
|
|
|
|
| 45 |
try:
|
| 46 |
delta = obj["choices"][0]["delta"]
|
| 47 |
if "content" in delta and delta["content"]:
|
| 48 |
yield delta["content"]
|
| 49 |
except Exception:
|
|
|
|
| 50 |
continue
|
| 51 |
|
| 52 |
+
def _normalize_history_to_messages(history):
|
| 53 |
+
"""Normalize history from Gradio into OpenAI-style messages without system prompt."""
|
| 54 |
+
# Chatbot(type='messages') already gives a list of dicts: [{'role': 'user'|'assistant', 'content': '...'}, ...]
|
| 55 |
+
if not history:
|
| 56 |
+
return []
|
| 57 |
+
if isinstance(history, list) and len(history) > 0 and isinstance(history[0], dict) and "role" in history[0]:
|
| 58 |
+
# Already messages format; pass through (filter any roles other than user/assistant)
|
| 59 |
+
return [m for m in history if m.get("role") in ("user", "assistant")]
|
| 60 |
+
# Back-compat: history may be list of (user, assistant) tuples
|
| 61 |
+
msgs = []
|
| 62 |
+
for u, a in history:
|
| 63 |
+
if u:
|
| 64 |
+
msgs.append({"role": "user", "content": u})
|
| 65 |
+
if a:
|
| 66 |
+
msgs.append({"role": "assistant", "content": a})
|
| 67 |
+
return msgs
|
| 68 |
+
|
| 69 |
def _build_messages(history, user_message):
|
| 70 |
messages = []
|
|
|
|
| 71 |
if SYSTEM_PROMPT:
|
| 72 |
messages.append({"role": "system", "content": SYSTEM_PROMPT})
|
| 73 |
+
messages.extend(_normalize_history_to_messages(history))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
if user_message:
|
| 75 |
messages.append({"role": "user", "content": user_message})
|
| 76 |
return messages
|
|
|
|
| 86 |
"frequency_penalty": float(frequency_penalty),
|
| 87 |
"messages": _build_messages(history, user_message),
|
| 88 |
}
|
|
|
|
| 89 |
for token in _fireworks_stream(payload):
|
| 90 |
yield token
|
| 91 |
|
|
|
|
| 101 |
<div style="display:flex; align-items:center; gap:12px; margin: 6px 0 16px;">
|
| 102 |
<svg width="28" height="28" viewBox="0 0 24 24" fill="none"><path d="M12 3l7 4v6c0 5-7 8-7 8s-7-3-7-8V7l7-4z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/></svg>
|
| 103 |
<div>
|
| 104 |
+
<div id="title" style="font-size:1.25rem;">Fireworks Chat Playground</div>
|
| 105 |
<div style="opacity:0.7; font-size:0.95rem;">Secure, streamed chat to <code>inference/v1/chat/completions</code></div>
|
| 106 |
</div>
|
| 107 |
</div>
|
| 108 |
""")
|
| 109 |
with gr.Row():
|
| 110 |
with gr.Column(scale=3):
|
| 111 |
+
# Use messages format to avoid deprecation
|
| 112 |
+
chatbot = gr.Chatbot(height=480, type="messages", avatar_images=(None, None))
|
| 113 |
with gr.Row(elem_classes=["controls"]):
|
| 114 |
max_tokens = gr.Slider(32, 8192, value=4000, step=16, label="Max tokens")
|
| 115 |
temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.05, label="Temperature")
|
|
|
|
| 121 |
frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="frequency_penalty")
|
| 122 |
gr.Markdown("""
|
| 123 |
**Security notes**
|
| 124 |
+
- API key and system prompt are server-side environment variables.
|
|
|
|
| 125 |
- Change the model id with `FIREWORKS_MODEL_ID` (env var).
|
| 126 |
""")
|
| 127 |
clear_btn = gr.Button("Clear", variant="secondary")
|
|
|
|
| 130 |
chatbot=chatbot,
|
| 131 |
additional_inputs=[max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty],
|
| 132 |
title=None,
|
|
|
|
|
|
|
|
|
|
| 133 |
submit_btn="Send",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
examples=["Hello!", "Summarize: Why is retrieval-augmented generation useful for insurers?", "Write a 3-bullet status update for the Palmyra team."],
|
| 135 |
description="Start chatting below. Streaming is enabled."
|
| 136 |
)
|
| 137 |
clear_btn.click(fn=clear_history, outputs=chatbot)
|
| 138 |
|
| 139 |
if __name__ == "__main__":
|
|
|
|
| 140 |
demo.queue().launch(server_name="0.0.0.0")
|