File size: 3,113 Bytes
5fc69e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os, json
from typing import List, Dict, Any, Optional
from chromadb import PersistentClient
from chromadb.utils.embedding_functions import EmbeddingFunction

_client = PersistentClient(path="./rag")
_collection = _client.get_or_create_collection(name="game_docs")
_embedder: Optional[EmbeddingFunction] = None

def set_embedder(embedder: Any):
    global _embedder
    _embedder = embedder

def chroma_initialized() -> bool:
    return os.path.exists("./chroma_db") and len(os.listdir("./chroma_db")) > 0

def load_game_docs_from_disk(path: str) -> List[Dict[str, Any]]:
    docs = []
    for filename in os.listdir(path):
        full = os.path.join(path, filename)
        if filename.endswith(".json"):
            with open(full, "r", encoding="utf-8") as f:
                data = json.load(f)
                if isinstance(data, list):
                    for i, doc in enumerate(data):
                        if "id" not in doc:
                            doc["id"] = f"{filename}_{i}"
                        docs.append(doc)
                else:
                    if "id" not in data:
                        data["id"] = filename
                    docs.append(data)
        elif filename.endswith(".txt"):
            with open(full, "r", encoding="utf-8") as f:
                content = f.read()
                docs.append({
                    "id": filename,
                    "content": content,
                    "metadata": {}
                })
    return docs

def add_docs(docs: List[Dict[str, Any]], batch_size: int = 32):
    assert _embedder is not None, "Embedder not initialized"
    for i in range(0, len(docs), batch_size):
        batch = docs[i:i+batch_size]
        ids = []
        contents = []
        embeddings = []
        metadatas = []
        for doc in batch:
            assert "id" in doc and "content" in doc, "doc requires id and content"
            ids.append(doc["id"])
            contents.append(doc["content"])
            metadatas.append(doc.get("metadata", {}))
            emb = _embedder.encode(doc["content"]).tolist()
            embeddings.append(emb)
        _collection.add(
            documents=contents,
            embeddings=embeddings,
            metadatas=metadatas,
            ids=ids
        )

def retrieve(query: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, top_k: int = 5) -> List[Dict[str, Any]]:
    assert _embedder is not None, "Embedder not initialized"
    
    if query:
        q_emb = _embedder.encode(query).tolist()
        res = _collection.query(
            query_embeddings=[q_emb],
            n_results=top_k,
            where=filters or {}
        )
        docs = res.get("documents", [[]])[0]
        metas = res.get("metadatas", [[]])[0]
        return [{"content": d, "metadata": m} for d, m in zip(docs, metas)]
    else:
        res = _collection.get(
            where=filters or {},
            limit=top_k
        )
        docs = res.get("documents", [])
        metas = res.get("metadatas", [])
        return [{"content": d, "metadata": m} for d, m in zip(docs, metas)]