minhvtt commited on
Commit
eda7f22
·
verified ·
1 Parent(s): 05351f2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +359 -9
main.py CHANGED
@@ -1,33 +1,63 @@
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.responses import JSONResponse
 
3
  from pydantic import BaseModel
4
- from typing import Optional, List
5
  from PIL import Image
6
  import io
7
  import numpy as np
 
 
 
 
8
 
9
  from embedding_service import JinaClipEmbeddingService
10
  from qdrant_service import QdrantVectorService
11
 
12
  # Initialize FastAPI app
13
  app = FastAPI(
14
- title="Event Social Media Embeddings API",
15
- description="API để embeddings search text + images từ events & social media với Jina CLIP v2 + Qdrant",
16
- version="1.0.0"
 
 
 
 
 
 
 
 
 
17
  )
18
 
19
  # Initialize services
20
  print("Initializing services...")
21
  embedding_service = JinaClipEmbeddingService(model_path="jinaai/jina-clip-v2")
 
 
22
  qdrant_service = QdrantVectorService(
23
- # URL và API key sẽ lấy từ environment variables
24
- collection_name="event_social_media",
25
  vector_size=embedding_service.get_embedding_dimension()
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  print("✓ Services initialized successfully")
28
 
29
 
30
- # Pydantic models
31
  class SearchRequest(BaseModel):
32
  text: Optional[str] = None
33
  limit: int = 10
@@ -48,15 +78,62 @@ class IndexResponse(BaseModel):
48
  message: str
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.get("/")
52
  async def root():
53
  """Health check endpoint"""
54
  return {
55
  "status": "running",
56
- "service": "Event Social Media Embeddings API",
57
  "embedding_model": "Jina CLIP v2",
58
  "vector_db": "Qdrant",
59
- "language_support": "Vietnamese + 88 other languages"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  }
61
 
62
 
@@ -342,6 +419,279 @@ async def get_stats():
342
  raise HTTPException(status_code=500, detail=f"Lỗi khi lấy stats: {str(e)}")
343
 
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  if __name__ == "__main__":
346
  import uvicorn
347
  uvicorn.run(
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
+ from typing import Optional, List, Dict
6
  from PIL import Image
7
  import io
8
  import numpy as np
9
+ import os
10
+ from datetime import datetime
11
+ from pymongo import MongoClient
12
+ from huggingface_hub import InferenceClient
13
 
14
  from embedding_service import JinaClipEmbeddingService
15
  from qdrant_service import QdrantVectorService
16
 
17
  # Initialize FastAPI app
18
  app = FastAPI(
19
+ title="Event Social Media Embeddings & ChatbotRAG API",
20
+ description="API để embeddings, search ChatbotRAG với Jina CLIP v2 + Qdrant + MongoDB + LLM",
21
+ version="2.0.0"
22
+ )
23
+
24
+ # CORS middleware
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
  )
32
 
33
  # Initialize services
34
  print("Initializing services...")
35
  embedding_service = JinaClipEmbeddingService(model_path="jinaai/jina-clip-v2")
36
+
37
+ collection_name = os.getenv("COLLECTION_NAME", "event_social_media")
38
  qdrant_service = QdrantVectorService(
39
+ collection_name=collection_name,
 
40
  vector_size=embedding_service.get_embedding_dimension()
41
  )
42
+ print(f"✓ Qdrant collection: {collection_name}")
43
+
44
+ # MongoDB connection
45
+ mongodb_uri = os.getenv("MONGODB_URI", "mongodb+srv://truongtn7122003:7KaI9OT5KTUxWjVI@truongtn7122003.xogin4q.mongodb.net/")
46
+ mongo_client = MongoClient(mongodb_uri)
47
+ db = mongo_client[os.getenv("MONGODB_DB_NAME", "chatbot_rag")]
48
+ documents_collection = db["documents"]
49
+ chat_history_collection = db["chat_history"]
50
+ print("✓ MongoDB connected")
51
+
52
+ # Hugging Face token
53
+ hf_token = os.getenv("HUGGINGFACE_TOKEN")
54
+ if hf_token:
55
+ print("✓ Hugging Face token configured")
56
+
57
  print("✓ Services initialized successfully")
58
 
59
 
60
+ # Pydantic models for embeddings
61
  class SearchRequest(BaseModel):
62
  text: Optional[str] = None
63
  limit: int = 10
 
78
  message: str
79
 
80
 
81
+ # Pydantic models for ChatbotRAG
82
+ class ChatRequest(BaseModel):
83
+ message: str
84
+ use_rag: bool = True
85
+ top_k: int = 3
86
+ system_message: Optional[str] = "You are a helpful AI assistant."
87
+ max_tokens: int = 512
88
+ temperature: float = 0.7
89
+ top_p: float = 0.95
90
+ hf_token: Optional[str] = None
91
+
92
+
93
+ class ChatResponse(BaseModel):
94
+ response: str
95
+ context_used: List[Dict]
96
+ timestamp: str
97
+
98
+
99
+ class AddDocumentRequest(BaseModel):
100
+ text: str
101
+ metadata: Optional[Dict] = None
102
+
103
+
104
+ class AddDocumentResponse(BaseModel):
105
+ success: bool
106
+ doc_id: str
107
+ message: str
108
+
109
+
110
  @app.get("/")
111
  async def root():
112
  """Health check endpoint"""
113
  return {
114
  "status": "running",
115
+ "service": "Event Social Media Embeddings & ChatbotRAG API",
116
  "embedding_model": "Jina CLIP v2",
117
  "vector_db": "Qdrant",
118
+ "language_support": "Vietnamese + 88 other languages",
119
+ "endpoints": {
120
+ "embeddings": {
121
+ "POST /index": "Index data với text/image",
122
+ "POST /search": "Hybrid search",
123
+ "POST /search/text": "Text search",
124
+ "POST /search/image": "Image search",
125
+ "DELETE /delete/{doc_id}": "Delete document",
126
+ "GET /document/{doc_id}": "Get document",
127
+ "GET /stats": "Collection statistics"
128
+ },
129
+ "chatbot_rag": {
130
+ "POST /chat": "Chat với RAG",
131
+ "POST /documents": "Add document to knowledge base",
132
+ "POST /rag/search": "Search in knowledge base",
133
+ "GET /history": "Get chat history",
134
+ "DELETE /documents/{doc_id}": "Delete document from knowledge base"
135
+ }
136
+ }
137
  }
138
 
139
 
 
419
  raise HTTPException(status_code=500, detail=f"Lỗi khi lấy stats: {str(e)}")
420
 
421
 
422
+ # ============================================
423
+ # ChatbotRAG Endpoints
424
+ # ============================================
425
+
426
+ @app.post("/chat", response_model=ChatResponse)
427
+ async def chat(request: ChatRequest):
428
+ """
429
+ Chat endpoint với RAG
430
+
431
+ Body:
432
+ - message: User message
433
+ - use_rag: Enable RAG retrieval (default: true)
434
+ - top_k: Number of documents to retrieve (default: 3)
435
+ - system_message: System prompt (optional)
436
+ - max_tokens: Max tokens for response (default: 512)
437
+ - temperature: Temperature for generation (default: 0.7)
438
+ - hf_token: Hugging Face token (optional, sẽ dùng env nếu không truyền)
439
+
440
+ Returns:
441
+ - response: Generated response
442
+ - context_used: Retrieved context documents
443
+ - timestamp: Response timestamp
444
+ """
445
+ try:
446
+ # Retrieve context if RAG enabled
447
+ context_used = []
448
+ if request.use_rag:
449
+ # Generate query embedding
450
+ query_embedding = embedding_service.encode_text(request.message)
451
+
452
+ # Search in Qdrant
453
+ results = qdrant_service.search(
454
+ query_embedding=query_embedding,
455
+ limit=request.top_k,
456
+ score_threshold=0.5
457
+ )
458
+ context_used = results
459
+
460
+ # Build context text
461
+ context_text = ""
462
+ if context_used:
463
+ context_text = "\n\nRelevant Context:\n"
464
+ for i, doc in enumerate(context_used, 1):
465
+ doc_text = doc["metadata"].get("text", "")
466
+ confidence = doc["confidence"]
467
+ context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n"
468
+
469
+ # Add context to system message
470
+ system_message = f"{request.system_message}\n{context_text}\n\nPlease use the above context to answer the user's question when relevant."
471
+ else:
472
+ system_message = request.system_message
473
+
474
+ # Use token from request or fallback to env
475
+ token = request.hf_token or hf_token
476
+
477
+ # Generate response
478
+ if not token:
479
+ response = f"""[LLM Response Placeholder]
480
+
481
+ Context retrieved: {len(context_used)} documents
482
+ User question: {request.message}
483
+
484
+ To enable actual LLM generation:
485
+ 1. Set HUGGINGFACE_TOKEN environment variable, OR
486
+ 2. Pass hf_token in request body
487
+
488
+ Example:
489
+ {{
490
+ "message": "Your question",
491
+ "hf_token": "hf_xxxxxxxxxxxxx"
492
+ }}
493
+ """
494
+ else:
495
+ try:
496
+ client = InferenceClient(
497
+ token=token,
498
+ model="openai/gpt-oss-20b"
499
+ )
500
+
501
+ # Build messages
502
+ messages = [
503
+ {"role": "system", "content": system_message},
504
+ {"role": "user", "content": request.message}
505
+ ]
506
+
507
+ # Generate response
508
+ response = ""
509
+ for msg in client.chat_completion(
510
+ messages,
511
+ max_tokens=request.max_tokens,
512
+ stream=True,
513
+ temperature=request.temperature,
514
+ top_p=request.top_p,
515
+ ):
516
+ choices = msg.choices
517
+ if len(choices) and choices[0].delta.content:
518
+ response += choices[0].delta.content
519
+
520
+ except Exception as e:
521
+ response = f"Error generating response with LLM: {str(e)}\n\nContext was retrieved successfully, but LLM generation failed."
522
+
523
+ # Save to history
524
+ chat_data = {
525
+ "user_message": request.message,
526
+ "assistant_response": response,
527
+ "context_used": context_used,
528
+ "timestamp": datetime.utcnow()
529
+ }
530
+ chat_history_collection.insert_one(chat_data)
531
+
532
+ return ChatResponse(
533
+ response=response,
534
+ context_used=context_used,
535
+ timestamp=datetime.utcnow().isoformat()
536
+ )
537
+
538
+ except Exception as e:
539
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
540
+
541
+
542
+ @app.post("/documents", response_model=AddDocumentResponse)
543
+ async def add_document(request: AddDocumentRequest):
544
+ """
545
+ Add document to knowledge base
546
+
547
+ Body:
548
+ - text: Document text
549
+ - metadata: Additional metadata (optional)
550
+
551
+ Returns:
552
+ - success: True/False
553
+ - doc_id: MongoDB document ID
554
+ - message: Status message
555
+ """
556
+ try:
557
+ # Save to MongoDB
558
+ doc_data = {
559
+ "text": request.text,
560
+ "metadata": request.metadata or {},
561
+ "created_at": datetime.utcnow()
562
+ }
563
+ result = documents_collection.insert_one(doc_data)
564
+ doc_id = str(result.inserted_id)
565
+
566
+ # Generate embedding
567
+ embedding = embedding_service.encode_text(request.text)
568
+
569
+ # Index to Qdrant
570
+ qdrant_service.index_data(
571
+ doc_id=doc_id,
572
+ embedding=embedding,
573
+ metadata={
574
+ "text": request.text,
575
+ "source": "api",
576
+ **(request.metadata or {})
577
+ }
578
+ )
579
+
580
+ return AddDocumentResponse(
581
+ success=True,
582
+ doc_id=doc_id,
583
+ message=f"Document added successfully with ID: {doc_id}"
584
+ )
585
+
586
+ except Exception as e:
587
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
588
+
589
+
590
+ @app.post("/rag/search", response_model=List[SearchResponse])
591
+ async def rag_search(
592
+ query: str = Form(...),
593
+ top_k: int = Form(5),
594
+ score_threshold: Optional[float] = Form(0.5)
595
+ ):
596
+ """
597
+ Search in knowledge base
598
+
599
+ Body:
600
+ - query: Search query
601
+ - top_k: Number of results (default: 5)
602
+ - score_threshold: Minimum score (default: 0.5)
603
+
604
+ Returns:
605
+ - results: List of matching documents
606
+ """
607
+ try:
608
+ # Generate query embedding
609
+ query_embedding = embedding_service.encode_text(query)
610
+
611
+ # Search in Qdrant
612
+ results = qdrant_service.search(
613
+ query_embedding=query_embedding,
614
+ limit=top_k,
615
+ score_threshold=score_threshold
616
+ )
617
+
618
+ return [
619
+ SearchResponse(
620
+ id=result["id"],
621
+ confidence=result["confidence"],
622
+ metadata=result["metadata"]
623
+ )
624
+ for result in results
625
+ ]
626
+
627
+ except Exception as e:
628
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
629
+
630
+
631
+ @app.get("/history")
632
+ async def get_history(limit: int = 10, skip: int = 0):
633
+ """
634
+ Get chat history
635
+
636
+ Query params:
637
+ - limit: Number of messages to return (default: 10)
638
+ - skip: Number of messages to skip (default: 0)
639
+
640
+ Returns:
641
+ - history: List of chat messages
642
+ """
643
+ try:
644
+ history = list(
645
+ chat_history_collection
646
+ .find({}, {"_id": 0})
647
+ .sort("timestamp", -1)
648
+ .skip(skip)
649
+ .limit(limit)
650
+ )
651
+
652
+ # Convert datetime to string
653
+ for msg in history:
654
+ if "timestamp" in msg:
655
+ msg["timestamp"] = msg["timestamp"].isoformat()
656
+
657
+ return {
658
+ "history": history,
659
+ "total": chat_history_collection.count_documents({})
660
+ }
661
+
662
+ except Exception as e:
663
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
664
+
665
+
666
+ @app.delete("/documents/{doc_id}")
667
+ async def delete_document_from_kb(doc_id: str):
668
+ """
669
+ Delete document from knowledge base
670
+
671
+ Args:
672
+ - doc_id: Document ID (MongoDB ObjectId)
673
+
674
+ Returns:
675
+ - success: True/False
676
+ - message: Status message
677
+ """
678
+ try:
679
+ # Delete from MongoDB
680
+ result = documents_collection.delete_one({"_id": doc_id})
681
+
682
+ # Delete from Qdrant
683
+ if result.deleted_count > 0:
684
+ qdrant_service.delete_by_id(doc_id)
685
+ return {"success": True, "message": f"Document {doc_id} deleted from knowledge base"}
686
+ else:
687
+ raise HTTPException(status_code=404, detail=f"Document {doc_id} not found")
688
+
689
+ except HTTPException:
690
+ raise
691
+ except Exception as e:
692
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
693
+
694
+
695
  if __name__ == "__main__":
696
  import uvicorn
697
  uvicorn.run(