m97j commited on
Commit
320c4f1
Β·
1 Parent(s): 72be173

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +36 -41
app.py CHANGED
@@ -1,37 +1,31 @@
1
- import os
2
  from fastapi import FastAPI, Request, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import HTMLResponse
 
5
  from manager.dialogue_manager import handle_dialogue
6
  from rag.rag_manager import chroma_initialized, load_game_docs_from_disk, add_docs, set_embedder
7
- from contextlib import asynccontextmanager
8
  from models.model_loader import load_fallback_model, load_embedder
9
  from schemas import AskReq, AskRes
10
- from pathlib import Path
11
  from config import (
12
  FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR,
13
  EMBEDDER_MODEL_NAME, EMBEDDER_MODEL_DIR,
14
  HF_TOKEN, BASE_DIR
15
  )
16
 
 
17
 
18
- @asynccontextmanager
19
- async def lifespan(app: FastAPI):
20
- print("πŸš€ μ„œλ²„ μ‹œμž‘ 쀑... λͺ¨λΈ λ‘œλ”© 쀑...")
21
-
22
- # Fallback
23
  fb_tokenizer, fb_model = load_fallback_model(FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR, token=HF_TOKEN)
24
  app.state.fallback_tokenizer = fb_tokenizer
25
  app.state.fallback_model = fb_model
26
 
27
- # Embedder
28
  embedder = load_embedder(EMBEDDER_MODEL_NAME, EMBEDDER_MODEL_DIR, token=HF_TOKEN)
29
  app.state.embedder = embedder
30
- set_embedder(embedder) # μΆ”κ°€
31
 
32
- print("βœ… λͺ¨λ“  λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
33
-
34
- # RAG μ΄ˆκΈ°ν™”
35
  docs_path = BASE_DIR / "rag" / "docs"
36
  if not chroma_initialized():
37
  docs = load_game_docs_from_disk(str(docs_path))
@@ -40,14 +34,17 @@ async def lifespan(app: FastAPI):
40
  else:
41
  print("πŸ”„ RAG DB 이미 μ΄ˆκΈ°ν™”λ¨")
42
 
43
- yield # μ•± μ‹€ν–‰
 
44
 
 
 
 
 
45
  print("πŸ›‘ μ„œλ²„ μ’…λ£Œ 쀑...")
46
 
47
-
48
  app = FastAPI(title="ai-server", lifespan=lifespan)
49
 
50
- # CORS μ„€μ • (game-serverμ—μ„œ μš”μ²­ κ°€λŠ₯ν•˜λ„λ‘)
51
  app.add_middleware(
52
  CORSMiddleware,
53
  allow_origins=["https://fpsgame-rrbc.onrender.com"],
@@ -56,20 +53,39 @@ app.add_middleware(
56
  allow_headers=["*"],
57
  )
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @app.post("/ask", response_model=AskRes)
61
  async def ask(request: Request, req: AskReq):
 
 
62
  if not req.context:
63
  raise HTTPException(status_code=400, detail="missing context")
64
-
65
  if not (req.session_id and req.npc_id and req.user_input):
66
  raise HTTPException(status_code=400, detail="missing fields")
67
 
68
  context = req.context
69
- npc_config = context.npc_config
70
- npc_config_dict = npc_config.model_dump() if npc_config else None
71
 
72
- result = await handle_dialogue(
73
  request=request,
74
  session_id=req.session_id,
75
  npc_id=req.npc_id,
@@ -77,27 +93,6 @@ async def ask(request: Request, req: AskReq):
77
  context=context.model_dump(),
78
  npc_config=npc_config_dict
79
  )
80
- return result
81
-
82
-
83
-
84
- @app.post("/wake")
85
- async def wake(request: Request):
86
- body = await request.json()
87
- session_id = body.get("session_id", "unknown")
88
- print(f"πŸ“‘ Wake signal received for session: {session_id}")
89
- return {"status": "awake", "session_id": session_id}
90
-
91
-
92
- from fastapi.responses import HTMLResponse
93
-
94
- @app.get("/", include_in_schema=False)
95
- async def root():
96
- return HTMLResponse("""
97
- <h1>Persona Chat Engine API</h1>
98
- <p>μ„œλ²„κ°€ 정상 μ‹€ν–‰ μ€‘μž…λ‹ˆλ‹€.</p>
99
- <p><a href="/docs">Swagger UI둜 이동</a></p>
100
- """)
101
 
102
 
103
  '''
 
1
+ import asyncio
2
  from fastapi import FastAPI, Request, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import HTMLResponse
5
+ from contextlib import asynccontextmanager
6
  from manager.dialogue_manager import handle_dialogue
7
  from rag.rag_manager import chroma_initialized, load_game_docs_from_disk, add_docs, set_embedder
 
8
  from models.model_loader import load_fallback_model, load_embedder
9
  from schemas import AskReq, AskRes
 
10
  from config import (
11
  FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR,
12
  EMBEDDER_MODEL_NAME, EMBEDDER_MODEL_DIR,
13
  HF_TOKEN, BASE_DIR
14
  )
15
 
16
+ model_ready = False # λͺ¨λΈ λ‘œλ”© μƒνƒœ ν”Œλž˜κ·Έ
17
 
18
+ async def load_models(app: FastAPI):
19
+ global model_ready
20
+ print("πŸš€ λͺ¨λΈ λ‘œλ”© μ‹œμž‘...")
 
 
21
  fb_tokenizer, fb_model = load_fallback_model(FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR, token=HF_TOKEN)
22
  app.state.fallback_tokenizer = fb_tokenizer
23
  app.state.fallback_model = fb_model
24
 
 
25
  embedder = load_embedder(EMBEDDER_MODEL_NAME, EMBEDDER_MODEL_DIR, token=HF_TOKEN)
26
  app.state.embedder = embedder
27
+ set_embedder(embedder)
28
 
 
 
 
29
  docs_path = BASE_DIR / "rag" / "docs"
30
  if not chroma_initialized():
31
  docs = load_game_docs_from_disk(str(docs_path))
 
34
  else:
35
  print("πŸ”„ RAG DB 이미 μ΄ˆκΈ°ν™”λ¨")
36
 
37
+ model_ready = True
38
+ print("βœ… λͺ¨λ“  λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
39
 
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI):
42
+ asyncio.create_task(load_models(app)) # λ°±κ·ΈλΌμš΄λ“œ λ‘œλ”©
43
+ yield
44
  print("πŸ›‘ μ„œλ²„ μ’…λ£Œ 쀑...")
45
 
 
46
  app = FastAPI(title="ai-server", lifespan=lifespan)
47
 
 
48
  app.add_middleware(
49
  CORSMiddleware,
50
  allow_origins=["https://fpsgame-rrbc.onrender.com"],
 
53
  allow_headers=["*"],
54
  )
55
 
56
+ @app.get("/", include_in_schema=False)
57
+ async def root():
58
+ return HTMLResponse("""
59
+ <h1>Persona Chat Engine API</h1>
60
+ <p>μ„œλ²„κ°€ μ‹€ν–‰ μ€‘μž…λ‹ˆλ‹€.</p>
61
+ <p><a href="/docs">Swagger UI둜 이동</a></p>
62
+ """)
63
+
64
+ @app.get("/status")
65
+ async def status():
66
+ return {"ready": model_ready}
67
+
68
+ @app.post("/wake")
69
+ async def wake(request: Request):
70
+ session_id = (await request.json()).get("session_id", "unknown")
71
+ print(f"πŸ“‘ Wake signal received for session: {session_id}")
72
+ if not model_ready:
73
+ asyncio.create_task(load_models(app))
74
+ return {"status": "awake", "model_ready": model_ready}
75
 
76
  @app.post("/ask", response_model=AskRes)
77
  async def ask(request: Request, req: AskReq):
78
+ if not model_ready:
79
+ raise HTTPException(status_code=503, detail="Model not ready")
80
  if not req.context:
81
  raise HTTPException(status_code=400, detail="missing context")
 
82
  if not (req.session_id and req.npc_id and req.user_input):
83
  raise HTTPException(status_code=400, detail="missing fields")
84
 
85
  context = req.context
86
+ npc_config_dict = context.npc_config.model_dump() if context.npc_config else None
 
87
 
88
+ return await handle_dialogue(
89
  request=request,
90
  session_id=req.session_id,
91
  npc_id=req.npc_id,
 
93
  context=context.model_dump(),
94
  npc_config=npc_config_dict
95
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  '''