added refine summary chain
Browse files- app_modules/init.py +34 -30
- app_modules/llm_summarize_chain.py +20 -0
- summarize.py +70 -0
app_modules/init.py
CHANGED
|
@@ -23,55 +23,59 @@ load_dotenv(found_dotenv, override=False)
|
|
| 23 |
init_settings()
|
| 24 |
|
| 25 |
|
| 26 |
-
def app_init():
|
| 27 |
# https://github.com/huggingface/transformers/issues/17611
|
| 28 |
os.environ["CURL_CA_BUNDLE"] = ""
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
|
| 31 |
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
|
| 32 |
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
|
| 39 |
-
index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get(
|
| 40 |
-
"CHROMADB_INDEX_PATH"
|
| 41 |
-
)
|
| 42 |
-
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
|
| 43 |
-
llm_model_type = os.environ.get("LLM_MODEL_TYPE")
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
)
|
| 50 |
-
end = timer()
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
elif using_faiss:
|
| 61 |
-
vectorstore = FAISS.load_local(index_path, embeddings)
|
| 62 |
-
else:
|
| 63 |
-
vectorstore = Chroma(
|
| 64 |
-
embedding_function=embeddings, persist_directory=index_path
|
| 65 |
)
|
| 66 |
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
|
| 71 |
start = timer()
|
| 72 |
llm_loader = LLMLoader(llm_model_type)
|
| 73 |
llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
|
| 74 |
-
qa_chain = QAChain(vectorstore, llm_loader)
|
| 75 |
end = timer()
|
| 76 |
print(f"Completed in {end - start:.3f}s")
|
| 77 |
|
|
|
|
| 23 |
init_settings()
|
| 24 |
|
| 25 |
|
| 26 |
+
def app_init(initQAChain: bool = True):
|
| 27 |
# https://github.com/huggingface/transformers/issues/17611
|
| 28 |
os.environ["CURL_CA_BUNDLE"] = ""
|
| 29 |
|
| 30 |
+
llm_model_type = os.environ.get("LLM_MODEL_TYPE")
|
| 31 |
+
n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
|
| 32 |
+
|
| 33 |
hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
|
| 34 |
print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
|
| 35 |
print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
|
| 36 |
|
| 37 |
+
if initQAChain:
|
| 38 |
+
hf_embeddings_model_name = (
|
| 39 |
+
os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
|
| 40 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get(
|
| 43 |
+
"CHROMADB_INDEX_PATH"
|
| 44 |
+
)
|
| 45 |
+
using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
start = timer()
|
| 48 |
+
embeddings = HuggingFaceInstructEmbeddings(
|
| 49 |
+
model_name=hf_embeddings_model_name,
|
| 50 |
+
model_kwargs={"device": hf_embeddings_device_type},
|
| 51 |
+
)
|
| 52 |
+
end = timer()
|
| 53 |
|
| 54 |
+
print(f"Completed in {end - start:.3f}s")
|
| 55 |
|
| 56 |
+
start = timer()
|
| 57 |
|
| 58 |
+
print(
|
| 59 |
+
f"Load index from {index_path} with {'FAISS' if using_faiss else 'Chroma'}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
+
if not os.path.isdir(index_path):
|
| 63 |
+
raise ValueError(f"{index_path} does not exist!")
|
| 64 |
+
elif using_faiss:
|
| 65 |
+
vectorstore = FAISS.load_local(index_path, embeddings)
|
| 66 |
+
else:
|
| 67 |
+
vectorstore = Chroma(
|
| 68 |
+
embedding_function=embeddings, persist_directory=index_path
|
| 69 |
+
)
|
| 70 |
|
| 71 |
+
end = timer()
|
| 72 |
+
|
| 73 |
+
print(f"Completed in {end - start:.3f}s")
|
| 74 |
|
| 75 |
start = timer()
|
| 76 |
llm_loader = LLMLoader(llm_model_type)
|
| 77 |
llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
|
| 78 |
+
qa_chain = QAChain(vectorstore, llm_loader) if initQAChain else None
|
| 79 |
end = timer()
|
| 80 |
print(f"Completed in {end - start:.3f}s")
|
| 81 |
|
app_modules/llm_summarize_chain.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
from langchain.chains.base import Chain
|
| 5 |
+
from langchain.chains.summarize import load_summarize_chain
|
| 6 |
+
|
| 7 |
+
from app_modules.llm_inference import LLMInference
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SummarizeChain(LLMInference):
|
| 11 |
+
def __init__(self, llm_loader):
|
| 12 |
+
super().__init__(llm_loader)
|
| 13 |
+
|
| 14 |
+
def create_chain(self) -> Chain:
|
| 15 |
+
chain = load_summarize_chain(self.llm_loader.llm, chain_type="refine")
|
| 16 |
+
return chain
|
| 17 |
+
|
| 18 |
+
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
|
| 19 |
+
result = chain(inputs, return_only_outputs=True)
|
| 20 |
+
return result
|
summarize.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# setting device on GPU if available, else CPU
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from timeit import default_timer as timer
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
| 8 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
| 9 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 10 |
+
from langchain.vectorstores.base import VectorStore
|
| 11 |
+
from langchain.vectorstores.chroma import Chroma
|
| 12 |
+
from langchain.vectorstores.faiss import FAISS
|
| 13 |
+
|
| 14 |
+
from app_modules.init import app_init, get_device_types
|
| 15 |
+
from app_modules.llm_summarize_chain import SummarizeChain
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_documents(source_pdfs_path, urls) -> List:
|
| 19 |
+
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
|
| 20 |
+
documents = loader.load()
|
| 21 |
+
if urls is not None and len(urls) > 0:
|
| 22 |
+
for doc in documents:
|
| 23 |
+
source = doc.metadata["source"]
|
| 24 |
+
filename = source.split("/")[-1]
|
| 25 |
+
for url in urls:
|
| 26 |
+
if url.endswith(filename):
|
| 27 |
+
doc.metadata["url"] = url
|
| 28 |
+
break
|
| 29 |
+
return documents
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def split_chunks(documents: List, chunk_size, chunk_overlap) -> List:
|
| 33 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 34 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
| 35 |
+
)
|
| 36 |
+
return text_splitter.split_documents(documents)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
llm_loader = app_init(False)[0]
|
| 40 |
+
|
| 41 |
+
source_pdfs_path = (
|
| 42 |
+
sys.argv[1] if len(sys.argv) > 1 else os.environ.get("SOURCE_PDFS_PATH")
|
| 43 |
+
)
|
| 44 |
+
chunk_size = os.environ.get("CHUNCK_SIZE")
|
| 45 |
+
chunk_overlap = os.environ.get("CHUNK_OVERLAP")
|
| 46 |
+
|
| 47 |
+
sources = load_documents(source_pdfs_path, None)
|
| 48 |
+
|
| 49 |
+
print(f"Splitting {len(sources)} PDF pages in to chunks ...")
|
| 50 |
+
|
| 51 |
+
chunks = split_chunks(
|
| 52 |
+
sources, chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
print(f"Summarizing {len(chunks)} chunks ...")
|
| 56 |
+
start = timer()
|
| 57 |
+
|
| 58 |
+
summarize_chain = SummarizeChain(llm_loader)
|
| 59 |
+
result = summarize_chain.call_chain(
|
| 60 |
+
{"input_documents": chunks},
|
| 61 |
+
None,
|
| 62 |
+
None,
|
| 63 |
+
True,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
end = timer()
|
| 67 |
+
print(f"Completed in {end - start:.3f}s")
|
| 68 |
+
|
| 69 |
+
print("\n\n***Summary:")
|
| 70 |
+
print(result["output_text"])
|