Spaces:
Runtime error
Runtime error
| import json | |
| import time | |
| from contextlib import contextmanager | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| import main | |
| class FakeEngine: | |
| def __init__(self, model_id="fake-model"): | |
| self.model_id = model_id | |
| self.last_context_info = { | |
| "compressed": False, | |
| "prompt_tokens": 5, | |
| "max_context": 8192, | |
| "budget": 7900, | |
| "strategy": "truncate", | |
| "dropped_messages": 0, | |
| } | |
| def infer(self, messages, max_tokens, temperature): | |
| # Simulate parse error pathway when special trigger is present | |
| if messages and isinstance(messages[0].get("content"), str) and "PARSE_ERR" in messages[0]["content"]: | |
| raise ValueError("Simulated parse error") | |
| # Return echo content for deterministic test | |
| parts = [] | |
| for m in messages: | |
| c = m.get("content", "") | |
| if isinstance(c, list): | |
| for p in c: | |
| if isinstance(p, dict) and p.get("type") == "text": | |
| parts.append(p.get("text", "")) | |
| elif isinstance(c, str): | |
| parts.append(c) | |
| txt = " ".join(parts) or "OK" | |
| # Simulate context accounting changing with request | |
| self.last_context_info = { | |
| "compressed": False, | |
| "prompt_tokens": max(1, len(txt.split())), | |
| "max_context": 8192, | |
| "budget": 7900, | |
| "strategy": "truncate", | |
| "dropped_messages": 0, | |
| } | |
| return f"OK: {txt}" | |
| def infer_stream(self, messages, max_tokens, temperature, cancel_event=None): | |
| # simple two-piece stream; respects cancel_event if set during streaming | |
| outputs = ["hello", " world"] | |
| for piece in outputs: | |
| if cancel_event is not None and cancel_event.is_set(): | |
| break | |
| yield piece | |
| # tiny delay to allow cancel test to interleave | |
| time.sleep(0.01) | |
| def get_context_report(self): | |
| return { | |
| "compressionEnabled": True, | |
| "strategy": "truncate", | |
| "safetyMargin": 256, | |
| "modelMaxContext": 8192, | |
| "tokenizerModelMaxLength": 8192, | |
| "last": self.last_context_info, | |
| } | |
| def patched_engine(): | |
| # Patch global engine so server does not load real model | |
| prev_engine = main._engine | |
| prev_err = main._engine_error | |
| fake = FakeEngine() | |
| main._engine = fake | |
| main._engine_error = None | |
| try: | |
| yield fake | |
| finally: | |
| main._engine = prev_engine | |
| main._engine_error = prev_err | |
| def get_client(): | |
| return TestClient(main.app) | |
| def test_health_ready_and_context(): | |
| with patched_engine(): | |
| client = get_client() | |
| r = client.get("/health") | |
| assert r.status_code == 200 | |
| body = r.json() | |
| assert body["ok"] is True | |
| assert body["modelReady"] is True | |
| assert body["modelId"] == "fake-model" | |
| # context block exists with required fields | |
| ctx = body["context"] | |
| assert ctx["compressionEnabled"] is True | |
| assert "last" in ctx | |
| assert isinstance(ctx["last"].get("prompt_tokens"), int) | |
| def test_health_with_engine_error(): | |
| # simulate model load error path | |
| prev_engine = main._engine | |
| prev_err = main._engine_error | |
| try: | |
| main._engine = None | |
| main._engine_error = "boom" | |
| client = get_client() | |
| r = client.get("/health") | |
| assert r.status_code == 200 | |
| body = r.json() | |
| assert body["modelReady"] is False | |
| assert body["error"] == "boom" | |
| finally: | |
| main._engine = prev_engine | |
| main._engine_error = prev_err | |
| def test_chat_non_stream_validation(): | |
| with patched_engine(): | |
| client = get_client() | |
| # missing messages should 400 | |
| r = client.post("/v1/chat/completions", json={"messages": []}) | |
| assert r.status_code == 400 | |
| def test_chat_non_stream_success_and_usage_context(): | |
| with patched_engine(): | |
| client = get_client() | |
| payload = { | |
| "messages": [{"role": "user", "content": "Hello Qwen"}], | |
| "max_tokens": 8, | |
| "temperature": 0.0, | |
| } | |
| r = client.post("/v1/chat/completions", json=payload) | |
| assert r.status_code == 200 | |
| body = r.json() | |
| assert body["object"] == "chat.completion" | |
| assert body["choices"][0]["message"]["content"].startswith("OK:") | |
| # usage prompt_tokens filled from engine.last_context_info | |
| assert body["usage"]["prompt_tokens"] >= 1 | |
| # response includes context echo | |
| assert "context" in body | |
| assert "prompt_tokens" in body["context"] | |
| def test_chat_non_stream_parse_error_to_400(): | |
| with patched_engine(): | |
| client = get_client() | |
| payload = { | |
| "messages": [{"role": "user", "content": "PARSE_ERR trigger"}], | |
| "max_tokens": 4, | |
| } | |
| r = client.post("/v1/chat/completions", json=payload) | |
| # ValueError in engine -> 400 per API contract | |
| assert r.status_code == 400 | |
| def read_sse_lines(resp): | |
| # Utility to parse event-stream into list of data payloads (including [DONE]) | |
| lines = [] | |
| buf = b"" | |
| # Starlette TestClient (httpx) responses expose iter_bytes()/iter_raw(), not requests.iter_content(). | |
| # Fall back to available iterator or to full content if streaming isn't supported. | |
| iterator = None | |
| for name in ("iter_bytes", "iter_raw", "iter_content"): | |
| it = getattr(resp, name, None) | |
| if callable(it): | |
| iterator = it | |
| break | |
| if iterator is None: | |
| data = getattr(resp, "content", b"") | |
| if isinstance(data, str): | |
| data = data.encode("utf-8", "ignore") | |
| buf = data | |
| else: | |
| for chunk in iterator(): | |
| if not chunk: | |
| continue | |
| if isinstance(chunk, str): | |
| chunk = chunk.encode("utf-8", "ignore") | |
| buf += chunk | |
| while b"\n\n" in buf: | |
| frame, buf = buf.split(b"\n\n", 1) | |
| # keep original frame text for asserts | |
| lines.append(frame.decode("utf-8", errors="ignore")) | |
| # Drain any leftover | |
| if buf: | |
| lines.append(buf.decode("utf-8", errors="ignore")) | |
| return lines | |
| def test_chat_stream_sse_flow_and_resume(): | |
| with patched_engine(): | |
| client = get_client() | |
| payload = { | |
| "session_id": "s1", | |
| "stream": True, | |
| "messages": [{"role": "user", "content": "stream please"}], | |
| "max_tokens": 8, | |
| "temperature": 0.2, | |
| } | |
| with client.stream("POST", "/v1/chat/completions", json=payload) as resp: | |
| assert resp.status_code == 200 | |
| lines = read_sse_lines(resp) | |
| # Must contain role delta, content pieces, finish chunk, and [DONE] | |
| joined = "\n".join(lines) | |
| assert "delta" in joined | |
| assert "[DONE]" in joined | |
| # Resume from event index 0 should receive at least one subsequent event | |
| headers = {"Last-Event-ID": "s1:0"} | |
| with client.stream("POST", "/v1/chat/completions", headers=headers, json=payload) as resp2: | |
| assert resp2.status_code == 200 | |
| lines2 = read_sse_lines(resp2) | |
| assert any("data:" in l for l in lines2) | |
| assert "[DONE]" in "\n".join(lines2) | |
| # Invalid Last-Event-ID format should not crash (covered by try/except) | |
| headers_bad = {"Last-Event-ID": "not-an-index"} | |
| with client.stream("POST", "/v1/chat/completions", headers=headers_bad, json=payload) as resp3: | |
| assert resp3.status_code == 200 | |
| _ = read_sse_lines(resp3) # just ensure no crash | |
| def test_cancel_endpoint_stops_generation(): | |
| with patched_engine(): | |
| client = get_client() | |
| payload = { | |
| "session_id": "to-cancel", | |
| "stream": True, | |
| "messages": [{"role": "user", "content": "cancel me"}], | |
| } | |
| # Start streaming in background (client.stream keeps the connection open) | |
| with client.stream("POST", "/v1/chat/completions", json=payload) as resp: | |
| # Immediately cancel | |
| rc = client.post("/v1/cancel/to-cancel") | |
| assert rc.status_code == 200 | |
| # Stream should end with [DONE] without hanging | |
| lines = read_sse_lines(resp) | |
| assert "[DONE]" in "\n".join(lines) | |
| def test_cancel_unknown_session_is_ok(): | |
| with patched_engine(): | |
| client = get_client() | |
| rc = client.post("/v1/cancel/does-not-exist") | |
| # Endpoint returns ok regardless (idempotent, operationally safe) | |
| assert rc.status_code == 200 | |
| def test_edge_large_last_event_id_after_finish_yields_done(): | |
| with patched_engine(): | |
| client = get_client() | |
| payload = { | |
| "session_id": "done-session", | |
| "stream": True, | |
| "messages": [{"role": "user", "content": "edge"}], | |
| } | |
| # Complete a run | |
| with client.stream("POST", "/v1/chat/completions", json=payload) as resp: | |
| _ = read_sse_lines(resp) | |
| # Resume with huge index; should return DONE quickly | |
| headers = {"Last-Event-ID": "done-session:99999"} | |
| with client.stream("POST", "/v1/chat/completions", headers=headers, json=payload) as resp2: | |
| lines2 = read_sse_lines(resp2) | |
| assert "[DONE]" in "\n".join(lines2) |