KillerKing93's picture
Sync from GitHub 8f6d598
7cd14d8 verified
raw
history blame
9.5 kB
import json
import time
from contextlib import contextmanager
import pytest
from fastapi.testclient import TestClient
import main
class FakeEngine:
def __init__(self, model_id="fake-model"):
self.model_id = model_id
self.last_context_info = {
"compressed": False,
"prompt_tokens": 5,
"max_context": 8192,
"budget": 7900,
"strategy": "truncate",
"dropped_messages": 0,
}
def infer(self, messages, max_tokens, temperature):
# Simulate parse error pathway when special trigger is present
if messages and isinstance(messages[0].get("content"), str) and "PARSE_ERR" in messages[0]["content"]:
raise ValueError("Simulated parse error")
# Return echo content for deterministic test
parts = []
for m in messages:
c = m.get("content", "")
if isinstance(c, list):
for p in c:
if isinstance(p, dict) and p.get("type") == "text":
parts.append(p.get("text", ""))
elif isinstance(c, str):
parts.append(c)
txt = " ".join(parts) or "OK"
# Simulate context accounting changing with request
self.last_context_info = {
"compressed": False,
"prompt_tokens": max(1, len(txt.split())),
"max_context": 8192,
"budget": 7900,
"strategy": "truncate",
"dropped_messages": 0,
}
return f"OK: {txt}"
def infer_stream(self, messages, max_tokens, temperature, cancel_event=None):
# simple two-piece stream; respects cancel_event if set during streaming
outputs = ["hello", " world"]
for piece in outputs:
if cancel_event is not None and cancel_event.is_set():
break
yield piece
# tiny delay to allow cancel test to interleave
time.sleep(0.01)
def get_context_report(self):
return {
"compressionEnabled": True,
"strategy": "truncate",
"safetyMargin": 256,
"modelMaxContext": 8192,
"tokenizerModelMaxLength": 8192,
"last": self.last_context_info,
}
@contextmanager
def patched_engine():
# Patch global engine so server does not load real model
prev_engine = main._engine
prev_err = main._engine_error
fake = FakeEngine()
main._engine = fake
main._engine_error = None
try:
yield fake
finally:
main._engine = prev_engine
main._engine_error = prev_err
def get_client():
return TestClient(main.app)
def test_health_ready_and_context():
with patched_engine():
client = get_client()
r = client.get("/health")
assert r.status_code == 200
body = r.json()
assert body["ok"] is True
assert body["modelReady"] is True
assert body["modelId"] == "fake-model"
# context block exists with required fields
ctx = body["context"]
assert ctx["compressionEnabled"] is True
assert "last" in ctx
assert isinstance(ctx["last"].get("prompt_tokens"), int)
def test_health_with_engine_error():
# simulate model load error path
prev_engine = main._engine
prev_err = main._engine_error
try:
main._engine = None
main._engine_error = "boom"
client = get_client()
r = client.get("/health")
assert r.status_code == 200
body = r.json()
assert body["modelReady"] is False
assert body["error"] == "boom"
finally:
main._engine = prev_engine
main._engine_error = prev_err
def test_chat_non_stream_validation():
with patched_engine():
client = get_client()
# missing messages should 400
r = client.post("/v1/chat/completions", json={"messages": []})
assert r.status_code == 400
def test_chat_non_stream_success_and_usage_context():
with patched_engine():
client = get_client()
payload = {
"messages": [{"role": "user", "content": "Hello Qwen"}],
"max_tokens": 8,
"temperature": 0.0,
}
r = client.post("/v1/chat/completions", json=payload)
assert r.status_code == 200
body = r.json()
assert body["object"] == "chat.completion"
assert body["choices"][0]["message"]["content"].startswith("OK:")
# usage prompt_tokens filled from engine.last_context_info
assert body["usage"]["prompt_tokens"] >= 1
# response includes context echo
assert "context" in body
assert "prompt_tokens" in body["context"]
def test_chat_non_stream_parse_error_to_400():
with patched_engine():
client = get_client()
payload = {
"messages": [{"role": "user", "content": "PARSE_ERR trigger"}],
"max_tokens": 4,
}
r = client.post("/v1/chat/completions", json=payload)
# ValueError in engine -> 400 per API contract
assert r.status_code == 400
def read_sse_lines(resp):
# Utility to parse event-stream into list of data payloads (including [DONE])
lines = []
buf = b""
# Starlette TestClient (httpx) responses expose iter_bytes()/iter_raw(), not requests.iter_content().
# Fall back to available iterator or to full content if streaming isn't supported.
iterator = None
for name in ("iter_bytes", "iter_raw", "iter_content"):
it = getattr(resp, name, None)
if callable(it):
iterator = it
break
if iterator is None:
data = getattr(resp, "content", b"")
if isinstance(data, str):
data = data.encode("utf-8", "ignore")
buf = data
else:
for chunk in iterator():
if not chunk:
continue
if isinstance(chunk, str):
chunk = chunk.encode("utf-8", "ignore")
buf += chunk
while b"\n\n" in buf:
frame, buf = buf.split(b"\n\n", 1)
# keep original frame text for asserts
lines.append(frame.decode("utf-8", errors="ignore"))
# Drain any leftover
if buf:
lines.append(buf.decode("utf-8", errors="ignore"))
return lines
def test_chat_stream_sse_flow_and_resume():
with patched_engine():
client = get_client()
payload = {
"session_id": "s1",
"stream": True,
"messages": [{"role": "user", "content": "stream please"}],
"max_tokens": 8,
"temperature": 0.2,
}
with client.stream("POST", "/v1/chat/completions", json=payload) as resp:
assert resp.status_code == 200
lines = read_sse_lines(resp)
# Must contain role delta, content pieces, finish chunk, and [DONE]
joined = "\n".join(lines)
assert "delta" in joined
assert "[DONE]" in joined
# Resume from event index 0 should receive at least one subsequent event
headers = {"Last-Event-ID": "s1:0"}
with client.stream("POST", "/v1/chat/completions", headers=headers, json=payload) as resp2:
assert resp2.status_code == 200
lines2 = read_sse_lines(resp2)
assert any("data:" in l for l in lines2)
assert "[DONE]" in "\n".join(lines2)
# Invalid Last-Event-ID format should not crash (covered by try/except)
headers_bad = {"Last-Event-ID": "not-an-index"}
with client.stream("POST", "/v1/chat/completions", headers=headers_bad, json=payload) as resp3:
assert resp3.status_code == 200
_ = read_sse_lines(resp3) # just ensure no crash
def test_cancel_endpoint_stops_generation():
with patched_engine():
client = get_client()
payload = {
"session_id": "to-cancel",
"stream": True,
"messages": [{"role": "user", "content": "cancel me"}],
}
# Start streaming in background (client.stream keeps the connection open)
with client.stream("POST", "/v1/chat/completions", json=payload) as resp:
# Immediately cancel
rc = client.post("/v1/cancel/to-cancel")
assert rc.status_code == 200
# Stream should end with [DONE] without hanging
lines = read_sse_lines(resp)
assert "[DONE]" in "\n".join(lines)
def test_cancel_unknown_session_is_ok():
with patched_engine():
client = get_client()
rc = client.post("/v1/cancel/does-not-exist")
# Endpoint returns ok regardless (idempotent, operationally safe)
assert rc.status_code == 200
def test_edge_large_last_event_id_after_finish_yields_done():
with patched_engine():
client = get_client()
payload = {
"session_id": "done-session",
"stream": True,
"messages": [{"role": "user", "content": "edge"}],
}
# Complete a run
with client.stream("POST", "/v1/chat/completions", json=payload) as resp:
_ = read_sse_lines(resp)
# Resume with huge index; should return DONE quickly
headers = {"Last-Event-ID": "done-session:99999"}
with client.stream("POST", "/v1/chat/completions", headers=headers, json=payload) as resp2:
lines2 = read_sse_lines(resp2)
assert "[DONE]" in "\n".join(lines2)