wassemgtk commited on
Commit
104cf09
·
verified ·
1 Parent(s): f5955b1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (3).py +146 -0
  2. requirements.txt +3 -0
app (3).py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import time
5
+ import requests
6
+ import gradio as gr
7
+
8
+ FIREWORKS_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
9
+ MODEL_ID = os.getenv("FIREWORKS_MODEL_ID", "accounts/waseem-9b447b/models/ft-gdixl08u-sz53t")
10
+
11
+ # Secrets (server-side only; never sent to the client UI)
12
+ FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY") # required
13
+ SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful enterprise-grade assistant. Be concise, accurate, and secure.")
14
+
15
+ if not FIREWORKS_API_KEY:
16
+ raise RuntimeError("Missing FIREWORKS_API_KEY environment variable")
17
+
18
+ def _fireworks_stream(payload):
19
+ """Generator that streams tokens from Fireworks chat completions SSE response."""
20
+ headers = {
21
+ "Accept": "application/json",
22
+ "Content-Type": "application/json",
23
+ "Authorization": f"Bearer {FIREWORKS_API_KEY}",
24
+ }
25
+ # Ensure we stream
26
+ payload = dict(payload) # shallow copy
27
+ payload["stream"] = True
28
+ with requests.post(FIREWORKS_URL, headers=headers, json=payload, stream=True) as r:
29
+ r.raise_for_status()
30
+ buffer = ""
31
+ for line in r.iter_lines(decode_unicode=True):
32
+ if not line:
33
+ continue
34
+ if line.startswith("data:"):
35
+ data = line[len("data:"):].strip()
36
+ if data == "[DONE]":
37
+ break
38
+ try:
39
+ obj = json.loads(data)
40
+ except json.JSONDecodeError:
41
+ # In case of partial line; accumulate
42
+ buffer += data
43
+ try:
44
+ obj = json.loads(buffer)
45
+ buffer = ""
46
+ except Exception:
47
+ continue
48
+ # Fireworks streams OpenAI-style deltas
49
+ try:
50
+ delta = obj["choices"][0]["delta"]
51
+ if "content" in delta and delta["content"]:
52
+ yield delta["content"]
53
+ except Exception:
54
+ # Some events may be role changes or tool calls; ignore silently
55
+ continue
56
+
57
+ def _build_messages(history, user_message):
58
+ messages = []
59
+ # Insert a hidden system message from server-side secret
60
+ if SYSTEM_PROMPT:
61
+ messages.append({"role": "system", "content": SYSTEM_PROMPT})
62
+ # History from Gradio ChatInterface comes as list of (user, assistant) tuples
63
+ for u, a in history:
64
+ if u:
65
+ messages.append({"role": "user", "content": u})
66
+ if a:
67
+ messages.append({"role": "assistant", "content": a})
68
+ if user_message:
69
+ messages.append({"role": "user", "content": user_message})
70
+ return messages
71
+
72
+ def chat_fn(user_message, history, max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty):
73
+ payload = {
74
+ "model": MODEL_ID,
75
+ "max_tokens": int(max_tokens),
76
+ "temperature": float(temperature),
77
+ "top_p": float(top_p),
78
+ "top_k": int(top_k),
79
+ "presence_penalty": float(presence_penalty),
80
+ "frequency_penalty": float(frequency_penalty),
81
+ "messages": _build_messages(history, user_message),
82
+ }
83
+ # Stream tokens back to the UI
84
+ for token in _fireworks_stream(payload):
85
+ yield token
86
+
87
+ def clear_history():
88
+ return None
89
+
90
+ with gr.Blocks(theme=gr.themes.Soft(), css="""
91
+ :root { --radius: 16px; }
92
+ #title { font-weight: 800; letter-spacing: -0.02em; }
93
+ div.controls { gap: 10px !important; }
94
+ """) as demo:
95
+ gr.HTML("""
96
+ <div style="display:flex; align-items:center; gap:12px; margin: 6px 0 16px;">
97
+ <svg width="28" height="28" viewBox="0 0 24 24" fill="none"><path d="M12 3l7 4v6c0 5-7 8-7 8s-7-3-7-8V7l7-4z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/></svg>
98
+ <div>
99
+ <div id="title" style="font-size:1.25rem;">Fireworks Chat Playground</div>
100
+ <div style="opacity:0.7; font-size:0.95rem;">Secure, streamed chat to <code>inference/v1/chat/completions</code></div>
101
+ </div>
102
+ </div>
103
+ """)
104
+ with gr.Row():
105
+ with gr.Column(scale=3):
106
+ chatbot = gr.Chatbot(height=480, avatar_images=(None, None), bubble_full_width=False, likeable=True)
107
+ with gr.Row(elem_classes=["controls"]):
108
+ max_tokens = gr.Slider(32, 8192, value=4000, step=16, label="Max tokens")
109
+ temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.05, label="Temperature")
110
+ with gr.Column(scale=2):
111
+ with gr.Group():
112
+ top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="top_p")
113
+ top_k = gr.Slider(0, 200, value=40, step=1, label="top_k")
114
+ presence_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="presence_penalty")
115
+ frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="frequency_penalty")
116
+ gr.Markdown("""
117
+ **Security notes**
118
+ - Your API key and system prompt are kept on the server as environment variables.
119
+ - They are never shown in the UI or sent to the browser.
120
+ - Change the model id with `FIREWORKS_MODEL_ID` (env var).
121
+ """)
122
+ clear_btn = gr.Button("Clear", variant="secondary")
123
+ chat = gr.ChatInterface(
124
+ fn=chat_fn,
125
+ chatbot=chatbot,
126
+ additional_inputs=[max_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty],
127
+ title=None,
128
+ retry_btn=None,
129
+ undo_btn="Undo last",
130
+ clear_btn=None,
131
+ submit_btn="Send",
132
+ autofocus=True,
133
+ fill_height=False,
134
+ cache_examples=False,
135
+ concurrency_limit=10,
136
+ multimodal=False,
137
+ analytics_enabled=False,
138
+ enable_queue=True,
139
+ examples=["Hello!", "Summarize: Why is retrieval-augmented generation useful for insurers?", "Write a 3-bullet status update for the Palmyra team."],
140
+ description="Start chatting below. Streaming is enabled."
141
+ )
142
+ clear_btn.click(fn=clear_history, outputs=chatbot)
143
+
144
+ if __name__ == "__main__":
145
+ # Use 0.0.0.0 for container friendliness; set GRADIO_SERVER_PORT externally if needed
146
+ demo.queue().launch(server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ gradio>=4.36.1
3
+ requests>=2.31.0