Spaces:
Sleeping
Sleeping
| from typing import List, Union | |
| from dotenv import find_dotenv, load_dotenv | |
| from langchain.chains import RetrievalQA | |
| from langchain.chat_models import init_chat_model | |
| from langchain.schema import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
| def get_default_splitter() -> RecursiveCharacterTextSplitter: | |
| """Returns a pre-configured text splitter.""" | |
| return RecursiveCharacterTextSplitter( | |
| # Using markdown headers as separators is a good strategy | |
| separators=["\n### ", "\n## ", "\n# ", "\n\n", "\n", " "], | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| ) | |
| def get_default_embeddings() -> HuggingFaceEmbeddings: | |
| """Returns a pre-configured embedding model.""" | |
| return HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| def build_retriever( | |
| data: Union[str, List[Document]], | |
| splitter: RecursiveCharacterTextSplitter = None, | |
| embeddings: HuggingFaceEmbeddings = None, | |
| top_k: int = 5): | |
| """Builds a retriever from either a raw text string or a list of documents. | |
| Args: | |
| Args: | |
| data (Union[str, List[Document]]): The source data to build the retriever from. | |
| splitter (RecursiveCharacterTextSplitter, optional): The text splitter to use. | |
| Defaults to get_default_splitter(). | |
| embeddings (HuggingFaceEmbeddings, optional): The embedding model to use. | |
| Defaults to get_default_embeddings(). | |
| top_k (int, optional): The number of top results to return. Defaults to 5. | |
| """ | |
| splitter = splitter or get_default_splitter() | |
| embeddings = embeddings or get_default_embeddings() | |
| if isinstance(data, str): | |
| # If the input is a raw string, split it into chunks first | |
| chunks = splitter.split_text(data) | |
| # Then convert those chunks into Document objects | |
| docs = [Document(page_content=chunk) for chunk in chunks] | |
| elif isinstance(data, list): | |
| # If the input is already a list of documents, split them directly | |
| docs = splitter.split_documents(data) | |
| else: | |
| raise ValueError(f"Unsupported data type: {type(data)}. Must be str or List[Document].") | |
| index = FAISS.from_documents(docs, embeddings) | |
| return index.as_retriever(search_kwargs={"k": top_k}) | |
| def create_retrieval_qa( | |
| retriever, | |
| llm=None | |
| ) -> RetrievalQA: | |
| """Creates a RetrievalQA instance from a given retriever and LLM. | |
| Args: | |
| retriever (BaseRetriever): The retriever to be used by the QA chain. | |
| llm (LLM, optional): The language model to use. If not provided, | |
| a default model will be initialized. | |
| """ | |
| if llm is None: | |
| load_dotenv(find_dotenv()) | |
| llm = init_chat_model("groq:meta-llama/llama-4-scout-17b-16e-instruct") | |
| return RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True, | |
| ) | |