refine summarize chain
Browse files- app_modules/llm_summarize_chain.py +48 -1
- summarize.py +11 -9
app_modules/llm_summarize_chain.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
from typing import List, Optional
|
|
|
|
| 3 |
|
| 4 |
from langchain.chains.base import Chain
|
| 5 |
from langchain.chains.summarize import load_summarize_chain
|
|
@@ -7,12 +8,58 @@ from langchain.chains.summarize import load_summarize_chain
|
|
| 7 |
from app_modules.llm_inference import LLMInference
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class SummarizeChain(LLMInference):
|
| 11 |
def __init__(self, llm_loader):
|
| 12 |
super().__init__(llm_loader)
|
| 13 |
|
| 14 |
def create_chain(self) -> Chain:
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
return chain
|
| 17 |
|
| 18 |
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import List, Optional
|
| 3 |
+
from langchain import PromptTemplate
|
| 4 |
|
| 5 |
from langchain.chains.base import Chain
|
| 6 |
from langchain.chains.summarize import load_summarize_chain
|
|
|
|
| 8 |
from app_modules.llm_inference import LLMInference
|
| 9 |
|
| 10 |
|
| 11 |
+
def get_llama_2_prompt_template(instruction):
|
| 12 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 13 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 14 |
+
|
| 15 |
+
system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the text to get context"
|
| 16 |
+
|
| 17 |
+
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
|
| 18 |
+
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
|
| 19 |
+
return prompt_template
|
| 20 |
+
|
| 21 |
+
|
| 22 |
class SummarizeChain(LLMInference):
|
| 23 |
def __init__(self, llm_loader):
|
| 24 |
super().__init__(llm_loader)
|
| 25 |
|
| 26 |
def create_chain(self) -> Chain:
|
| 27 |
+
use_llama_2_prompt_template = (
|
| 28 |
+
os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
| 29 |
+
)
|
| 30 |
+
prompt_template = """Write a concise summary of the following:
|
| 31 |
+
{text}
|
| 32 |
+
CONCISE SUMMARY:"""
|
| 33 |
+
|
| 34 |
+
if use_llama_2_prompt_template:
|
| 35 |
+
prompt_template = get_llama_2_prompt_template(prompt_template)
|
| 36 |
+
prompt = PromptTemplate.from_template(prompt_template)
|
| 37 |
+
|
| 38 |
+
refine_template = (
|
| 39 |
+
"Your job is to produce a final summary\n"
|
| 40 |
+
"We have provided an existing summary up to a certain point: {existing_answer}\n"
|
| 41 |
+
"We have the opportunity to refine the existing summary"
|
| 42 |
+
"(only if needed) with some more context below.\n"
|
| 43 |
+
"------------\n"
|
| 44 |
+
"{text}\n"
|
| 45 |
+
"------------\n"
|
| 46 |
+
"Given the new context, refine the original summary."
|
| 47 |
+
"If the context isn't useful, return the original summary."
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if use_llama_2_prompt_template:
|
| 51 |
+
refine_template = get_llama_2_prompt_template(refine_template)
|
| 52 |
+
refine_prompt = PromptTemplate.from_template(refine_template)
|
| 53 |
+
|
| 54 |
+
chain = load_summarize_chain(
|
| 55 |
+
llm=self.llm_loader.llm,
|
| 56 |
+
chain_type="refine",
|
| 57 |
+
question_prompt=prompt,
|
| 58 |
+
refine_prompt=refine_prompt,
|
| 59 |
+
return_intermediate_steps=True,
|
| 60 |
+
input_key="input_documents",
|
| 61 |
+
output_key="output_text",
|
| 62 |
+
)
|
| 63 |
return chain
|
| 64 |
|
| 65 |
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
|
summarize.py
CHANGED
|
@@ -15,17 +15,16 @@ from app_modules.init import app_init, get_device_types
|
|
| 15 |
from app_modules.llm_summarize_chain import SummarizeChain
|
| 16 |
|
| 17 |
|
| 18 |
-
def load_documents(source_pdfs_path,
|
| 19 |
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
|
| 20 |
documents = loader.load()
|
| 21 |
-
if
|
| 22 |
for doc in documents:
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
break
|
| 29 |
return documents
|
| 30 |
|
| 31 |
|
|
@@ -43,8 +42,11 @@ source_pdfs_path = (
|
|
| 43 |
)
|
| 44 |
chunk_size = sys.argv[2] if len(sys.argv) > 2 else os.environ.get("CHUNCK_SIZE")
|
| 45 |
chunk_overlap = sys.argv[3] if len(sys.argv) > 3 else os.environ.get("CHUNK_OVERLAP")
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
sources = load_documents(source_pdfs_path,
|
| 48 |
|
| 49 |
print(f"Splitting {len(sources)} PDF pages in to chunks ...")
|
| 50 |
|
|
|
|
| 15 |
from app_modules.llm_summarize_chain import SummarizeChain
|
| 16 |
|
| 17 |
|
| 18 |
+
def load_documents(source_pdfs_path, keep_page_info) -> List:
|
| 19 |
loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
|
| 20 |
documents = loader.load()
|
| 21 |
+
if not keep_page_info:
|
| 22 |
for doc in documents:
|
| 23 |
+
if doc is not documents[0]:
|
| 24 |
+
documents[0].page_content = (
|
| 25 |
+
documents[0].page_content + "\n" + doc.page_content
|
| 26 |
+
)
|
| 27 |
+
documents = [documents[0]]
|
|
|
|
| 28 |
return documents
|
| 29 |
|
| 30 |
|
|
|
|
| 42 |
)
|
| 43 |
chunk_size = sys.argv[2] if len(sys.argv) > 2 else os.environ.get("CHUNCK_SIZE")
|
| 44 |
chunk_overlap = sys.argv[3] if len(sys.argv) > 3 else os.environ.get("CHUNK_OVERLAP")
|
| 45 |
+
keep_page_info = (
|
| 46 |
+
sys.argv[3] if len(sys.argv) > 3 else os.environ.get("KEEP_PAGE_INFO")
|
| 47 |
+
) == "true"
|
| 48 |
|
| 49 |
+
sources = load_documents(source_pdfs_path, keep_page_info)
|
| 50 |
|
| 51 |
print(f"Splitting {len(sources)} PDF pages in to chunks ...")
|
| 52 |
|