Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import requests | |
| import gradio as gr | |
| FIREWORKS_URL = "https://api.fireworks.ai/inference/v1/chat/completions" | |
| MODEL_ID = os.getenv("FIREWORKS_MODEL_ID", "accounts/waseem-9b447b/models/ft-gdixl08u-sz53t") | |
| # Secrets (server-side only; never sent to the client UI) | |
| FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY") # required | |
| SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT") | |
| if not FIREWORKS_API_KEY: | |
| raise RuntimeError("Missing FIREWORKS_API_KEY environment variable") | |
| def _fireworks_stream(payload): | |
| """Generator that streams tokens from Fireworks chat completions SSE response.""" | |
| headers = { | |
| "Accept": "application/json", | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {FIREWORKS_API_KEY}", | |
| } | |
| # Ensure we stream | |
| payload = dict(payload) # shallow copy | |
| payload["stream"] = True | |
| with requests.post(FIREWORKS_URL, headers=headers, json=payload, stream=True) as r: | |
| r.raise_for_status() | |
| buffer = "" | |
| for line in r.iter_lines(decode_unicode=True): | |
| if not line: | |
| continue | |
| if line.startswith("data:"): | |
| data = line[len("data:"):].strip() | |
| if data == "[DONE]": | |
| break | |
| try: | |
| obj = json.loads(data) | |
| except json.JSONDecodeError: | |
| # In case of partial line; accumulate | |
| buffer += data | |
| try: | |
| obj = json.loads(buffer) | |
| buffer = "" | |
| except Exception: | |
| continue | |
| # Fireworks streams OpenAI-style deltas | |
| try: | |
| delta = obj["choices"][0]["delta"] | |
| if "content" in delta and delta["content"]: | |
| yield delta["content"] | |
| except Exception: | |
| # Some events may be role changes or tool calls; ignore silently | |
| continue | |
| def _build_messages(history, user_message): | |
| messages = [] | |
| # Insert a hidden system message from server-side secret | |
| if SYSTEM_PROMPT: | |
| messages.append({"role": "system", "content": SYSTEM_PROMPT}) | |
| # History from Gradio ChatInterface comes as list of (user, assistant) tuples | |
| for u, a in history: | |
| if u: | |
| messages.append({"role": "user", "content": u}) | |
| if a: | |
| messages.append({"role": "assistant", "content": a}) | |
| if user_message: | |
| messages.append({"role": "user", "content": user_message}) | |
| return messages | |
| def chat_fn(user_message, history, max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty): | |
| payload = { | |
| "model": MODEL_ID, | |
| "max_tokens": int(max_tokens), | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "top_k": int(top_k), | |
| "presence_penalty": float(presence_penalty), | |
| "frequency_penalty": float(frequency_penalty), | |
| "messages": _build_messages(history, user_message), | |
| } | |
| # Stream tokens back to the UI | |
| for token in _fireworks_stream(payload): | |
| yield token | |
| def clear_history(): | |
| return None | |
| with gr.Blocks(theme=gr.themes.Soft(), css=""" | |
| :root { --radius: 16px; } | |
| #title { font-weight: 800; letter-spacing: -0.02em; } | |
| div.controls { gap: 10px !important; } | |
| """) as demo: | |
| gr.HTML(""" | |
| <div style="display:flex; align-items:center; gap:12px; margin: 6px 0 16px;"> | |
| <svg width="28" height="28" viewBox="0 0 24 24" fill="none"><path d="M12 3l7 4v6c0 5-7 8-7 8s-7-3-7-8V7l7-4z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/></svg> | |
| <div> | |
| <div id="title" style="font-size:1.25rem;">Palmyra-sec Chat Playground</div> | |
| <div style="opacity:0.7; font-size:0.95rem;">Secure, streamed chat to <code>inference/v1/chat/completions</code></div> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=480, avatar_images=(None, None), bubble_full_width=False, likeable=True) | |
| with gr.Row(elem_classes=["controls"]): | |
| max_tokens = gr.Slider(32, 8192, value=4000, step=16, label="Max tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.05, label="Temperature") | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="top_p") | |
| top_k = gr.Slider(0, 200, value=40, step=1, label="top_k") | |
| presence_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="presence_penalty") | |
| frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="frequency_penalty") | |
| gr.Markdown(""" | |
| **Security notes** | |
| - Your API key and system prompt are kept on the server as environment variables. | |
| - They are never shown in the UI or sent to the browser. | |
| - Change the model id with `FIREWORKS_MODEL_ID` (env var). | |
| """) | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| chat = gr.ChatInterface( | |
| fn=chat_fn, | |
| chatbot=chatbot, | |
| additional_inputs=[max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty], | |
| title=None, | |
| retry_btn=None, | |
| undo_btn="Undo last", | |
| clear_btn=None, | |
| submit_btn="Send", | |
| autofocus=True, | |
| fill_height=False, | |
| cache_examples=False, | |
| concurrency_limit=10, | |
| multimodal=False, | |
| analytics_enabled=False, | |
| enable_queue=True, | |
| examples=["Hello!", "Summarize: Why is retrieval-augmented generation useful for insurers?", "Write a 3-bullet status update for the Palmyra team."], | |
| description="Start chatting below. Streaming is enabled." | |
| ) | |
| clear_btn.click(fn=clear_history, outputs=chatbot) | |
| if __name__ == "__main__": | |
| # Use 0.0.0.0 for container friendliness; set GRADIO_SERVER_PORT externally if needed | |
| demo.queue().launch(server_name="0.0.0.0") | |