|
|
import os |
|
|
import io |
|
|
import json |
|
|
import uuid |
|
|
import base64 |
|
|
import time |
|
|
import random |
|
|
import math |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from ollama import Client |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Failed to import the 'ollama' Python client. Ensure it's in requirements.txt." |
|
|
) from e |
|
|
|
|
|
DEFAULT_PORT = int(os.getenv("PORT", 7860)) |
|
|
DEFAULT_OLLAMA_HOST = os.getenv("OLLAMA_HOST", "").strip() or os.getenv("OLLAMA_BASE_URL", "").strip() or "" |
|
|
DEFAULT_MODEL = os.getenv("OLLAMA_MODEL", "llama3.1") |
|
|
APP_TITLE = "Ollama Chat (Gradio + Docker)" |
|
|
APP_DESCRIPTION = """ |
|
|
A lightweight, fully functional chat UI for Ollama, designed to run on Hugging Face Spaces (Docker). |
|
|
- Bring your own Ollama host (set OLLAMA_HOST in repo secrets or via the UI). |
|
|
- Streamed responses, model management (list/pull), and basic vision support (image input). |
|
|
- Compatible with Spaces ZeroGPU via a @spaces.GPU-decorated function (see GPU Tools panel). |
|
|
""" |
|
|
|
|
|
|
|
|
def ensure_scheme(host: str) -> str: |
|
|
if not host: |
|
|
return host |
|
|
host = host.strip() |
|
|
if not host.startswith(("http://", "https://")): |
|
|
host = "http://" + host |
|
|
|
|
|
while host.endswith("/"): |
|
|
host = host[:-1] |
|
|
return host |
|
|
|
|
|
|
|
|
def get_client(host: str) -> Client: |
|
|
host = ensure_scheme(host) |
|
|
if not host: |
|
|
|
|
|
return Client() |
|
|
return Client(host=host) |
|
|
|
|
|
|
|
|
def list_models(host: str) -> Tuple[List[str], Optional[str]]: |
|
|
try: |
|
|
client = get_client(host) |
|
|
data = client.list() |
|
|
names = sorted(m.get("name", "") for m in data.get("models", []) if m.get("name")) |
|
|
return names, None |
|
|
except Exception as e: |
|
|
return [], f"Unable to list models from {host or '(env default)'}: {e}" |
|
|
|
|
|
|
|
|
def test_connection(host: str) -> Tuple[bool, str]: |
|
|
names, err = list_models(host) |
|
|
if err: |
|
|
return False, err |
|
|
if not names: |
|
|
return True, f"Connected to {host or '(env default)'} but no models found. Pull one to continue." |
|
|
return True, f"Connected to {host or '(env default)'}; found {len(names)} models." |
|
|
|
|
|
|
|
|
def show_model(host: str, model: str) -> Tuple[Optional[dict], Optional[str]]: |
|
|
try: |
|
|
client = get_client(host) |
|
|
info = client.show(model=model) |
|
|
return info, None |
|
|
except Exception as e: |
|
|
return None, f"Unable to show model '{model}': {e}" |
|
|
|
|
|
|
|
|
def pull_model(host: str, model: str): |
|
|
""" |
|
|
Generator that pulls a model on the remote Ollama host, yielding progress strings. |
|
|
""" |
|
|
if not model: |
|
|
yield "Provide a model name to pull (e.g., llama3.1, mistral, qwen2.5:latest)" |
|
|
return |
|
|
try: |
|
|
client = get_client(host) |
|
|
already, _ = show_model(host, model) |
|
|
if already: |
|
|
yield f"Model '{model}' already present on the host." |
|
|
return |
|
|
|
|
|
yield f"Pulling '{model}' from registry..." |
|
|
for part in client.pull(model=model, stream=True): |
|
|
|
|
|
status = part.get("status", "") |
|
|
total = part.get("total", 0) |
|
|
completed = part.get("completed", 0) |
|
|
pct = f"{(completed / total * 100):.1f}%" if total else "" |
|
|
line = status |
|
|
if pct: |
|
|
line += f" ({pct})" |
|
|
yield line |
|
|
yield f"Finished pulling '{model}'." |
|
|
except Exception as e: |
|
|
yield f"Error pulling '{model}': {e}" |
|
|
|
|
|
|
|
|
def encode_image_to_base64(path: str) -> Optional[str]: |
|
|
try: |
|
|
with open(path, "rb") as f: |
|
|
return base64.b64encode(f.read()).decode("utf-8") |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def build_ollama_messages( |
|
|
system_prompt: str, |
|
|
convo_messages: List[Dict], |
|
|
user_text: str, |
|
|
image_paths: Optional[List[str]] = None, |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Returns the full message list to send to Ollama, including system prompt (if provided), |
|
|
past conversation, and the new user message. |
|
|
""" |
|
|
messages = [] |
|
|
if system_prompt.strip(): |
|
|
messages.append({"role": "system", "content": system_prompt.strip()}) |
|
|
|
|
|
messages.extend(convo_messages or []) |
|
|
|
|
|
msg: Dict = {"role": "user", "content": user_text or ""} |
|
|
if image_paths: |
|
|
images_b64 = [] |
|
|
for p in image_paths: |
|
|
b64 = encode_image_to_base64(p) |
|
|
if b64: |
|
|
images_b64.append(b64) |
|
|
if images_b64: |
|
|
msg["images"] = images_b64 |
|
|
messages.append(msg) |
|
|
return messages |
|
|
|
|
|
|
|
|
def messages_for_chatbot( |
|
|
text: str, |
|
|
image_paths: Optional[List[str]] = None, |
|
|
role: str = "user", |
|
|
) -> Dict: |
|
|
""" |
|
|
Build a Gradio Chatbot message in "messages" mode: |
|
|
{"role": "user"|"assistant", "content": [{"type":"text","text":...}, {"type":"image","image":<PIL.Image>}, ...]} |
|
|
""" |
|
|
content = [] |
|
|
t = (text or "").strip() |
|
|
if t: |
|
|
content.append({"type": "text", "text": t}) |
|
|
|
|
|
if image_paths: |
|
|
|
|
|
for p in image_paths: |
|
|
try: |
|
|
|
|
|
content.append({"type": "image", "image": p}) |
|
|
except Exception: |
|
|
continue |
|
|
return {"role": role, "content": content if content else [{"type": "text", "text": ""}]} |
|
|
|
|
|
|
|
|
def stream_chat( |
|
|
host: str, |
|
|
model: str, |
|
|
system_prompt: str, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
top_k: int, |
|
|
repeat_penalty: float, |
|
|
num_ctx: int, |
|
|
max_tokens: Optional[int], |
|
|
seed: Optional[int], |
|
|
convo_messages: List[Dict], |
|
|
chatbot_history: List[Dict], |
|
|
user_text: str, |
|
|
image_files: Optional[List[str]], |
|
|
): |
|
|
""" |
|
|
Stream a chat completion from Ollama and update Gradio Chatbot incrementally. |
|
|
""" |
|
|
|
|
|
user_msg_for_bot = messages_for_chatbot(user_text, image_files, role="user") |
|
|
chatbot_history = chatbot_history + [user_msg_for_bot] |
|
|
|
|
|
|
|
|
ollama_messages = build_ollama_messages(system_prompt, convo_messages, user_text, image_files) |
|
|
|
|
|
|
|
|
options = { |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"top_k": top_k, |
|
|
"repeat_penalty": repeat_penalty, |
|
|
"num_ctx": num_ctx, |
|
|
} |
|
|
if max_tokens is not None and max_tokens > 0: |
|
|
|
|
|
options["num_predict"] = max_tokens |
|
|
if seed is not None: |
|
|
options["seed"] = seed |
|
|
|
|
|
|
|
|
client = get_client(host) |
|
|
assistant_text_accum = "" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
assistant_msg_for_bot = messages_for_chatbot("", None, role="assistant") |
|
|
chatbot_history = chatbot_history + [assistant_msg_for_bot] |
|
|
status_md = f"Model: {model} | Host: {ensure_scheme(host) or '(env default)'} | Streaming..." |
|
|
|
|
|
|
|
|
yield chatbot_history, status_md, convo_messages |
|
|
|
|
|
try: |
|
|
for part in client.chat( |
|
|
model=model, |
|
|
messages=ollama_messages, |
|
|
stream=True, |
|
|
options=options, |
|
|
): |
|
|
msg = part.get("message", {}) or {} |
|
|
delta = msg.get("content", "") |
|
|
if delta: |
|
|
assistant_text_accum += delta |
|
|
chatbot_history[-1] = messages_for_chatbot(assistant_text_accum, None, role="assistant") |
|
|
|
|
|
done = part.get("done", False) |
|
|
if done: |
|
|
eval_count = part.get("eval_count", 0) |
|
|
prompt_eval_count = part.get("prompt_eval_count", 0) |
|
|
total = time.time() - start_time |
|
|
tok_s = (eval_count / total) if total > 0 else 0.0 |
|
|
status_md = ( |
|
|
f"Model: {model} | Host: {ensure_scheme(host) or '(env default)'} | " |
|
|
f"Prompt tokens: {prompt_eval_count} | Output tokens: {eval_count} | " |
|
|
f"Time: {total:.2f}s | Speed: {tok_s:.1f} tok/s" |
|
|
) |
|
|
yield chatbot_history, status_md, convo_messages |
|
|
|
|
|
|
|
|
convo_messages = convo_messages + [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": user_text or "", |
|
|
**( |
|
|
{ |
|
|
"images": [ |
|
|
b for p in (image_files or []) |
|
|
for b in ([encode_image_to_base64(p)] if encode_image_to_base64(p) else []) |
|
|
] |
|
|
} if image_files else {} |
|
|
), |
|
|
}, |
|
|
{"role": "assistant", "content": assistant_text_accum}, |
|
|
] |
|
|
|
|
|
yield chatbot_history, status_md, convo_messages |
|
|
|
|
|
except Exception as e: |
|
|
err_msg = f"Error during generation: {e}" |
|
|
chatbot_history[-1] = messages_for_chatbot(err_msg, None, role="assistant") |
|
|
yield chatbot_history, err_msg, convo_messages |
|
|
|
|
|
|
|
|
def clear_conversation(): |
|
|
return [], [], "" |
|
|
|
|
|
|
|
|
def export_conversation(history: List[Dict], convo_messages: List[Dict]) -> Tuple[str, str]: |
|
|
export_blob = { |
|
|
"chat_messages": history, |
|
|
"ollama_messages": convo_messages, |
|
|
"meta": { |
|
|
"title": APP_TITLE, |
|
|
"exported_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()), |
|
|
"version": "1.1", |
|
|
}, |
|
|
} |
|
|
path = f"chat_export_{int(time.time())}.json" |
|
|
with open(path, "w", encoding="utf-8") as f: |
|
|
json.dump(export_blob, f, ensure_ascii=False, indent=2) |
|
|
return path, f"Exported {len(history)} messages to {path}" |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def gpu_ping(workload: int = 256) -> dict: |
|
|
""" |
|
|
Minimal function to satisfy ZeroGPU Spaces requirement and optionally exercise the GPU. |
|
|
If torch with CUDA is available, perform a tiny matmul on GPU; otherwise do a CPU loop. |
|
|
""" |
|
|
t0 = time.time() |
|
|
|
|
|
acc = 0.0 |
|
|
for i in range(max(1, workload)): |
|
|
x = random.random() * 1000.0 |
|
|
|
|
|
s = math.sin(x) |
|
|
c = math.cos(x) |
|
|
t = math.tan(x) if abs(math.cos(x)) > 1e-9 else 1.0 |
|
|
acc += s * c / t |
|
|
|
|
|
info = {"mode": "cpu", "ops": workload} |
|
|
|
|
|
try: |
|
|
import torch |
|
|
if torch.cuda.is_available(): |
|
|
a = torch.randn((256, 256), device="cuda") |
|
|
b = torch.mm(a, a) |
|
|
_ = float(b.mean().item()) |
|
|
info["mode"] = "cuda" |
|
|
info["device"] = torch.cuda.get_device_name(torch.cuda.current_device()) |
|
|
info["cuda"] = True |
|
|
else: |
|
|
info["cuda"] = False |
|
|
except Exception: |
|
|
|
|
|
info["cuda"] = "unavailable" |
|
|
|
|
|
elapsed = time.time() - t0 |
|
|
return {"ok": True, "elapsed_s": round(elapsed, 4), "acc_checksum": float(acc % 1.0), "info": info} |
|
|
|
|
|
|
|
|
|
|
|
def ui() -> gr.Blocks: |
|
|
with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(f"# {APP_TITLE}") |
|
|
gr.Markdown(APP_DESCRIPTION) |
|
|
|
|
|
|
|
|
state_convo = gr.State([]) |
|
|
state_history = gr.State([]) |
|
|
state_system_prompt = gr.State("") |
|
|
state_host = gr.State(DEFAULT_OLLAMA_HOST) |
|
|
state_session = gr.State(str(uuid.uuid4())) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot(label="Chat", type="messages", height=520, avatar_images=(None, None)) |
|
|
with gr.Row(): |
|
|
txt = gr.Textbox( |
|
|
label="Your message", |
|
|
placeholder="Ask anything...", |
|
|
autofocus=True, |
|
|
scale=4, |
|
|
) |
|
|
image_files = gr.Files( |
|
|
label="Optional image(s)", |
|
|
file_types=["image"], |
|
|
type="filepath", |
|
|
visible=True, |
|
|
) |
|
|
with gr.Row(): |
|
|
send_btn = gr.Button("Send", variant="primary") |
|
|
stop_btn = gr.Button("Stop") |
|
|
clear_btn = gr.Button("Clear") |
|
|
export_btn = gr.Button("Export") |
|
|
|
|
|
status = gr.Markdown("Ready.", elem_id="status_box") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("## Connection") |
|
|
host_in = gr.Textbox( |
|
|
label="Ollama Host URL", |
|
|
placeholder="http://127.0.0.1:11434 (or leave blank to use server env OLLAMA_HOST)", |
|
|
value=DEFAULT_OLLAMA_HOST, |
|
|
) |
|
|
with gr.Row(): |
|
|
test_btn = gr.Button("Test Connection") |
|
|
refresh_models_btn = gr.Button("Refresh Models") |
|
|
|
|
|
models_dd = gr.Dropdown( |
|
|
choices=[], |
|
|
value=None, |
|
|
label="Model", |
|
|
allow_custom_value=True, |
|
|
info="Select a model from the server or type a name (e.g., llama3.1, mistral, phi4:latest)", |
|
|
) |
|
|
pull_model_txt = gr.Textbox( |
|
|
label="Pull Model (by name)", |
|
|
placeholder="e.g., llama3.1, mistral, qwen2.5:latest", |
|
|
) |
|
|
pull_btn = gr.Button("Pull Model") |
|
|
pull_log = gr.Textbox(label="Pull Progress", interactive=False, lines=6) |
|
|
|
|
|
gr.Markdown("## System Prompt") |
|
|
sys_prompt = gr.Textbox( |
|
|
label="System Prompt", |
|
|
placeholder="You are a helpful assistant...", |
|
|
lines=4, |
|
|
value=os.getenv("SYSTEM_PROMPT", ""), |
|
|
) |
|
|
|
|
|
gr.Markdown("## Generation Settings") |
|
|
with gr.Row(): |
|
|
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") |
|
|
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") |
|
|
with gr.Row(): |
|
|
top_k = gr.Slider(0, 200, value=40, step=1, label="Top-k") |
|
|
repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.01, label="Repeat Penalty") |
|
|
with gr.Row(): |
|
|
num_ctx = gr.Slider(256, 8192, value=4096, step=256, label="Context Window (num_ctx)") |
|
|
max_tokens = gr.Slider(0, 8192, value=0, step=16, label="Max New Tokens (0 = auto)") |
|
|
seed = gr.Number(value=None, label="Seed (optional)", precision=0) |
|
|
|
|
|
gr.Markdown("## GPU Tools (ZeroGPU compatible)") |
|
|
with gr.Row(): |
|
|
gpu_workload = gr.Slider(64, 4096, value=256, step=64, label="GPU Ping Workload") |
|
|
with gr.Row(): |
|
|
gpu_btn = gr.Button("Run GPU Ping") |
|
|
gpu_out = gr.Textbox(label="GPU Ping Result", lines=6, interactive=False) |
|
|
|
|
|
|
|
|
def _on_load(): |
|
|
|
|
|
host = DEFAULT_OLLAMA_HOST |
|
|
names, err = list_models(host) |
|
|
if err: |
|
|
status_msg = f"Note: {err}" |
|
|
else: |
|
|
status_msg = f"Loaded {len(names)} models from {ensure_scheme(host) or '(env default)'}." |
|
|
|
|
|
value = DEFAULT_MODEL if DEFAULT_MODEL in names else (names[0] if names else None) |
|
|
return ( |
|
|
names, value, |
|
|
host, |
|
|
status_msg, |
|
|
[], [], "", |
|
|
) |
|
|
|
|
|
load_outputs = [ |
|
|
models_dd, models_dd, |
|
|
host_in, |
|
|
status, |
|
|
state_history, state_convo, state_system_prompt |
|
|
] |
|
|
demo.load(_on_load, outputs=load_outputs) |
|
|
|
|
|
|
|
|
def set_host(h): |
|
|
return ensure_scheme(h) |
|
|
|
|
|
host_in.change(set_host, inputs=host_in, outputs=state_host) |
|
|
|
|
|
|
|
|
def _test(h): |
|
|
ok, msg = test_connection(h) |
|
|
|
|
|
names, err = list_models(h) if ok else ([], None) |
|
|
model_val = models_dd.value if ok and models_dd.value in names else (names[0] if names else None) |
|
|
if err: |
|
|
msg += f"\nAlso: {err}" |
|
|
return names, model_val, msg |
|
|
|
|
|
test_btn.click(_test, inputs=host_in, outputs=[models_dd, models_dd, status]) |
|
|
|
|
|
|
|
|
refresh_models_btn.click(_test, inputs=host_in, outputs=[models_dd, models_dd, status]) |
|
|
|
|
|
|
|
|
def _pull(h, name): |
|
|
if not name: |
|
|
yield "Please enter a model name to pull." |
|
|
return |
|
|
for line in pull_model(h, name.strip()): |
|
|
yield line |
|
|
|
|
|
pull_btn.click(_pull, inputs=[host_in, pull_model_txt], outputs=pull_log) |
|
|
|
|
|
|
|
|
clear_btn.click(clear_conversation, outputs=[chatbot, state_convo, status]) |
|
|
|
|
|
|
|
|
export_file = gr.File(label="Download Conversation", visible=True) |
|
|
export_btn.click(export_conversation, inputs=[state_history, state_convo], outputs=[export_file, status]) |
|
|
|
|
|
|
|
|
def _submit( |
|
|
h, m, sp, t, tp, tk, rp, ctx, mx, sd, convo, history, text, files |
|
|
): |
|
|
|
|
|
mx_int = int(mx) if mx and int(mx) > 0 else None |
|
|
sd_int = int(sd) if sd is not None else None |
|
|
yield from stream_chat( |
|
|
host=h, |
|
|
model=m or DEFAULT_MODEL, |
|
|
system_prompt=sp or "", |
|
|
temperature=float(t), |
|
|
top_p=float(tp), |
|
|
top_k=int(tk), |
|
|
repeat_penalty=float(rp), |
|
|
num_ctx=int(ctx), |
|
|
max_tokens=mx_int, |
|
|
seed=sd_int, |
|
|
convo_messages=convo, |
|
|
chatbot_history=history, |
|
|
user_text=text, |
|
|
image_files=files, |
|
|
) |
|
|
|
|
|
submit_event = send_btn.click( |
|
|
_submit, |
|
|
inputs=[host_in, models_dd, sys_prompt, temperature, top_p, top_k, repeat_penalty, num_ctx, max_tokens, seed, state_convo, state_history, txt, image_files], |
|
|
outputs=[chatbot, status, state_convo], |
|
|
) |
|
|
|
|
|
txt.submit( |
|
|
_submit, |
|
|
inputs=[host_in, models_dd, sys_prompt, temperature, top_p, top_k, repeat_penalty, num_ctx, max_tokens, seed, state_convo, state_history, txt, image_files], |
|
|
outputs=[chatbot, status, state_convo], |
|
|
) |
|
|
|
|
|
|
|
|
stop_btn.click(None, None, None, cancels=[submit_event]) |
|
|
|
|
|
|
|
|
def _post_send(): |
|
|
return "", None |
|
|
|
|
|
send_btn.click(_post_send, outputs=[txt, image_files]) |
|
|
txt.submit(_post_send, outputs=[txt, image_files]) |
|
|
|
|
|
|
|
|
def _sync_chatbot_state(history): |
|
|
return history |
|
|
|
|
|
chatbot.change(_sync_chatbot_state, inputs=chatbot, outputs=state_history) |
|
|
|
|
|
|
|
|
def _gpu_ping_ui(n): |
|
|
try: |
|
|
res = gpu_ping(int(n)) |
|
|
try: |
|
|
return json.dumps(res, indent=2) |
|
|
except Exception: |
|
|
return str(res) |
|
|
except Exception as e: |
|
|
return f"GPU ping failed: {e}" |
|
|
|
|
|
gpu_btn.click(_gpu_ping_ui, inputs=[gpu_workload], outputs=[gpu_out]) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = ui() |
|
|
demo.queue(default_concurrency_limit=10) |
|
|
demo.launch(server_name="0.0.0.0", server_port=DEFAULT_PORT, show_api=True, ssr_mode=False) |
|
|
|