|
|
from typing import List, Union |
|
|
from dotenv import find_dotenv, load_dotenv |
|
|
from langchain.chains import RetrievalQA |
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain.schema import Document |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_huggingface.embeddings import HuggingFaceEmbeddings |
|
|
|
|
|
|
|
|
def get_default_splitter() -> RecursiveCharacterTextSplitter: |
|
|
"""Returns a pre-configured text splitter.""" |
|
|
return RecursiveCharacterTextSplitter( |
|
|
|
|
|
separators=["\n### ", "\n## ", "\n# ", "\n\n", "\n", " "], |
|
|
chunk_size=1000, |
|
|
chunk_overlap=200, |
|
|
) |
|
|
|
|
|
def get_default_embeddings() -> HuggingFaceEmbeddings: |
|
|
"""Returns a pre-configured embedding model.""" |
|
|
return HuggingFaceEmbeddings( |
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2", |
|
|
model_kwargs={'device': 'cpu'} |
|
|
) |
|
|
|
|
|
|
|
|
def build_retriever( |
|
|
data: Union[str, List[Document]], |
|
|
splitter: RecursiveCharacterTextSplitter = None, |
|
|
embeddings: HuggingFaceEmbeddings = None, |
|
|
top_k: int = 5): |
|
|
"""Builds a retriever from either a raw text string or a list of documents. |
|
|
|
|
|
Args: |
|
|
Args: |
|
|
data (Union[str, List[Document]]): The source data to build the retriever from. |
|
|
splitter (RecursiveCharacterTextSplitter, optional): The text splitter to use. |
|
|
Defaults to get_default_splitter(). |
|
|
embeddings (HuggingFaceEmbeddings, optional): The embedding model to use. |
|
|
Defaults to get_default_embeddings(). |
|
|
top_k (int, optional): The number of top results to return. Defaults to 5. |
|
|
""" |
|
|
splitter = splitter or get_default_splitter() |
|
|
embeddings = embeddings or get_default_embeddings() |
|
|
if isinstance(data, str): |
|
|
|
|
|
chunks = splitter.split_text(data) |
|
|
|
|
|
docs = [Document(page_content=chunk) for chunk in chunks] |
|
|
elif isinstance(data, list): |
|
|
|
|
|
docs = splitter.split_documents(data) |
|
|
else: |
|
|
raise ValueError(f"Unsupported data type: {type(data)}. Must be str or List[Document].") |
|
|
|
|
|
index = FAISS.from_documents(docs, embeddings) |
|
|
return index.as_retriever(search_kwargs={"k": top_k}) |
|
|
|
|
|
|
|
|
def create_retrieval_qa( |
|
|
retriever, |
|
|
llm=None |
|
|
) -> RetrievalQA: |
|
|
"""Creates a RetrievalQA instance from a given retriever and LLM. |
|
|
|
|
|
Args: |
|
|
retriever (BaseRetriever): The retriever to be used by the QA chain. |
|
|
llm (LLM, optional): The language model to use. If not provided, |
|
|
a default model will be initialized. |
|
|
""" |
|
|
if llm is None: |
|
|
load_dotenv(find_dotenv()) |
|
|
llm = init_chat_model("groq:meta-llama/llama-4-scout-17b-16e-instruct") |
|
|
return RetrievalQA.from_chain_type( |
|
|
llm=llm, |
|
|
chain_type="stuff", |
|
|
retriever=retriever, |
|
|
return_source_documents=True, |
|
|
) |
|
|
|