Spaces:
Runtime error
Runtime error
File size: 9,502 Bytes
7cd14d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
import json
import time
from contextlib import contextmanager
import pytest
from fastapi.testclient import TestClient
import main
class FakeEngine:
def __init__(self, model_id="fake-model"):
self.model_id = model_id
self.last_context_info = {
"compressed": False,
"prompt_tokens": 5,
"max_context": 8192,
"budget": 7900,
"strategy": "truncate",
"dropped_messages": 0,
}
def infer(self, messages, max_tokens, temperature):
# Simulate parse error pathway when special trigger is present
if messages and isinstance(messages[0].get("content"), str) and "PARSE_ERR" in messages[0]["content"]:
raise ValueError("Simulated parse error")
# Return echo content for deterministic test
parts = []
for m in messages:
c = m.get("content", "")
if isinstance(c, list):
for p in c:
if isinstance(p, dict) and p.get("type") == "text":
parts.append(p.get("text", ""))
elif isinstance(c, str):
parts.append(c)
txt = " ".join(parts) or "OK"
# Simulate context accounting changing with request
self.last_context_info = {
"compressed": False,
"prompt_tokens": max(1, len(txt.split())),
"max_context": 8192,
"budget": 7900,
"strategy": "truncate",
"dropped_messages": 0,
}
return f"OK: {txt}"
def infer_stream(self, messages, max_tokens, temperature, cancel_event=None):
# simple two-piece stream; respects cancel_event if set during streaming
outputs = ["hello", " world"]
for piece in outputs:
if cancel_event is not None and cancel_event.is_set():
break
yield piece
# tiny delay to allow cancel test to interleave
time.sleep(0.01)
def get_context_report(self):
return {
"compressionEnabled": True,
"strategy": "truncate",
"safetyMargin": 256,
"modelMaxContext": 8192,
"tokenizerModelMaxLength": 8192,
"last": self.last_context_info,
}
@contextmanager
def patched_engine():
# Patch global engine so server does not load real model
prev_engine = main._engine
prev_err = main._engine_error
fake = FakeEngine()
main._engine = fake
main._engine_error = None
try:
yield fake
finally:
main._engine = prev_engine
main._engine_error = prev_err
def get_client():
return TestClient(main.app)
def test_health_ready_and_context():
with patched_engine():
client = get_client()
r = client.get("/health")
assert r.status_code == 200
body = r.json()
assert body["ok"] is True
assert body["modelReady"] is True
assert body["modelId"] == "fake-model"
# context block exists with required fields
ctx = body["context"]
assert ctx["compressionEnabled"] is True
assert "last" in ctx
assert isinstance(ctx["last"].get("prompt_tokens"), int)
def test_health_with_engine_error():
# simulate model load error path
prev_engine = main._engine
prev_err = main._engine_error
try:
main._engine = None
main._engine_error = "boom"
client = get_client()
r = client.get("/health")
assert r.status_code == 200
body = r.json()
assert body["modelReady"] is False
assert body["error"] == "boom"
finally:
main._engine = prev_engine
main._engine_error = prev_err
def test_chat_non_stream_validation():
with patched_engine():
client = get_client()
# missing messages should 400
r = client.post("/v1/chat/completions", json={"messages": []})
assert r.status_code == 400
def test_chat_non_stream_success_and_usage_context():
with patched_engine():
client = get_client()
payload = {
"messages": [{"role": "user", "content": "Hello Qwen"}],
"max_tokens": 8,
"temperature": 0.0,
}
r = client.post("/v1/chat/completions", json=payload)
assert r.status_code == 200
body = r.json()
assert body["object"] == "chat.completion"
assert body["choices"][0]["message"]["content"].startswith("OK:")
# usage prompt_tokens filled from engine.last_context_info
assert body["usage"]["prompt_tokens"] >= 1
# response includes context echo
assert "context" in body
assert "prompt_tokens" in body["context"]
def test_chat_non_stream_parse_error_to_400():
with patched_engine():
client = get_client()
payload = {
"messages": [{"role": "user", "content": "PARSE_ERR trigger"}],
"max_tokens": 4,
}
r = client.post("/v1/chat/completions", json=payload)
# ValueError in engine -> 400 per API contract
assert r.status_code == 400
def read_sse_lines(resp):
# Utility to parse event-stream into list of data payloads (including [DONE])
lines = []
buf = b""
# Starlette TestClient (httpx) responses expose iter_bytes()/iter_raw(), not requests.iter_content().
# Fall back to available iterator or to full content if streaming isn't supported.
iterator = None
for name in ("iter_bytes", "iter_raw", "iter_content"):
it = getattr(resp, name, None)
if callable(it):
iterator = it
break
if iterator is None:
data = getattr(resp, "content", b"")
if isinstance(data, str):
data = data.encode("utf-8", "ignore")
buf = data
else:
for chunk in iterator():
if not chunk:
continue
if isinstance(chunk, str):
chunk = chunk.encode("utf-8", "ignore")
buf += chunk
while b"\n\n" in buf:
frame, buf = buf.split(b"\n\n", 1)
# keep original frame text for asserts
lines.append(frame.decode("utf-8", errors="ignore"))
# Drain any leftover
if buf:
lines.append(buf.decode("utf-8", errors="ignore"))
return lines
def test_chat_stream_sse_flow_and_resume():
with patched_engine():
client = get_client()
payload = {
"session_id": "s1",
"stream": True,
"messages": [{"role": "user", "content": "stream please"}],
"max_tokens": 8,
"temperature": 0.2,
}
with client.stream("POST", "/v1/chat/completions", json=payload) as resp:
assert resp.status_code == 200
lines = read_sse_lines(resp)
# Must contain role delta, content pieces, finish chunk, and [DONE]
joined = "\n".join(lines)
assert "delta" in joined
assert "[DONE]" in joined
# Resume from event index 0 should receive at least one subsequent event
headers = {"Last-Event-ID": "s1:0"}
with client.stream("POST", "/v1/chat/completions", headers=headers, json=payload) as resp2:
assert resp2.status_code == 200
lines2 = read_sse_lines(resp2)
assert any("data:" in l for l in lines2)
assert "[DONE]" in "\n".join(lines2)
# Invalid Last-Event-ID format should not crash (covered by try/except)
headers_bad = {"Last-Event-ID": "not-an-index"}
with client.stream("POST", "/v1/chat/completions", headers=headers_bad, json=payload) as resp3:
assert resp3.status_code == 200
_ = read_sse_lines(resp3) # just ensure no crash
def test_cancel_endpoint_stops_generation():
with patched_engine():
client = get_client()
payload = {
"session_id": "to-cancel",
"stream": True,
"messages": [{"role": "user", "content": "cancel me"}],
}
# Start streaming in background (client.stream keeps the connection open)
with client.stream("POST", "/v1/chat/completions", json=payload) as resp:
# Immediately cancel
rc = client.post("/v1/cancel/to-cancel")
assert rc.status_code == 200
# Stream should end with [DONE] without hanging
lines = read_sse_lines(resp)
assert "[DONE]" in "\n".join(lines)
def test_cancel_unknown_session_is_ok():
with patched_engine():
client = get_client()
rc = client.post("/v1/cancel/does-not-exist")
# Endpoint returns ok regardless (idempotent, operationally safe)
assert rc.status_code == 200
def test_edge_large_last_event_id_after_finish_yields_done():
with patched_engine():
client = get_client()
payload = {
"session_id": "done-session",
"stream": True,
"messages": [{"role": "user", "content": "edge"}],
}
# Complete a run
with client.stream("POST", "/v1/chat/completions", json=payload) as resp:
_ = read_sse_lines(resp)
# Resume with huge index; should return DONE quickly
headers = {"Last-Event-ID": "done-session:99999"}
with client.stream("POST", "/v1/chat/completions", headers=headers, json=payload) as resp2:
lines2 = read_sse_lines(resp2)
assert "[DONE]" in "\n".join(lines2) |