Spaces:
Paused
Paused
| import os | |
| import json | |
| import requests | |
| import subprocess | |
| from typing import Mapping, Optional, Any | |
| import torch | |
| import transformers | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import snapshot_download | |
| from urllib.parse import quote | |
| from langchain import PromptTemplate, HuggingFaceHub, LLMChain | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.llms.base import LLM | |
| from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings, HuggingFaceInstructEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from sentence_transformers import CrossEncoder | |
| from awq import AutoAWQForCausalLM | |
| from qa_engine import logger | |
| from qa_engine.response import Response | |
| from qa_engine.mocks import MockLocalBinaryModel | |
| class LocalBinaryModel(LLM): | |
| model_id: str = None | |
| llm: None = None | |
| def __init__(self, model_id: str = None): | |
| super().__init__() | |
| # pip install llama_cpp_python==0.1.39 | |
| from llama_cpp import Llama | |
| model_path = f'qa_engine/{model_id}' | |
| if not os.path.exists(model_path): | |
| raise ValueError(f'{model_path} does not exist') | |
| self.model_id = model_id | |
| self.llm = Llama(model_path=model_path, n_ctx=4096) | |
| def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str: | |
| output = self.llm( | |
| prompt, | |
| max_tokens=1024, | |
| stop=['Q:'], | |
| echo=False | |
| ) | |
| return output['choices'][0]['text'] | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| return {'name_of_model': self.model_id} | |
| def _llm_type(self) -> str: | |
| return self.model_id | |
| class TransformersPipelineModel(LLM): | |
| model_id: str = None | |
| pipeline: str = None | |
| def __init__(self, model_id: str = None): | |
| super().__init__() | |
| self.model_id = model_id | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| if "AWQ" in model_id: | |
| model = AutoAWQForCausalLM.from_quantized(model_id, fuse_layers=True, | |
| trust_remote_code=False, safetensors=True).model | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| load_in_8bit=False, | |
| device_map='auto', | |
| resume_download=True, | |
| ) | |
| self.pipeline = transformers.pipeline( | |
| 'text-generation', | |
| model=model, | |
| tokenizer=tokenizer, | |
| torch_dtype=torch.bfloat16, | |
| device_map='auto', | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id, | |
| min_new_tokens=64, | |
| max_new_tokens=800, | |
| temperature=0.5, | |
| ) | |
| def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str: | |
| output_text = self.pipeline(prompt)[0]['generated_text'] | |
| output_text = output_text.replace(prompt+'\n', '') | |
| return output_text | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| return {'name_of_model': self.model_id} | |
| def _llm_type(self) -> str: | |
| return self.model_id | |
| class APIServedModel(LLM): | |
| model_url: str = None | |
| debug: bool = None | |
| def __init__(self, model_url: str = None, debug: bool = None): | |
| super().__init__() | |
| if model_url[-1] == '/': | |
| raise ValueError('URL should not end with a slash - "/"') | |
| self.model_url = model_url | |
| self.debug = debug | |
| def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str: | |
| prompt_encoded = quote(prompt, safe='') | |
| url = f'{self.model_url}/?prompt={prompt_encoded}' | |
| if self.debug: | |
| logger.info(f'URL: {url}') | |
| try: | |
| response = requests.get(url, timeout=1200, verify=False) | |
| response.raise_for_status() | |
| return json.loads(response.content)['output_text'] | |
| except Exception as err: | |
| logger.error(f'Error: {err}') | |
| return f'Error: {err}' | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| return {'name_of_model': f'model url: {self.model_url}'} | |
| def _llm_type(self) -> str: | |
| return 'api_model' | |
| class QAEngine(): | |
| """ | |
| QAEngine class, used for generating answers to questions. | |
| Args: | |
| llm_model_id (str): The ID of the LLM model to be used. | |
| embedding_model_id (str): The ID of the embedding model to be used. | |
| index_repo_id (str): The ID of the index repository to be used. | |
| run_locally (bool, optional): Whether to run the models locally or on the Hugging Face hub. Defaults to True. | |
| use_docs_for_context (bool, optional): Whether to use relevant documents as context for generating answers. | |
| Defaults to True. | |
| use_messages_for_context (bool, optional): Whether to use previous messages as context for generating answers. | |
| Defaults to True. | |
| debug (bool, optional): Whether to log debug information. Defaults to False. | |
| Attributes: | |
| use_docs_for_context (bool): Whether to use relevant documents as context for generating answers. | |
| use_messages_for_context (bool): Whether to use previous messages as context for generating answers. | |
| debug (bool): Whether to log debug information. | |
| llm_model (Union[LocalBinaryModel, HuggingFacePipeline, HuggingFaceHub]): The LLM model to be used. | |
| embedding_model (Union[HuggingFaceInstructEmbeddings, HuggingFaceHubEmbeddings]): The embedding model to be used. | |
| prompt_template (PromptTemplate): The prompt template to be used. | |
| llm_chain (LLMChain): The LLM chain to be used. | |
| knowledge_index (FAISS): The FAISS index to be used. | |
| """ | |
| def __init__( | |
| self, | |
| llm_model_id: str, | |
| embedding_model_id: str, | |
| index_repo_id: str, | |
| prompt_template: str, | |
| use_docs_for_context: bool = True, | |
| num_relevant_docs: int = 3, | |
| add_sources_to_response: bool = True, | |
| use_messages_for_context: bool = True, | |
| first_stage_docs: int = 50, | |
| debug: bool = False | |
| ): | |
| super().__init__() | |
| self.prompt_template = prompt_template | |
| self.use_docs_for_context = use_docs_for_context | |
| self.num_relevant_docs = num_relevant_docs | |
| self.add_sources_to_response = add_sources_to_response | |
| self.use_messages_for_context = use_messages_for_context | |
| self.first_stage_docs = first_stage_docs | |
| self.debug = debug | |
| if 'local_models/' in llm_model_id: | |
| logger.info('using local binary model') | |
| self.llm_model = LocalBinaryModel( | |
| model_id=llm_model_id | |
| ) | |
| elif 'api_models/' in llm_model_id: | |
| logger.info('using api served model') | |
| self.llm_model = APIServedModel( | |
| model_url=llm_model_id.replace('api_models/', ''), | |
| debug=self.debug | |
| ) | |
| elif llm_model_id == 'mock': | |
| logger.info('using mock model') | |
| self.llm_model = MockLocalBinaryModel() | |
| else: | |
| logger.info('using transformers pipeline model') | |
| self.llm_model = TransformersPipelineModel( | |
| model_id=llm_model_id | |
| ) | |
| prompt = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=['question', 'context'] | |
| ) | |
| self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model) | |
| if self.use_docs_for_context: | |
| logger.info(f'Downloading {index_repo_id}') | |
| snapshot_download( | |
| repo_id=index_repo_id, | |
| allow_patterns=['*.faiss', '*.pkl'], | |
| repo_type='dataset', | |
| local_dir='indexes/run/' | |
| ) | |
| logger.info('Loading embedding model') | |
| embed_instruction = 'Represent the Hugging Face library documentation' | |
| query_instruction = 'Query the most relevant piece of information from the Hugging Face documentation' | |
| embedding_model = HuggingFaceInstructEmbeddings( | |
| model_name=embedding_model_id, | |
| embed_instruction=embed_instruction, | |
| query_instruction=query_instruction | |
| ) | |
| logger.info('Loading index') | |
| self.knowledge_index = FAISS.load_local('./indexes/run/', embedding_model) | |
| self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') | |
| def _preprocess_question(question: str) -> str: | |
| if question[-1] != '?': | |
| question += '?' | |
| return question | |
| def _postprocess_answer(answer: str) -> str: | |
| ''' | |
| Preprocess the answer by removing unnecessary sequences and stop sequences. | |
| ''' | |
| SEQUENCES_TO_REMOVE = [ | |
| 'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]' | |
| ] | |
| SEQUENCES_TO_STOP = [ | |
| 'User:', 'You:', 'Question:' | |
| ] | |
| for seq in SEQUENCES_TO_REMOVE: | |
| answer = answer.replace(seq, '') | |
| for seq in SEQUENCES_TO_STOP: | |
| if seq in answer: | |
| answer = answer[:answer.index(seq)] | |
| answer = answer.strip() | |
| return answer | |
| def get_response(self, question: str, messages_context: str = '') -> Response: | |
| """ | |
| Generate an answer to the specified question. | |
| Args: | |
| question (str): The question to be answered. | |
| messages_context (str, optional): The context to be used for generating the answer. Defaults to ''. | |
| Returns: | |
| response (Response): The Response object containing the generated answer and the sources of information | |
| used to generate the response. | |
| """ | |
| response = Response() | |
| context = '' | |
| relevant_docs = '' | |
| if self.use_messages_for_context and messages_context: | |
| messages_context = f'\nPrevious questions and answers:\n{messages_context}' | |
| context += messages_context | |
| if self.use_docs_for_context: | |
| logger.info('Retriving documents') | |
| # messages context is used for better retrival | |
| retrival_query = messages_context + question | |
| relevant_docs = self.knowledge_index.similarity_search( | |
| query=retrival_query, | |
| k=self.first_stage_docs | |
| ) | |
| cross_encoding_predictions = self.reranker.predict( | |
| [(retrival_query, doc.page_content) for doc in relevant_docs] | |
| ) | |
| relevant_docs = [ | |
| doc for _, doc in sorted( | |
| zip(cross_encoding_predictions, relevant_docs), | |
| reverse=True, key = lambda x: x[0] | |
| ) | |
| ] | |
| relevant_docs = relevant_docs[:self.num_relevant_docs] | |
| context += '\nExtracted documents:\n' | |
| context += ''.join([doc.page_content for doc in relevant_docs]) | |
| metadata = [doc.metadata for doc in relevant_docs] | |
| response.set_sources(sources=[str(m['source']) for m in metadata]) | |
| logger.info('Running LLM chain') | |
| question_processed = QAEngine._preprocess_question(question) | |
| answer = self.llm_chain.run(question=question_processed, context=context) | |
| answer_postprocessed = QAEngine._postprocess_answer(answer) | |
| response.set_answer(answer_postprocessed) | |
| logger.info('Received answer') | |
| if self.debug: | |
| logger.info('\n' + '=' * 100) | |
| sep = '\n' + '-' * 100 | |
| logger.info(f'question len: {len(question)} {sep}') | |
| logger.info(f'question: {question} {sep}') | |
| logger.info(f'answer len: {len(response.get_answer())} {sep}') | |
| logger.info(f'answer original: {answer} {sep}') | |
| logger.info(f'answer postprocessed: {response.get_answer()} {sep}') | |
| logger.info(f'{response.get_sources_as_text()} {sep}') | |
| logger.info(f'messages_contex: {messages_context} {sep}') | |
| logger.info(f'relevant_docs: {relevant_docs} {sep}') | |
| logger.info(f'context len: {len(context)} {sep}') | |
| logger.info(f'context: {context} {sep}') | |
| return response | |