Spaces:
Sleeping
Sleeping
| """ | |
| ์ฌ์์ํ ๊ฒ์ ๊ตฌํ ๋ชจ๋ | |
| """ | |
| import logging | |
| from typing import List, Dict, Any, Optional, Union, Callable | |
| from .base_retriever import BaseRetriever | |
| logger = logging.getLogger(__name__) | |
| class ReRanker(BaseRetriever): | |
| """ | |
| ๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์์ํ ๊ฒ์๊ธฐ | |
| """ | |
| def __init__( | |
| self, | |
| base_retriever: BaseRetriever, | |
| rerank_model: Optional[Union[str, Any]] = None, | |
| rerank_fn: Optional[Callable] = None, | |
| rerank_field: str = "text", | |
| rerank_batch_size: int = 32 | |
| ): | |
| """ | |
| ReRanker ์ด๊ธฐํ | |
| Args: | |
| base_retriever: ๊ธฐ๋ณธ ๊ฒ์๊ธฐ ์ธ์คํด์ค | |
| rerank_model: ์ฌ์์ํ ๋ชจ๋ธ (Cross-Encoder) ์ด๋ฆ ๋๋ ์ธ์คํด์ค | |
| rerank_fn: ์ฌ์ฉ์ ์ ์ ์ฌ์์ํ ํจ์ (์ ๊ณต๋ ๊ฒฝ์ฐ rerank_model ๋์ ์ฌ์ฉ) | |
| rerank_field: ์ฌ์์ํ์ ์ฌ์ฉํ ๋ฌธ์ ํ๋ | |
| rerank_batch_size: ์ฌ์์ํ ๋ชจ๋ธ ๋ฐฐ์น ํฌ๊ธฐ | |
| """ | |
| self.base_retriever = base_retriever | |
| self.rerank_field = rerank_field | |
| self.rerank_batch_size = rerank_batch_size | |
| self.rerank_fn = rerank_fn | |
| # ์ฌ์์ํ ๋ชจ๋ธ ๋ก๋ (์ฌ์ฉ์ ์ ์ ํจ์๊ฐ ์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ) | |
| if rerank_fn is None and rerank_model is not None: | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| if isinstance(rerank_model, str): | |
| logger.info(f"์ฌ์์ํ ๋ชจ๋ธ '{rerank_model}' ๋ก๋ ์ค...") | |
| self.rerank_model = CrossEncoder(rerank_model) | |
| else: | |
| self.rerank_model = rerank_model | |
| except ImportError: | |
| logger.warning("sentence-transformers ํจํค์ง๊ฐ ์ค์น๋์ง ์์์ต๋๋ค. pip install sentence-transformers ๋ช ๋ น์ผ๋ก ์ค์นํ์ธ์.") | |
| raise | |
| else: | |
| self.rerank_model = None | |
| def add_documents(self, documents: List[Dict[str, Any]]) -> None: | |
| """ | |
| ๊ธฐ๋ณธ ๊ฒ์๊ธฐ์ ๋ฌธ์ ์ถ๊ฐ | |
| Args: | |
| documents: ์ถ๊ฐํ ๋ฌธ์ ๋ชฉ๋ก | |
| """ | |
| self.base_retriever.add_documents(documents) | |
| def search(self, query: str, top_k: int = 5, first_stage_k: int = 30, **kwargs) -> List[Dict[str, Any]]: | |
| """ | |
| 2๋จ๊ณ ๊ฒ์ ์ํ: ๊ธฐ๋ณธ ๊ฒ์ + ์ฌ์์ํ | |
| Args: | |
| query: ๊ฒ์ ์ฟผ๋ฆฌ | |
| top_k: ์ต์ข ์ ์ผ๋ก ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์ | |
| first_stage_k: ์ฒซ ๋ฒ์งธ ๋จ๊ณ์์ ๊ฒ์ํ ๊ฒฐ๊ณผ ์ | |
| **kwargs: ์ถ๊ฐ ๊ฒ์ ๋งค๊ฐ๋ณ์ | |
| Returns: | |
| ์ฌ์์ํ๋ ๊ฒ์ ๊ฒฐ๊ณผ ๋ชฉ๋ก | |
| """ | |
| # ์ฒซ ๋ฒ์งธ ๋จ๊ณ: ๊ธฐ๋ณธ ๊ฒ์๊ธฐ๋ก more_k ๋ฌธ์ ๊ฒ์ | |
| logger.info(f"๊ธฐ๋ณธ ๊ฒ์๊ธฐ๋ก {first_stage_k}๊ฐ ๋ฌธ์ ๊ฒ์ ์ค...") | |
| initial_results = self.base_retriever.search(query, top_k=first_stage_k, **kwargs) | |
| if not initial_results: | |
| logger.warning("์ฒซ ๋ฒ์งธ ๋จ๊ณ ๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค.") | |
| return [] | |
| if len(initial_results) < first_stage_k: | |
| logger.info(f"์์ฒญํ {first_stage_k}๊ฐ๋ณด๋ค ์ ์ {len(initial_results)}๊ฐ ๊ฒฐ๊ณผ๋ฅผ ๊ฒ์ํ์ต๋๋ค.") | |
| # ์ฌ์ฉ์ ์ ์ ์ฌ์์ํ ํจ์๊ฐ ์ ๊ณต๋ ๊ฒฝ์ฐ | |
| if self.rerank_fn is not None: | |
| logger.info("์ฌ์ฉ์ ์ ์ ํจ์๋ก ์ฌ์์ํ ์ค...") | |
| reranked_results = self.rerank_fn(query, initial_results) | |
| return reranked_results[:top_k] | |
| # ์ฌ์์ํ ๋ชจ๋ธ์ด ๋ก๋๋ ๊ฒฝ์ฐ | |
| elif self.rerank_model is not None: | |
| logger.info(f"CrossEncoder ๋ชจ๋ธ๋ก ์ฌ์์ํ ์ค...") | |
| # ํ ์คํธ ์ ์์ฑ | |
| text_pairs = [] | |
| for doc in initial_results: | |
| if self.rerank_field not in doc: | |
| logger.warning(f"๋ฌธ์์ ํ๋ '{self.rerank_field}'๊ฐ ์์ต๋๋ค.") | |
| continue | |
| text_pairs.append([query, doc[self.rerank_field]]) | |
| # ๋ชจ๋ธ๋ก ์ ์ ๊ณ์ฐ | |
| scores = self.rerank_model.predict( | |
| text_pairs, | |
| batch_size=self.rerank_batch_size, | |
| show_progress_bar=True if len(text_pairs) > 10 else False | |
| ) | |
| # ๊ฒฐ๊ณผ ์ฌ์ ๋ ฌ | |
| for idx, doc in enumerate(initial_results[:len(scores)]): | |
| doc["rerank_score"] = float(scores[idx]) | |
| reranked_results = sorted( | |
| initial_results[:len(scores)], | |
| key=lambda x: x.get("rerank_score", 0), | |
| reverse=True | |
| ) | |
| return reranked_results[:top_k] | |
| # ์ฌ์์ํ ์์ด ์ด๊ธฐ ๊ฒฐ๊ณผ ๋ฐํ | |
| else: | |
| logger.info("์ฌ์์ํ ๋ชจ๋ธ/ํจ์๊ฐ ์์ด ์ด๊ธฐ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋๋ก ๋ฐํํฉ๋๋ค.") | |
| return initial_results[:top_k] | |