WIP
Browse files- app_modules/llm_inference.py +96 -0
- app_modules/llm_loader.py +1 -0
- app_modules/llm_qa_chain.py +23 -0
app_modules/llm_inference.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import urllib
|
| 5 |
+
from queue import Queue
|
| 6 |
+
from threading import Thread
|
| 7 |
+
|
| 8 |
+
from langchain.callbacks.tracers import LangChainTracer
|
| 9 |
+
from langchain.chains.base import Chain
|
| 10 |
+
|
| 11 |
+
from app_modules.llm_loader import *
|
| 12 |
+
from app_modules.utils import *
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LLMInference(metaclass=abc.ABCMeta):
|
| 16 |
+
llm_loader: LLMLoader
|
| 17 |
+
chain: Chain
|
| 18 |
+
|
| 19 |
+
def __init__(self, llm_loader):
|
| 20 |
+
self.llm_loader = llm_loader
|
| 21 |
+
self.chain = None
|
| 22 |
+
|
| 23 |
+
@abc.abstractmethod
|
| 24 |
+
def create_chain(self) -> Chain:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def get_chain(self, tracing: bool = False) -> Chain:
|
| 28 |
+
if self.chain is None:
|
| 29 |
+
if tracing:
|
| 30 |
+
tracer = LangChainTracer()
|
| 31 |
+
tracer.load_default_session()
|
| 32 |
+
|
| 33 |
+
self.chain = self.create_chain()
|
| 34 |
+
|
| 35 |
+
return self.chain
|
| 36 |
+
|
| 37 |
+
def call_chain(
|
| 38 |
+
self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
|
| 39 |
+
):
|
| 40 |
+
print(inputs)
|
| 41 |
+
|
| 42 |
+
if self.llm_loader.streamer is not None and isinstance(
|
| 43 |
+
self.llm_loader.streamer, TextIteratorStreamer
|
| 44 |
+
):
|
| 45 |
+
self.llm_loader.streamer.reset(q)
|
| 46 |
+
|
| 47 |
+
chain = self.get_chain(tracing)
|
| 48 |
+
result = (
|
| 49 |
+
self._run_qa_chain(
|
| 50 |
+
chain,
|
| 51 |
+
inputs,
|
| 52 |
+
streaming_handler,
|
| 53 |
+
)
|
| 54 |
+
if streaming_handler is not None
|
| 55 |
+
else chain(inputs)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
result["answer"] = remove_extra_spaces(result["answer"])
|
| 59 |
+
|
| 60 |
+
base_url = os.environ.get("PDF_FILE_BASE_URL")
|
| 61 |
+
if base_url is not None and len(base_url) > 0:
|
| 62 |
+
documents = result["source_documents"]
|
| 63 |
+
for doc in documents:
|
| 64 |
+
source = doc.metadata["source"]
|
| 65 |
+
title = source.split("/")[-1]
|
| 66 |
+
doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
|
| 67 |
+
|
| 68 |
+
return result
|
| 69 |
+
|
| 70 |
+
def _run_qa_chain(self, qa, inputs, streaming_handler):
|
| 71 |
+
que = Queue()
|
| 72 |
+
|
| 73 |
+
t = Thread(
|
| 74 |
+
target=lambda qa, inputs, q, sh: q.put(qa(inputs, callbacks=[sh])),
|
| 75 |
+
args=(qa, inputs, que, streaming_handler),
|
| 76 |
+
)
|
| 77 |
+
t.start()
|
| 78 |
+
|
| 79 |
+
if self.llm_loader.streamer is not None and isinstance(
|
| 80 |
+
self.llm_loader.streamer, TextIteratorStreamer
|
| 81 |
+
):
|
| 82 |
+
count = 2 if len(inputs.get("chat_history")) > 0 else 1
|
| 83 |
+
|
| 84 |
+
while count > 0:
|
| 85 |
+
try:
|
| 86 |
+
for token in self.llm_loader.streamer:
|
| 87 |
+
streaming_handler.on_llm_new_token(token)
|
| 88 |
+
|
| 89 |
+
self.llm_loader.streamer.reset()
|
| 90 |
+
count -= 1
|
| 91 |
+
except Exception:
|
| 92 |
+
print("nothing generated yet - retry in 0.5s")
|
| 93 |
+
time.sleep(0.5)
|
| 94 |
+
|
| 95 |
+
t.join()
|
| 96 |
+
return que.get()
|
app_modules/llm_loader.py
CHANGED
|
@@ -88,6 +88,7 @@ class LLMLoader:
|
|
| 88 |
llm_model_type: str
|
| 89 |
llm: any
|
| 90 |
streamer: any
|
|
|
|
| 91 |
|
| 92 |
def __init__(self, llm_model_type):
|
| 93 |
self.llm_model_type = llm_model_type
|
|
|
|
| 88 |
llm_model_type: str
|
| 89 |
llm: any
|
| 90 |
streamer: any
|
| 91 |
+
max_tokens_limit: int
|
| 92 |
|
| 93 |
def __init__(self, llm_model_type):
|
| 94 |
self.llm_model_type = llm_model_type
|
app_modules/llm_qa_chain.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 2 |
+
from langchain.chains.base import Chain
|
| 3 |
+
from langchain.vectorstores.base import VectorStore
|
| 4 |
+
|
| 5 |
+
from app_modules.llm_inference import LLMInference
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class QAChain(LLMInference):
|
| 9 |
+
vectorstore: VectorStore
|
| 10 |
+
|
| 11 |
+
def __init__(self, vectorstore, llm_loader: int = 2048):
|
| 12 |
+
super.__init__(llm_loader)
|
| 13 |
+
self.vectorstore = vectorstore
|
| 14 |
+
|
| 15 |
+
def create_chain(self) -> Chain:
|
| 16 |
+
qa = ConversationalRetrievalChain.from_llm(
|
| 17 |
+
self.llm_loader.llm,
|
| 18 |
+
self.vectorstore.as_retriever(search_kwargs=self.llm_loader.search_kwargs),
|
| 19 |
+
max_tokens_limit=self.llm_loader.max_tokens_limit,
|
| 20 |
+
return_source_documents=True,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
return qa
|