Spaces:
Sleeping
Sleeping
| """ | |
| ๋ฒกํฐ ๊ฒ์ ๊ตฌํ ๋ชจ๋ | |
| """ | |
| import os | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional, Union, Tuple | |
| import logging | |
| from sentence_transformers import SentenceTransformer | |
| from .base_retriever import BaseRetriever | |
| logger = logging.getLogger(__name__) | |
| class VectorRetriever(BaseRetriever): | |
| """ | |
| ์๋ฒ ๋ฉ ๊ธฐ๋ฐ ๋ฒกํฐ ๊ฒ์ ๊ตฌํ | |
| """ | |
| def __init__( | |
| self, | |
| embedding_model: Optional[Union[str, SentenceTransformer]] = "paraphrase-multilingual-MiniLM-L12-v2", | |
| documents: Optional[List[Dict[str, Any]]] = None, | |
| embedding_field: str = "text", | |
| embedding_device: str = "cpu" | |
| ): | |
| """ | |
| VectorRetriever ์ด๊ธฐํ | |
| Args: | |
| embedding_model: ์๋ฒ ๋ฉ ๋ชจ๋ธ ์ด๋ฆ ๋๋ SentenceTransformer ์ธ์คํด์ค | |
| documents: ์ด๊ธฐ ๋ฌธ์ ๋ชฉ๋ก (์ ํ ์ฌํญ) | |
| embedding_field: ์๋ฒ ๋ฉํ ๋ฌธ์ ํ๋ ์ด๋ฆ | |
| embedding_device: ์๋ฒ ๋ฉ ๋ชจ๋ธ ์คํ ์ฅ์น ('cpu' ๋๋ 'cuda') | |
| """ | |
| self.embedding_field = embedding_field | |
| self.model_name = None | |
| # ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ | |
| if isinstance(embedding_model, str): | |
| logger.info(f"์๋ฒ ๋ฉ ๋ชจ๋ธ '{embedding_model}' ๋ก๋ ์ค...") | |
| self.model_name = embedding_model | |
| self.embedding_model = SentenceTransformer(embedding_model, device=embedding_device) | |
| else: | |
| self.embedding_model = embedding_model | |
| # ๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋ ์ธ์คํด์ค์ผ ๊ฒฝ์ฐ ์ด๋ฆ ์ถ์ถ | |
| if hasattr(embedding_model, '_modules') and 'modules' in embedding_model._modules: | |
| self.model_name = "loaded_sentence_transformer" | |
| # ๋ฌธ์ ์ ์ฅ์ ์ด๊ธฐํ | |
| self.documents = [] | |
| self.document_embeddings = None | |
| # ์ด๊ธฐ ๋ฌธ์๊ฐ ์ ๊ณต๋ ๊ฒฝ์ฐ ์ถ๊ฐ | |
| if documents: | |
| self.add_documents(documents) | |
| def add_documents(self, documents: List[Dict[str, Any]]) -> None: | |
| """ | |
| ๊ฒ์๊ธฐ์ ๋ฌธ์๋ฅผ ์ถ๊ฐํ๊ณ ์๋ฒ ๋ฉ ์์ฑ | |
| Args: | |
| documents: ์ถ๊ฐํ ๋ฌธ์ ๋ชฉ๋ก | |
| """ | |
| if not documents: | |
| logger.warning("์ถ๊ฐํ ๋ฌธ์๊ฐ ์์ต๋๋ค.") | |
| return | |
| # ๋ฌธ์ ์ถ๊ฐ | |
| document_texts = [] | |
| for doc in documents: | |
| if self.embedding_field not in doc: | |
| logger.warning(f"๋ฌธ์์ ํ๋ '{self.embedding_field}'๊ฐ ์์ต๋๋ค. ๊ฑด๋๋๋๋ค.") | |
| continue | |
| self.documents.append(doc) | |
| document_texts.append(doc[self.embedding_field]) | |
| if not document_texts: | |
| logger.warning(f"์๋ฒ ๋ฉํ ํ ์คํธ๊ฐ ์์ต๋๋ค. ๋ชจ๋ ๋ฌธ์์ '{self.embedding_field}' ํ๋๊ฐ ์๋์ง ํ์ธํ์ธ์.") | |
| return | |
| # ๋ฌธ์ ์๋ฒ ๋ฉ ์์ฑ | |
| logger.info(f"{len(document_texts)}๊ฐ ๋ฌธ์์ ์๋ฒ ๋ฉ ์์ฑ ์ค...") | |
| new_embeddings = self.embedding_model.encode(document_texts, show_progress_bar=True) | |
| # ๊ธฐ์กด ์๋ฒ ๋ฉ๊ณผ ๋ณํฉ | |
| if self.document_embeddings is None: | |
| self.document_embeddings = new_embeddings | |
| else: | |
| self.document_embeddings = np.vstack([self.document_embeddings, new_embeddings]) | |
| logger.info(f"์ด {len(self.documents)}๊ฐ ๋ฌธ์, {self.document_embeddings.shape[0]}๊ฐ ์๋ฒ ๋ฉ ์ ์ฅ๋จ") | |
| def search(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: | |
| """ | |
| ์ฟผ๋ฆฌ์ ๋ํ ๋ฒกํฐ ๊ฒ์ ์ํ | |
| Args: | |
| query: ๊ฒ์ ์ฟผ๋ฆฌ | |
| top_k: ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์ | |
| **kwargs: ์ถ๊ฐ ๊ฒ์ ๋งค๊ฐ๋ณ์ | |
| Returns: | |
| ๊ด๋ จ์ฑ ์ ์์ ํจ๊ป ๊ฒ์๋ ๋ฌธ์ ๋ชฉ๋ก | |
| """ | |
| if not self.documents or self.document_embeddings is None: | |
| logger.warning("๊ฒ์ํ ๋ฌธ์๊ฐ ์์ต๋๋ค.") | |
| return [] | |
| # ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ ์์ฑ | |
| query_embedding = self.embedding_model.encode(query) | |
| # ์ฝ์ฌ์ธ ์ ์ฌ๋ ๊ณ์ฐ | |
| scores = np.dot(self.document_embeddings, query_embedding) / ( | |
| np.linalg.norm(self.document_embeddings, axis=1) * np.linalg.norm(query_embedding) | |
| ) | |
| # ์์ ๊ฒฐ๊ณผ ์ ํ | |
| top_indices = np.argsort(scores)[-top_k:][::-1] | |
| # ๊ฒฐ๊ณผ ํ์ํ | |
| results = [] | |
| for idx in top_indices: | |
| doc = self.documents[idx].copy() | |
| doc["score"] = float(scores[idx]) | |
| results.append(doc) | |
| return results | |
| def save(self, directory: str) -> None: | |
| """ | |
| ๊ฒ์๊ธฐ ์ํ๋ฅผ ๋์คํฌ์ ์ ์ฅ | |
| Args: | |
| directory: ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก | |
| """ | |
| import pickle | |
| import json | |
| os.makedirs(directory, exist_ok=True) | |
| # ๋ฌธ์ ์ ์ฅ | |
| with open(os.path.join(directory, "documents.json"), "w", encoding="utf-8") as f: | |
| json.dump(self.documents, f, ensure_ascii=False, indent=2) | |
| # ์๋ฒ ๋ฉ ์ ์ฅ | |
| if self.document_embeddings is not None: | |
| np.save(os.path.join(directory, "embeddings.npy"), self.document_embeddings) | |
| # ๋ชจ๋ธ ์ ๋ณด ์ ์ฅ | |
| model_info = { | |
| "model_name": self.model_name or "paraphrase-multilingual-MiniLM-L12-v2", # ๊ธฐ๋ณธ๊ฐ ์ค์ | |
| "embedding_dim": self.embedding_model.get_sentence_embedding_dimension() if hasattr(self.embedding_model, 'get_sentence_embedding_dimension') else 384 | |
| } | |
| with open(os.path.join(directory, "model_info.json"), "w") as f: | |
| json.dump(model_info, f) | |
| logger.info(f"๊ฒ์๊ธฐ ์ํ๋ฅผ '{directory}'์ ์ ์ฅํ์ต๋๋ค.") | |
| def load(cls, directory: str, embedding_model: Optional[Union[str, SentenceTransformer]] = None) -> "VectorRetriever": | |
| """ | |
| ๋์คํฌ์์ ๊ฒ์๊ธฐ ์ํ๋ฅผ ๋ก๋ | |
| Args: | |
| directory: ๋ก๋ํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก | |
| embedding_model: ์ฌ์ฉํ ์๋ฒ ๋ฉ ๋ชจ๋ธ (์ ๊ณต๋์ง ์์ผ๋ฉด ์ ์ฅ๋ ์ ๋ณด ์ฌ์ฉ) | |
| Returns: | |
| ๋ก๋๋ VectorRetriever ์ธ์คํด์ค | |
| """ | |
| import json | |
| # ๋ชจ๋ธ ์ ๋ณด ๋ก๋ | |
| with open(os.path.join(directory, "model_info.json"), "r") as f: | |
| model_info = json.load(f) | |
| # ์๋ฒ ๋ฉ ๋ชจ๋ธ ์ธ์คํด์คํ | |
| if embedding_model is None: | |
| # ๋ชจ๋ธ ์ด๋ฆ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ์ธ์คํด์คํ | |
| if "model_name" in model_info and isinstance(model_info["model_name"], str): | |
| embedding_model = model_info["model_name"] | |
| else: | |
| # ์์ ์ฅ์น: ๋ชจ๋ธ ์ด๋ฆ์ด ์๊ฑฐ๋ ์ ์์ธ ๊ฒฝ์ฐ(์ด์ ๋ฒ์ ํธํ์ฑ) ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ | |
| logger.warning("์ ํจํ ๋ชจ๋ธ ์ด๋ฆ์ ์ฐพ์ ์ ์์ต๋๋ค. ๊ธฐ๋ณธ ๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค.") | |
| embedding_model = "paraphrase-multilingual-MiniLM-L12-v2" | |
| # ๊ฒ์๊ธฐ ์ธ์คํด์ค ์์ฑ (๋ฌธ์ ์์ด) | |
| retriever = cls(embedding_model=embedding_model) | |
| # ๋ฌธ์ ๋ก๋ | |
| with open(os.path.join(directory, "documents.json"), "r", encoding="utf-8") as f: | |
| retriever.documents = json.load(f) | |
| # ์๋ฒ ๋ฉ ๋ก๋ | |
| embeddings_path = os.path.join(directory, "embeddings.npy") | |
| if os.path.exists(embeddings_path): | |
| retriever.document_embeddings = np.load(embeddings_path) | |
| logger.info(f"๊ฒ์๊ธฐ ์ํ๋ฅผ '{directory}'์์ ๋ก๋ํ์ต๋๋ค.") | |
| return retriever | |