File size: 43,731 Bytes
7cd14d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
FastAPI Inference Server (OpenAI-compatible) for Qwen3-VL multimodal model.

- Default model: Qwen/Qwen3-VL-2B-Thinking
- Endpoints:
  * GET /openapi.yaml     (OpenAPI schema in YAML)
  * GET /health           (readiness + context report)
  * POST /v1/chat/completions (non-stream and streaming SSE)
  * POST /v1/cancel/{session_id} (custom cancel endpoint)

Notes:
- Uses Hugging Face Transformers with trust_remote_code=True.
- Supports OpenAI-style chat messages with text, image_url/input_image, video_url/input_video.
- Streaming SSE supports resume (session_id + Last-Event-ID) with optional SQLite persistence.
- Auto prompt compression prevents context overflow with a simple truncate strategy.
"""

import os
import io
import re
import base64
import tempfile
import contextlib
from typing import Any, Dict, List, Optional, Tuple, Deque

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from starlette.responses import JSONResponse
from fastapi.responses import StreamingResponse, Response
import json
import yaml
import threading
import time
import uuid
import sqlite3
from collections import deque
import subprocess
import sys
import shutil

# Load env
try:
    from dotenv import load_dotenv
    load_dotenv()
except Exception:
    pass

# Ensure HF cache dirs are relative to this project by default
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_HF_CACHE = os.path.join(ROOT_DIR, "hf-cache")
if not os.getenv("HF_HOME"):
    os.environ["HF_HOME"] = DEFAULT_HF_CACHE
if not os.getenv("TRANSFORMERS_CACHE"):
    os.environ["TRANSFORMERS_CACHE"] = DEFAULT_HF_CACHE
# Create directory eagerly to avoid later mkdir races
try:
    os.makedirs(os.environ["HF_HOME"], exist_ok=True)
except Exception:
    pass

# Optional heavy deps are imported lazily inside Engine to improve startup UX
import requests
from PIL import Image
import numpy as np
from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download, get_hf_file_metadata

# Server config
PORT = int(os.getenv("PORT", "3000"))
DEFAULT_MODEL_ID = os.getenv("MODEL_REPO_ID", "Qwen/Qwen3-VL-2B-Thinking")
HF_TOKEN = os.getenv("HF_TOKEN", "").strip() or None
DEFAULT_MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256"))
DEFAULT_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
MAX_VIDEO_FRAMES = int(os.getenv("MAX_VIDEO_FRAMES", "16"))
DEVICE_MAP = os.getenv("DEVICE_MAP", "auto")
TORCH_DTYPE = os.getenv("TORCH_DTYPE", "auto")

# Persistent session store (SQLite)
PERSIST_SESSIONS = str(os.getenv("PERSIST_SESSIONS", "0")).lower() in ("1", "true", "yes", "y")
SESSIONS_DB_PATH = os.getenv("SESSIONS_DB_PATH", "sessions.db")
SESSIONS_TTL_SECONDS = int(os.getenv("SESSIONS_TTL_SECONDS", "600"))
# Auto-cancel if all clients disconnect for duration (seconds). 0 disables it.
CANCEL_AFTER_DISCONNECT_SECONDS = int(os.getenv("CANCEL_AFTER_DISCONNECT_SECONDS", "3600"))

# Auto compression settings
ENABLE_AUTO_COMPRESSION = str(os.getenv("ENABLE_AUTO_COMPRESSION", "1")).lower() in ("1", "true", "yes", "y")
CONTEXT_MAX_TOKENS_AUTO = int(os.getenv("CONTEXT_MAX_TOKENS_AUTO", "0"))  # 0 -> infer from model/tokenizer
CONTEXT_SAFETY_MARGIN = int(os.getenv("CONTEXT_SAFETY_MARGIN", "256"))
COMPRESSION_STRATEGY = os.getenv("COMPRESSION_STRATEGY", "truncate")  # truncate | summarize (future)

# Eager model loading (download/check at startup before serving traffic)
EAGER_LOAD_MODEL = str(os.getenv("EAGER_LOAD_MODEL", "1")).lower() in ("1", "true", "yes", "y")

def _log(msg: str):
    # Consistent, flush-immediate startup logs
    print(f"[startup] {msg}", flush=True)

def prefetch_model_assets(repo_id: str, token: Optional[str]) -> Optional[str]:
    """
    Reproducible prefetch driven by huggingface-cli:
    - Downloads the ENTIRE repo using CLI (visible progress bar).
    - Returns the local directory path where the repo is mirrored.
    - If CLI is unavailable, falls back to verbose API prefetch.
    """
    try:
        # Enable accelerated transfer + xet if available
        os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
        os.environ.setdefault("HF_HUB_ENABLE_XET", "1")

        cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or ""
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)

        # Resolve huggingface-cli path (Windows-friendly)
        cli_path = shutil.which("huggingface-cli")
        if not cli_path:
            candidates = []
            appdata = os.getenv("APPDATA")
            if appdata:
                candidates.append(os.path.join(appdata, "Python", "Python312", "Scripts", "huggingface-cli.exe"))
            candidates.append(os.path.join(os.path.dirname(sys.executable), "Scripts", "huggingface-cli.exe"))
            cli_path = next((p for p in candidates if os.path.exists(p)), None)

        # Preferred: one-shot CLI download for the whole repo (shows live progress)
        if cli_path:
            local_root = os.path.join(cache_dir if cache_dir else ".", repo_id.replace("/", "_"))
            os.makedirs(local_root, exist_ok=True)
            _log(f"Using huggingface-cli to download entire repo -> '{local_root}'")
            cmd = [
                cli_path,
                "download",
                repo_id,
                "--repo-type",
                "model",
                "--local-dir",
                local_root,
                "--local-dir-use-symlinks",
                "False",
                "--resume",
            ]
            if token:
                cmd += ["--token", token]
            # Inherit stdio; users will see a proper progress bar
            subprocess.run(cmd, check=False)
            # Verify we have the essential files
            if os.path.exists(os.path.join(local_root, "config.json")) or os.path.exists(os.path.join(local_root, "model.safetensors")):
                _log("CLI prefetch completed")
                return local_root
            else:
                _log("CLI prefetch finished but essential files not found; will fallback to API mirroring")

        # Fallback: verbose API-driven prefetch with per-file logging
        _log(f"Prefetching (API) repo={repo_id} to cache='{cache_dir}'")
        try:
            files = list_repo_files(repo_id, repo_type="model", token=token)
        except Exception as e:
            _log(f"list_repo_files failed ({type(e).__name__}: {e}); falling back to snapshot_download")
            snapshot_download(repo_id, token=token, local_files_only=False)
            _log("Prefetch completed (snapshot)")
            return None

        total = len(files)
        _log(f"Found {total} files to ensure cached (API)")
        for i, fn in enumerate(files, start=1):
            try:
                meta = get_hf_file_metadata(repo_id, fn, repo_type="model", token=token)
                size_bytes = meta.size or 0
            except Exception:
                size_bytes = 0
            size_mb = size_bytes / (1024 * 1024) if size_bytes else 0.0
            _log(f"[{i}/{total}] fetching '{fn}' (~{size_mb:.2f} MB)")
            _ = hf_hub_download(
                repo_id=repo_id,
                filename=fn,
                repo_type="model",
                token=token,
                local_files_only=False,
                resume_download=True,
            )
            _log(f"[{i}/{total}] done '{fn}'")
        _log("Prefetch completed (API)")
        return None
    except Exception as e:
        _log(f"Prefetch skipped: {type(e).__name__}: {e}")
        return None

def is_data_url(url: str) -> bool:
    return url.startswith("data:") and ";base64," in url


def is_http_url(url: str) -> bool:
    return url.startswith("http://") or url.startswith("https://")


def decode_base64_to_bytes(b64: str) -> bytes:
    # strip possible "data:*;base64," prefix
    if "base64," in b64:
        b64 = b64.split("base64,", 1)[1]
    return base64.b64decode(b64, validate=False)


def fetch_bytes(url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 60) -> bytes:
    if not is_http_url(url):
        raise ValueError(f"Only http(s) URLs supported for fetch, got: {url}")
    resp = requests.get(url, headers=headers or {}, timeout=timeout, stream=True)
    resp.raise_for_status()
    return resp.content


def load_image_from_any(src: Dict[str, Any]) -> Image.Image:
    """
    src can be:
      - { "url": "http(s)://..." } (also supports data URL)
      - { "b64_json": "<base64>" }
      - { "path": "local_path" } (optional)
    """
    if "b64_json" in src and src["b64_json"]:
        data = decode_base64_to_bytes(str(src["b64_json"]))
        return Image.open(io.BytesIO(data)).convert("RGB")

    if "url" in src and src["url"]:
        url = str(src["url"])
        if is_data_url(url):
            data = decode_base64_to_bytes(url)
            return Image.open(io.BytesIO(data)).convert("RGB")
        if is_http_url(url):
            data = fetch_bytes(url)
            return Image.open(io.BytesIO(data)).convert("RGB")
        # treat as local path
        if os.path.exists(url):
            with open(url, "rb") as f:
                return Image.open(io.BytesIO(f.read())).convert("RGB")
        raise ValueError(f"Invalid image url/path: {url}")

    if "path" in src and src["path"]:
        p = str(src["path"])
        if os.path.exists(p):
            with open(p, "rb") as f:
                return Image.open(io.BytesIO(f.read())).convert("RGB")
        raise ValueError(f"Image path not found: {p}")

    raise ValueError("Unsupported image source payload")


def write_bytes_tempfile(data: bytes, suffix: str) -> str:
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
    with tmp as f:
        f.write(data)
    return tmp.name


def load_video_frames_from_any(src: Dict[str, Any], max_frames: int = MAX_VIDEO_FRAMES) -> List[Image.Image]:
    """
    Returns a list of PIL.Image frames (RGB) sampled up to max_frames.
    src can be:
      - { "url": "http(s)://..." } (mp4/mov/webm/etc.)
      - { "b64_json": "<base64 of a video file>" }
      - { "path": "local_path" }
    """
    # Prefer imageio.v3 if present, fallback to OpenCV
    # We load all frames then uniform sample if too many.
    def _load_all_frames(path: str) -> List[Image.Image]:
        frames: List[Image.Image] = []
        with contextlib.suppress(ImportError):
            import imageio.v3 as iio
            arr_iter = iio.imiter(path)  # yields numpy arrays HxWxC
            for arr in arr_iter:
                if arr is None:
                    continue
                if arr.ndim == 2:
                    arr = np.stack([arr, arr, arr], axis=-1)
                if arr.shape[-1] == 4:
                    arr = arr[..., :3]
                frames.append(Image.fromarray(arr).convert("RGB"))
            return frames

        # Fallback to OpenCV
        import cv2  # type: ignore
        cap = cv2.VideoCapture(path)
        ok, frame = cap.read()
        while ok:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame))
            ok, frame = cap.read()
        cap.release()
        return frames

    # Resolve to a local path
    local_path = None
    if "b64_json" in src and src["b64_json"]:
        data = decode_base64_to_bytes(str(src["b64_json"]))
        local_path = write_bytes_tempfile(data, suffix=".mp4")
    elif "url" in src and src["url"]:
        url = str(src["url"])
        if is_data_url(url):
            data = decode_base64_to_bytes(url)
            local_path = write_bytes_tempfile(data, suffix=".mp4")
        elif is_http_url(url):
            data = fetch_bytes(url)
            local_path = write_bytes_tempfile(data, suffix=".mp4")
        elif os.path.exists(url):
            local_path = url
        else:
            raise ValueError(f"Invalid video url/path: {url}")
    elif "path" in src and src["path"]:
        p = str(src["path"])
        if os.path.exists(p):
            local_path = p
        else:
            raise ValueError(f"Video path not found: {p}")
    else:
        raise ValueError("Unsupported video source payload")

    frames = _load_all_frames(local_path)
    # Uniform sample if too many frames
    if len(frames) > max_frames and max_frames > 0:
        idxs = np.linspace(0, len(frames) - 1, max_frames).astype(int).tolist()
        frames = [frames[i] for i in idxs]
    return frames


class ChatRequest(BaseModel):
    model: Optional[str] = None
    messages: List[Dict[str, Any]]
    max_tokens: Optional[int] = None
    temperature: Optional[float] = None
    stream: Optional[bool] = None
    session_id: Optional[str] = None


class Engine:
    def __init__(self, model_id: str, hf_token: Optional[str] = None):
        # Lazy import heavy deps
        from transformers import AutoProcessor, AutoModelForCausalLM, AutoModelForVision2Seq, AutoModel
        # AutoModelForImageTextToText is the v5+ replacement for Vision2Seq in Transformers
        try:
            from transformers import AutoModelForImageTextToText  # type: ignore
        except Exception:
            AutoModelForImageTextToText = None  # type: ignore

        model_kwargs: Dict[str, Any] = {
            "trust_remote_code": True,
        }
        if hf_token:
            # Only pass 'token' (use_auth_token is deprecated and causes conflicts)
            model_kwargs["token"] = hf_token
        # Device and dtype
        model_kwargs["device_map"] = DEVICE_MAP
        model_kwargs["torch_dtype"] = TORCH_DTYPE if TORCH_DTYPE != "auto" else "auto"

        # Processor (handles text + images/videos)
        proc_kwargs: Dict[str, Any] = {"trust_remote_code": True}
        if hf_token:
            proc_kwargs["token"] = hf_token
        self.processor = AutoProcessor.from_pretrained(
            model_id,
            **proc_kwargs,
        )  # pragma: no cover

        # Prefer ImageTextToText (Transformers v5 path), then Vision2Seq, then CausalLM as a last resort
        model = None
        if 'AutoModelForImageTextToText' in globals() and AutoModelForImageTextToText is not None:
            try:
                model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)  # pragma: no cover
            except Exception:
                model = None
        if model is None:
            try:
                model = AutoModelForVision2Seq.from_pretrained(model_id, **model_kwargs)  # pragma: no cover
            except Exception:
                model = None
        if model is None:
            try:
                model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)  # pragma: no cover
            except Exception:
                model = None
        if model is None:
            # Generic AutoModel as last-resort with trust_remote_code to load custom architectures
            model = AutoModel.from_pretrained(model_id, **model_kwargs)  # pragma: no cover
        self.model = model.eval()  # pragma: no cover

        self.model_id = model_id
        self.tokenizer = getattr(self.processor, "tokenizer", None)
        self.last_context_info: Dict[str, Any] = {}

    def _model_max_context(self) -> int:
        try:
            cfg = getattr(self.model, "config", None)
            if cfg is not None:
                v = getattr(cfg, "max_position_embeddings", None)
                if isinstance(v, int) and v > 0 and v < 10_000_000:
                    return v
        except Exception:
            pass
        try:
            mx = int(getattr(self.tokenizer, "model_max_length", 0) or 0)
            if mx > 0 and mx < 10_000_000_000:
                return mx
        except Exception:
            pass
        return 32768

    def _count_prompt_tokens(self, text: str) -> int:
        try:
            if self.tokenizer is not None:
                enc = self.tokenizer([text], add_special_tokens=False, return_attention_mask=False)
                ids = enc["input_ids"][0]
                return len(ids)
        except Exception:
            pass
        return max(1, int(len(text.split()) * 1.3))

    def _auto_compress_if_needed(
        self, mm_messages: List[Dict[str, Any]], max_new_tokens: int
    ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
        info: Dict[str, Any] = {}
        # Build once to measure
        text0 = self.processor.apply_chat_template(mm_messages, tokenize=False, add_generation_prompt=True)
        prompt_tokens = self._count_prompt_tokens(text0)
        max_ctx = CONTEXT_MAX_TOKENS_AUTO if CONTEXT_MAX_TOKENS_AUTO > 0 else self._model_max_context()
        budget = max(1024, max_ctx - CONTEXT_SAFETY_MARGIN - int(max_new_tokens))
        if not ENABLE_AUTO_COMPRESSION or prompt_tokens <= budget:
            info = {
                "compressed": False,
                "prompt_tokens": int(prompt_tokens),
                "max_context": int(max_ctx),
                "budget": int(budget),
                "strategy": COMPRESSION_STRATEGY,
                "dropped_messages": 0,
            }
            return mm_messages, info

        # Truncate earliest non-system messages until within budget
        msgs = list(mm_messages)
        dropped = 0
        guard = 0
        while True:
            text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
            prompt_tokens = self._count_prompt_tokens(text)
            if prompt_tokens <= budget or len(msgs) <= 1:
                break
            # drop earliest non-system
            drop_idx = None
            for j, m in enumerate(msgs):
                if (m.get("role") or "user") != "system":
                    drop_idx = j
                    break
            if drop_idx is None:
                break
            msgs.pop(drop_idx)
            dropped += 1
            guard += 1
            if guard > 10000:
                break

        info = {
            "compressed": True,
            "prompt_tokens": int(prompt_tokens),
            "max_context": int(max_ctx),
            "budget": int(budget),
            "strategy": "truncate",
            "dropped_messages": int(dropped),
        }
        return msgs, info

    def get_context_report(self) -> Dict[str, Any]:
        try:
            tk_max = int(getattr(self.tokenizer, "model_max_length", 0) or 0)
        except Exception:
            tk_max = 0
        return {
            "compressionEnabled": ENABLE_AUTO_COMPRESSION,
            "strategy": COMPRESSION_STRATEGY,
            "safetyMargin": CONTEXT_SAFETY_MARGIN,
            "modelMaxContext": self._model_max_context(),
            "tokenizerModelMaxLength": tk_max,
            "last": self.last_context_info or {},
        }

    def build_mm_messages(
        self, openai_messages: List[Dict[str, Any]]
    ) -> Tuple[List[Dict[str, Any]], List[Image.Image], List[List[Image.Image]]]:
        """
        Convert OpenAI-style messages to Qwen multimodal messages.
        Returns:
          - messages for apply_chat_template
          - flat list of images in encounter order
          - list of videos (each is list of PIL frames)
        """
        mm_msgs: List[Dict[str, Any]] = []
        images: List[Image.Image] = []
        videos: List[List[Image.Image]] = []

        for msg in openai_messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")

            parts: List[Dict[str, Any]] = []

            if isinstance(content, str):
                if content:
                    parts.append({"type": "text", "text": content})
            elif isinstance(content, list):
                for p in content:
                    ptype = p.get("type")
                    if ptype == "text":
                        txt = p.get("text", "")
                        if txt:
                            parts.append({"type": "text", "text": txt})
                    elif ptype in ("image_url", "input_image"):
                        src: Dict[str, Any] = {}
                        if ptype == "image_url":
                            u = (p.get("image_url") or {}).get("url") if isinstance(p.get("image_url"), dict) else p.get("image_url")
                            src["url"] = u
                        else:
                            b64 = p.get("image") or p.get("b64_json") or p.get("data") or (p.get("image_url") or {}).get("url")
                            if b64:
                                src["b64_json"] = b64
                        try:
                            img = load_image_from_any(src)
                            images.append(img)
                            parts.append({"type": "image", "image": img})
                        except Exception as e:
                            raise ValueError(f"Failed to parse image part: {e}") from e
                    elif ptype in ("video_url", "input_video"):
                        src = {}
                        if ptype == "video_url":
                            u = (p.get("video_url") or {}).get("url") if isinstance(p.get("video_url"), dict) else p.get("video_url")
                            src["url"] = u
                        else:
                            b64 = p.get("video") or p.get("b64_json") or p.get("data")
                            if b64:
                                src["b64_json"] = b64
                        try:
                            frames = load_video_frames_from_any(src, max_frames=MAX_VIDEO_FRAMES)
                            videos.append(frames)
                            parts.append({"type": "video", "video": frames})
                        except Exception as e:
                            raise ValueError(f"Failed to parse video part: {e}") from e
                    else:
                        if isinstance(p, dict):
                            txt = p.get("text")
                            if isinstance(txt, str) and txt:
                                parts.append({"type": "text", "text": txt})
            else:
                if content:
                    parts.append({"type": "text", "text": str(content)})

            mm_msgs.append({"role": role, "content": parts})

        return mm_msgs, images, videos

    def infer(self, messages: List[Dict[str, Any]], max_tokens: int, temperature: float) -> str:
        mm_messages, images, videos = self.build_mm_messages(messages)
        # Auto-compress if needed based on context budget
        mm_messages, ctx_info = self._auto_compress_if_needed(mm_messages, max_tokens)
        self.last_context_info = ctx_info

        # Build chat template
        text = self.processor.apply_chat_template(
            mm_messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        proc_kwargs: Dict[str, Any] = {"text": [text], "return_tensors": "pt"}
        if images:
            proc_kwargs["images"] = images
        if videos:
            proc_kwargs["videos"] = videos

        inputs = self.processor(**proc_kwargs)
        # Move tensors to model device if present
        try:
            device = getattr(self.model, "device", None) or next(self.model.parameters()).device
            inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
        except Exception:
            pass

        do_sample = temperature is not None and float(temperature) > 0.0

        gen_ids = self.model.generate(
            **inputs,
            max_new_tokens=int(max_tokens),
            temperature=float(temperature),
            do_sample=do_sample,
            use_cache=True,
        )
        # Decode
        output = self.processor.batch_decode(
            gen_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0]

        # Best-effort: return only the assistant reply after the last template marker if present
        parts = re.split(r"\n?assistant:\s*", output, flags=re.IGNORECASE)
        if len(parts) >= 2:
            return parts[-1].strip()
        return output.strip()

    def infer_stream(
        self,
        messages: List[Dict[str, Any]],
        max_tokens: int,
        temperature: float,
        cancel_event: Optional[threading.Event] = None,
    ):
        from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList

        mm_messages, images, videos = self.build_mm_messages(messages)
        # Auto-compress if needed based on context budget
        mm_messages, ctx_info = self._auto_compress_if_needed(mm_messages, max_tokens)
        self.last_context_info = ctx_info

        text = self.processor.apply_chat_template(
            mm_messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        proc_kwargs: Dict[str, Any] = {"text": [text], "return_tensors": "pt"}
        if images:
            proc_kwargs["images"] = images
        if videos:
            proc_kwargs["videos"] = videos

        inputs = self.processor(**proc_kwargs)
        try:
            device = getattr(self.model, "device", None) or next(self.model.parameters()).device
            inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
        except Exception:
            pass

        do_sample = temperature is not None and float(temperature) > 0.0

        streamer = TextIteratorStreamer(
            getattr(self.processor, "tokenizer", None),
            skip_prompt=True,
            skip_special_tokens=True,
        )

        gen_kwargs = dict(
            **inputs,
            max_new_tokens=int(max_tokens),
            temperature=float(temperature),
            do_sample=do_sample,
            use_cache=True,
            streamer=streamer,
        )

        # Optional cooperative cancellation via StoppingCriteria
        if cancel_event is not None:
            class _CancelCrit(StoppingCriteria):
                def __init__(self, ev: threading.Event):
                    self.ev = ev

                def __call__(self, input_ids, scores, **kwargs):
                    return bool(self.ev.is_set())

            gen_kwargs["stopping_criteria"] = StoppingCriteriaList([_CancelCrit(cancel_event)])

        th = threading.Thread(target=self.model.generate, kwargs=gen_kwargs)
        th.start()

        for piece in streamer:
            if piece:
                yield piece


# Simple in-memory resumable SSE session store + optional SQLite persistence
class _SSESession:
    def __init__(self, maxlen: int = 2048, ttl_seconds: int = 600):
        self.buffer: Deque[Tuple[int, str]] = deque(maxlen=maxlen)  # (idx, sse_line_block)
        self.last_idx: int = -1
        self.created: float = time.time()
        self.finished: bool = False
        self.cond = threading.Condition()
        self.thread: Optional[threading.Thread] = None
        self.ttl_seconds = ttl_seconds
        # Cancellation + client tracking
        self.cancel_event = threading.Event()
        self.listeners: int = 0
        self.cancel_timer = None  # type: ignore


class _SessionStore:
    def __init__(self, ttl_seconds: int = 600, max_sessions: int = 256):
        self._sessions: Dict[str, _SSESession] = {}
        self._lock = threading.Lock()
        self._ttl = ttl_seconds
        self._max_sessions = max_sessions

    def get_or_create(self, sid: str) -> _SSESession:
        with self._lock:
            sess = self._sessions.get(sid)
            if sess is None:
                sess = _SSESession(ttl_seconds=self._ttl)
                self._sessions[sid] = sess
            return sess

    def get(self, sid: str) -> Optional[_SSESession]:
        with self._lock:
            return self._sessions.get(sid)

    def gc(self):
        now = time.time()
        with self._lock:
            # remove expired
            expired = [k for k, v in self._sessions.items() if (now - v.created) > self._ttl or (v.finished and (now - v.created) > self._ttl / 4)]
            for k in expired:
                self._sessions.pop(k, None)
            # bound session count
            if len(self._sessions) > self._max_sessions:
                for k, _ in sorted(self._sessions.items(), key=lambda kv: kv[1].created)[: max(0, len(self._sessions) - self._max_sessions)]:
                    self._sessions.pop(k, None)


class _SQLiteStore:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self._lock = threading.Lock()
        self._conn = sqlite3.connect(self.db_path, check_same_thread=False)
        self._conn.execute("PRAGMA journal_mode=WAL;")
        self._conn.execute("PRAGMA synchronous=NORMAL;")
        self._ensure_schema()

    def _ensure_schema(self):
        cur = self._conn.cursor()
        cur.execute(
            "CREATE TABLE IF NOT EXISTS sessions (session_id TEXT PRIMARY KEY, created REAL, finished INTEGER DEFAULT 0)"
        )
        cur.execute(
            "CREATE TABLE IF NOT EXISTS events (session_id TEXT, idx INTEGER, data TEXT, created REAL, PRIMARY KEY(session_id, idx))"
        )
        cur.execute("CREATE INDEX IF NOT EXISTS idx_events_session ON events(session_id, idx)")
        self._conn.commit()

    def ensure_session(self, session_id: str, created: int):
        with self._lock:
            self._conn.execute(
                "INSERT OR IGNORE INTO sessions(session_id, created, finished) VALUES (?, ?, 0)",
                (session_id, float(created)),
            )
            self._conn.commit()

    def append_event(self, session_id: str, idx: int, payload: Dict[str, Any]):
        data = json.dumps(payload, ensure_ascii=False)
        with self._lock:
            self._conn.execute(
                "INSERT OR REPLACE INTO events(session_id, idx, data, created) VALUES (?, ?, ?, ?)",
                (session_id, idx, data, time.time()),
            )
            self._conn.commit()

    def get_events_after(self, session_id: str, last_idx: int) -> List[Tuple[int, str]]:
        with self._lock:
            cur = self._conn.execute(
                "SELECT idx, data FROM events WHERE session_id=? AND idx>? ORDER BY idx ASC", (session_id, last_idx)
            )
            return [(int(r[0]), str(r[1])) for r in cur.fetchall()]

    def mark_finished(self, session_id: str):
        with self._lock:
            self._conn.execute("UPDATE sessions SET finished=1 WHERE session_id=?", (session_id,))
            self._conn.commit()

    def session_meta(self, session_id: str) -> Tuple[bool, int]:
        with self._lock:
            row = self._conn.execute("SELECT finished FROM sessions WHERE session_id=?", (session_id,)).fetchone()
            finished = bool(row[0]) if row else False
            row2 = self._conn.execute("SELECT MAX(idx) FROM events WHERE session_id=?", (session_id,)).fetchone()
            last_idx = int(row2[0]) if row2 and row2[0] is not None else -1
            return finished, last_idx

    def gc(self, ttl_seconds: int):
        cutoff = time.time() - float(ttl_seconds)
        with self._lock:
            cur = self._conn.execute("SELECT session_id FROM sessions WHERE finished=1 AND created<?", (cutoff,))
            ids = [r[0] for r in cur.fetchall()]
            for sid in ids:
                self._conn.execute("DELETE FROM events WHERE session_id=?", (sid,))
                self._conn.execute("DELETE FROM sessions WHERE session_id=?", (sid,))
            self._conn.commit()


def _sse_event(session_id: str, idx: int, payload: Dict[str, Any]) -> str:
    # Include SSE id line so clients can send Last-Event-ID to resume.
    return f"id: {session_id}:{idx}\n" + f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"


_STORE = _SessionStore()
_DB_STORE = _SQLiteStore(SESSIONS_DB_PATH) if PERSIST_SESSIONS else None

# FastAPI app and OpenAPI tags
tags_metadata = [
    {"name": "meta", "description": "Service metadata and OpenAPI schema"},
    {"name": "health", "description": "Readiness and runtime info including context window report"},
    {"name": "chat", "description": "OpenAI-compatible chat completions (non-stream and streaming SSE)"},
]

app = FastAPI(
    title="Qwen3-VL Inference Server",
    version="1.0.0",
    description="OpenAI-compatible inference server for Qwen3-VL with multimodal support, streaming SSE with resume, context auto-compression, and optional SQLite persistence.",
    openapi_tags=tags_metadata,
)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Startup hook is defined after get_engine() so globals are initialized first.

# Engine singletons
_engine: Optional[Engine] = None
_engine_error: Optional[str] = None


def get_engine() -> Engine:
    global _engine, _engine_error
    if _engine is not None:
        return _engine
    try:
        model_id = DEFAULT_MODEL_ID
        _log(f"Preparing model '{model_id}' (HF_HOME={os.getenv('HF_HOME')}, cache={os.getenv('TRANSFORMERS_CACHE')})")
        local_repo_dir = prefetch_model_assets(model_id, HF_TOKEN)
        load_id = local_repo_dir if (local_repo_dir and os.path.exists(os.path.join(local_repo_dir, 'config.json'))) else model_id
        _log(f"Loading processor and model from: {load_id}")
        _engine = Engine(model_id=load_id, hf_token=HF_TOKEN)
        _engine_error = None
        _log(f"Model ready: {_engine.model_id}")
        return _engine
    except Exception as e:
        _engine_error = f"{type(e).__name__}: {e}"
        _log(f"Engine init failed: {_engine_error}")
        raise

# Eager-load model at startup after definitions so it downloads/checks before serving traffic.
@app.on_event("startup")
def _startup_load_model():
    if EAGER_LOAD_MODEL:
        print("[startup] EAGER_LOAD_MODEL=1: initializing model...")
        try:
            _ = get_engine()
            print("[startup] Model loaded:", _engine.model_id if _engine else "unknown")
        except Exception as e:
            # Fail fast if model cannot be initialized
            print("[startup] Model load failed:", e)
            raise


@app.get("/", tags=["meta"])
def root():
    """Liveness check."""
    return JSONResponse({"ok": True})


@app.get("/openapi.yaml", tags=["meta"])
def openapi_yaml():
    """Serve OpenAPI schema as YAML for tooling compatibility."""
    schema = app.openapi()
    yml = yaml.safe_dump(schema, sort_keys=False)
    return Response(yml, media_type="application/yaml")


@app.get("/health", tags=["health"])
def health():
    ready = False
    err = None
    model_id = DEFAULT_MODEL_ID
    global _engine, _engine_error
    if _engine is not None:
        ready = True
        model_id = _engine.model_id
    elif _engine_error:
        err = _engine_error
    ctx = None
    try:
        if _engine is not None:
            ctx = _engine.get_context_report()
    except Exception:
        ctx = None
    return JSONResponse({"ok": True, "modelReady": ready, "modelId": model_id, "error": err, "context": ctx})


@app.post("/v1/chat/completions", tags=["chat"])
def chat_completions(request: Request, body: ChatRequest):
    # Ensure engine is loaded
    try:
        engine = get_engine()
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"Model not ready: {e}")

    if not body or not isinstance(body.messages, list) or len(body.messages) == 0:
        raise HTTPException(status_code=400, detail="messages must be a non-empty array")

    max_tokens = int(body.max_tokens) if isinstance(body.max_tokens, int) else DEFAULT_MAX_TOKENS
    temperature = float(body.temperature) if body.temperature is not None else DEFAULT_TEMPERATURE
    do_stream = bool(body.stream)

    # Parse Last-Event-ID for resuming and derive/align session_id
    last_event_id_header = request.headers.get("last-event-id")
    sid_from_header: Optional[str] = None
    last_idx_from_header: int = -1
    if last_event_id_header:
        try:
            sid_from_header, idx_str = last_event_id_header.split(":", 1)
            last_idx_from_header = int(idx_str)
        except Exception:
            sid_from_header = None
            last_idx_from_header = -1

    session_id = body.session_id or sid_from_header or f"sess-{uuid.uuid4().hex[:12]}"
    sess = _STORE.get_or_create(session_id)
    created_ts = int(sess.created)
    if _DB_STORE is not None:
        _DB_STORE.ensure_session(session_id, created_ts)

    if not do_stream:
        # Non-streaming path
        try:
            content = engine.infer(body.messages, max_tokens=max_tokens, temperature=temperature)
        except ValueError as e:
            # Parsing/user payload errors from engine -> HTTP 400
            raise HTTPException(status_code=400, detail=str(e))
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Inference error: {e}")

        now = int(time.time())
        prompt_tokens = int((engine.last_context_info or {}).get("prompt_tokens") or 0)
        completion_tokens = max(1, len((content or "").split()))
        total_tokens = prompt_tokens + completion_tokens
        resp: Dict[str, Any] = {
            "id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
            "object": "chat.completion",
            "created": now,
            "model": engine.model_id,
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": content},
                    "finish_reason": "stop",
                }
            ],
            "usage": {
                "prompt_tokens": prompt_tokens,
                "completion_tokens": completion_tokens,
                "total_tokens": total_tokens,
            },
            "context": engine.last_context_info or {},
        }
        return JSONResponse(resp)

    # Streaming SSE with resumable support
    def sse_generator():
        # Manage listener count and cancel timer
        sess.listeners += 1
        try:
            # Cancel any pending cancel timer when a listener attaches
            if getattr(sess, "cancel_timer", None):
                try:
                    sess.cancel_timer.cancel()
                except Exception:
                    pass
                sess.cancel_timer = None

            # Replay if Last-Event-ID was provided
            replay_from = last_idx_from_header if sid_from_header == session_id else -1
            if replay_from >= -1:
                # First try in-memory buffer
                for idx, block in list(sess.buffer):
                    if idx > replay_from:
                        yield block.encode("utf-8")
                # Optionally pull from SQLite persistence
                if _DB_STORE is not None:
                    try:
                        for idx, data in _DB_STORE.get_events_after(session_id, replay_from):
                            block = f"id: {session_id}:{idx}\n" + f"data: {data}\n\n"
                            yield block.encode("utf-8")
                    except Exception:
                        pass
                if sess.finished:
                    # Already finished; send terminal and exit
                    yield b"data: [DONE]\n\n"
                    return

            # Fresh generation path
            # Helper to append to buffers and yield to client
            def push(payload: Dict[str, Any]):
                sess.last_idx += 1
                idx = sess.last_idx
                block = _sse_event(session_id, idx, payload)
                sess.buffer.append((idx, block))
                if _DB_STORE is not None:
                    try:
                        _DB_STORE.append_event(session_id, idx, payload)
                    except Exception:
                        pass
                return block

            # Initial assistant role delta
            head = {
                "id": session_id,
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": engine.model_id,
                "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
                "system_fingerprint": "fastapi",
            }
            yield push(head).encode("utf-8")

            # Stream model pieces
            try:
                for piece in engine.infer_stream(
                    body.messages, max_tokens=max_tokens, temperature=temperature, cancel_event=sess.cancel_event
                ):
                    if not piece:
                        continue
                    payload = {
                        "id": session_id,
                        "object": "chat.completion.chunk",
                        "created": int(time.time()),
                        "model": engine.model_id,
                        "choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}],
                    }
                    yield push(payload).encode("utf-8")
                    # Cooperative early-exit if cancel requested
                    if sess.cancel_event.is_set():
                        break
            except Exception:
                # On engine error, terminate gracefully
                pass

            # Finish chunk
            finish = {
                "id": session_id,
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": engine.model_id,
                "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
            }
            yield push(finish).encode("utf-8")

        finally:
            # Mark finished and persist
            sess.finished = True
            if _DB_STORE is not None:
                try:
                    _DB_STORE.mark_finished(session_id)
                    # Optionally GC older finished sessions
                    _DB_STORE.gc(SESSIONS_TTL_SECONDS)
                except Exception:
                    pass

            # Always send terminal [DONE]
            yield b"data: [DONE]\n\n"

            # Listener bookkeeping and optional auto-cancel if all disconnect
            try:
                sess.listeners = max(0, sess.listeners - 1)
                if sess.listeners == 0 and CANCEL_AFTER_DISCONNECT_SECONDS > 0 and not sess.cancel_event.is_set():
                    def _later_cancel():
                        # If still no listeners, cancel
                        if sess.listeners == 0 and not sess.cancel_event.is_set():
                            sess.cancel_event.set()
                    sess.cancel_timer = threading.Timer(CANCEL_AFTER_DISCONNECT_SECONDS, _later_cancel)
                    sess.cancel_timer.daemon = True
                    sess.cancel_timer.start()
            except Exception:
                pass

            # In-memory store GC
            try:
                _STORE.gc()
            except Exception:
                pass

    headers = {
        "Cache-Control": "no-cache",
        "Connection": "keep-alive",
        "X-Accel-Buffering": "no",
    }
    return StreamingResponse(sse_generator(), media_type="text/event-stream", headers=headers)


@app.post("/v1/cancel/{session_id}", tags=["chat"])
def cancel_session(session_id: str):
    sess = _STORE.get(session_id)
    if sess is not None:
        try:
            sess.cancel_event.set()
            sess.finished = True
            if _DB_STORE is not None:
                _DB_STORE.mark_finished(session_id)
        except Exception:
            pass
    return JSONResponse({"ok": True, "session_id": session_id})


if __name__ == "__main__":
    import uvicorn

    uvicorn.run("main:app", host="0.0.0.0", port=PORT, reload=False)