Spaces:
Runtime error
Runtime error
vteam27
commited on
Commit
·
5a08d5f
1
Parent(s):
57b1a45
added RAG
Browse files- app.py +147 -0
- requirements.txt +21 -5
- utils.py +59 -1
app.py
CHANGED
|
@@ -14,6 +14,7 @@ from happytransformer import HappyTextToText, TTSettings
|
|
| 14 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,logging
|
| 15 |
from transformers.integrations import deepspeed
|
| 16 |
import re
|
|
|
|
| 17 |
from lang_list import (
|
| 18 |
LANGUAGE_NAME_TO_CODE,
|
| 19 |
T2TT_TARGET_LANGUAGE_NAMES,
|
|
@@ -251,12 +252,158 @@ with gr.Blocks() as demo_t2tt:
|
|
| 251 |
api_name="t2tt",
|
| 252 |
)
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
with gr.Blocks() as demo:
|
| 255 |
with gr.Tabs():
|
| 256 |
with gr.Tab(label="OCR"):
|
| 257 |
demo_ocr.render()
|
| 258 |
with gr.Tab(label="Translate"):
|
| 259 |
demo_t2tt.render()
|
|
|
|
|
|
|
| 260 |
|
| 261 |
if __name__ == "__main__":
|
| 262 |
demo.launch()
|
|
|
|
| 14 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,logging
|
| 15 |
from transformers.integrations import deepspeed
|
| 16 |
import re
|
| 17 |
+
import torch
|
| 18 |
from lang_list import (
|
| 19 |
LANGUAGE_NAME_TO_CODE,
|
| 20 |
T2TT_TARGET_LANGUAGE_NAMES,
|
|
|
|
| 252 |
api_name="t2tt",
|
| 253 |
)
|
| 254 |
|
| 255 |
+
|
| 256 |
+
#RAG
|
| 257 |
+
import utils
|
| 258 |
+
from langchain_mistralai import ChatMistralAI
|
| 259 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 260 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 261 |
+
from langchain_community.vectorstores import Chroma
|
| 262 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 263 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 264 |
+
os.environ['MISTRAL_API_KEY'] = 'XuyOObDE7trMbpAeI7OXYr3dnmoWy3L0'
|
| 265 |
+
|
| 266 |
+
class VectorData():
|
| 267 |
+
def __init__(self):
|
| 268 |
+
embedding_model_name = 'l3cube-pune/punjabi-sentence-similarity-sbert'
|
| 269 |
+
|
| 270 |
+
model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True}
|
| 271 |
+
|
| 272 |
+
self.embeddings = HuggingFaceEmbeddings(
|
| 273 |
+
model_name=embedding_model_name,
|
| 274 |
+
model_kwargs=model_kwargs
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.vectorstore = Chroma(persist_directory="chroma_db", embedding_function=self.embeddings)
|
| 278 |
+
self.retriever = self.vectorstore.as_retriever()
|
| 279 |
+
self.ingested_files = []
|
| 280 |
+
self.prompt = ChatPromptTemplate.from_messages(
|
| 281 |
+
[
|
| 282 |
+
(
|
| 283 |
+
"system",
|
| 284 |
+
"""Answer the question based on the given context. Dont give any ans if context is not valid to question. Always give the source of context:
|
| 285 |
+
{context}
|
| 286 |
+
""",
|
| 287 |
+
),
|
| 288 |
+
("human", "{question}"),
|
| 289 |
+
]
|
| 290 |
+
)
|
| 291 |
+
self.llm = ChatMistralAI(model="mistral-large-latest")
|
| 292 |
+
self.rag_chain = (
|
| 293 |
+
{"context": self.retriever, "question": RunnablePassthrough()}
|
| 294 |
+
| self.prompt
|
| 295 |
+
| self.llm
|
| 296 |
+
| StrOutputParser()
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def add_file(self,file):
|
| 300 |
+
if file is not None:
|
| 301 |
+
self.ingested_files.append(file.name.split('/')[-1])
|
| 302 |
+
self.retriever, self.vectorstore = utils.add_doc(file,self.vectorstore)
|
| 303 |
+
self.rag_chain = (
|
| 304 |
+
{"context": self.retriever, "question": RunnablePassthrough()}
|
| 305 |
+
| self.prompt
|
| 306 |
+
| self.llm
|
| 307 |
+
| StrOutputParser()
|
| 308 |
+
)
|
| 309 |
+
return [[name] for name in self.ingested_files]
|
| 310 |
+
|
| 311 |
+
def delete_file_by_name(self,file_name):
|
| 312 |
+
if file_name in self.ingested_files:
|
| 313 |
+
self.retriever, self.vectorstore = utils.delete_doc(file_name,self.vectorstore)
|
| 314 |
+
self.ingested_files.remove(file_name)
|
| 315 |
+
return [[name] for name in self.ingested_files]
|
| 316 |
+
|
| 317 |
+
def delete_all_files(self):
|
| 318 |
+
self.ingested_files.clear()
|
| 319 |
+
self.retriever, self.vectorstore = utils.delete_all_doc(self.vectorstore)
|
| 320 |
+
return []
|
| 321 |
+
|
| 322 |
+
data_obj = VectorData()
|
| 323 |
+
|
| 324 |
+
# Function to handle question answering
|
| 325 |
+
def answer_question(question):
|
| 326 |
+
if question.strip():
|
| 327 |
+
return f'{data_obj.rag_chain.invoke(question)}'
|
| 328 |
+
return "Please enter a question."
|
| 329 |
+
|
| 330 |
+
with gr.Blocks() as rag_interface:
|
| 331 |
+
# Title and Description
|
| 332 |
+
gr.Markdown("# RAG Interface")
|
| 333 |
+
gr.Markdown("Manage documents and ask questions with a Retrieval-Augmented Generation (RAG) system.")
|
| 334 |
+
|
| 335 |
+
with gr.Row():
|
| 336 |
+
# Left Column: File Management
|
| 337 |
+
with gr.Column():
|
| 338 |
+
gr.Markdown("### File Management")
|
| 339 |
+
|
| 340 |
+
# File upload and ingest
|
| 341 |
+
file_input = gr.File(label="Upload File to Ingest")
|
| 342 |
+
add_file_button = gr.Button("Ingest File")
|
| 343 |
+
|
| 344 |
+
# Scrollable list for ingested files
|
| 345 |
+
ingested_files_box = gr.Dataframe(
|
| 346 |
+
headers=["Files"],
|
| 347 |
+
datatype="str",
|
| 348 |
+
row_count=4, # Limits the visible rows to create a scrollable view
|
| 349 |
+
interactive=False
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Radio buttons to choose delete option
|
| 353 |
+
delete_option = gr.Radio(choices=["Delete by File Name", "Delete All Files"], label="Delete Option")
|
| 354 |
+
file_name_input = gr.Textbox(label="Enter File Name to Delete", visible=False)
|
| 355 |
+
delete_button = gr.Button("Delete Selected")
|
| 356 |
+
|
| 357 |
+
# Show or hide file name input based on delete option selection
|
| 358 |
+
def toggle_file_input(option):
|
| 359 |
+
return gr.update(visible=(option == "Delete by File Name"))
|
| 360 |
+
|
| 361 |
+
delete_option.change(fn=toggle_file_input, inputs=delete_option, outputs=file_name_input)
|
| 362 |
+
|
| 363 |
+
# Handle file ingestion
|
| 364 |
+
add_file_button.click(
|
| 365 |
+
fn=data_obj.add_file,
|
| 366 |
+
inputs=file_input,
|
| 367 |
+
outputs=ingested_files_box
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Handle delete based on selected option
|
| 371 |
+
def delete_action(delete_option, file_name):
|
| 372 |
+
if delete_option == "Delete by File Name" and file_name:
|
| 373 |
+
return data_obj.delete_file_by_name(file_name)
|
| 374 |
+
elif delete_option == "Delete All Files":
|
| 375 |
+
return data_obj.delete_all_files()
|
| 376 |
+
else:
|
| 377 |
+
return [[name] for name in data_obj.ingested_files]
|
| 378 |
+
|
| 379 |
+
delete_button.click(
|
| 380 |
+
fn=delete_action,
|
| 381 |
+
inputs=[delete_option, file_name_input],
|
| 382 |
+
outputs=ingested_files_box
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Right Column: Question Answering
|
| 386 |
+
with gr.Column():
|
| 387 |
+
gr.Markdown("### Ask a Question")
|
| 388 |
+
|
| 389 |
+
# Question input
|
| 390 |
+
question_input = gr.Textbox(label="Enter your question")
|
| 391 |
+
|
| 392 |
+
# Get answer button and answer output
|
| 393 |
+
ask_button = gr.Button("Get Answer")
|
| 394 |
+
answer_output = gr.Textbox(label="Answer", interactive=False)
|
| 395 |
+
|
| 396 |
+
ask_button.click(fn=answer_question, inputs=question_input, outputs=answer_output)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
with gr.Blocks() as demo:
|
| 400 |
with gr.Tabs():
|
| 401 |
with gr.Tab(label="OCR"):
|
| 402 |
demo_ocr.render()
|
| 403 |
with gr.Tab(label="Translate"):
|
| 404 |
demo_t2tt.render()
|
| 405 |
+
with gr.Tab(label="RAG"):
|
| 406 |
+
rag_interface.render()
|
| 407 |
|
| 408 |
if __name__ == "__main__":
|
| 409 |
demo.launch()
|
requirements.txt
CHANGED
|
@@ -4,14 +4,30 @@ reportlab>=3.6.2
|
|
| 4 |
PyPDF2==1.26.0
|
| 5 |
happytransformer
|
| 6 |
python-doctr[torch]@git+https://github.com/mindee/doctr.git
|
| 7 |
-
transformers
|
| 8 |
fairseq2==0.1
|
| 9 |
-
pydub
|
| 10 |
yt-dlp
|
| 11 |
sentencepiece
|
| 12 |
nltk
|
| 13 |
-
numpy==1.26.4
|
| 14 |
opencv-python==4.9.0.80
|
| 15 |
-
packaging
|
| 16 |
pillow==10.3.0
|
| 17 |
-
pytesseract==0.3.10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
PyPDF2==1.26.0
|
| 5 |
happytransformer
|
| 6 |
python-doctr[torch]@git+https://github.com/mindee/doctr.git
|
|
|
|
| 7 |
fairseq2==0.1
|
|
|
|
| 8 |
yt-dlp
|
| 9 |
sentencepiece
|
| 10 |
nltk
|
|
|
|
| 11 |
opencv-python==4.9.0.80
|
|
|
|
| 12 |
pillow==10.3.0
|
| 13 |
+
pytesseract==0.3.10
|
| 14 |
+
packaging
|
| 15 |
+
torch
|
| 16 |
+
fastapi
|
| 17 |
+
uvicorn
|
| 18 |
+
pandas
|
| 19 |
+
numpy
|
| 20 |
+
torch
|
| 21 |
+
transformers
|
| 22 |
+
scikit-learn
|
| 23 |
+
sentence-transformers
|
| 24 |
+
langchain
|
| 25 |
+
langchain-community
|
| 26 |
+
langchain-core
|
| 27 |
+
langchain-huggingface
|
| 28 |
+
langchain-mistralai
|
| 29 |
+
langchain-text-splitters
|
| 30 |
+
langsmith
|
| 31 |
+
chroma-hnswlib
|
| 32 |
+
chromadb
|
| 33 |
+
fastapi
|
utils.py
CHANGED
|
@@ -160,4 +160,62 @@ class HocrParser():
|
|
| 160 |
if image is not None:
|
| 161 |
pdf.drawImage(ImageReader(Image.fromarray(image)),
|
| 162 |
0, 0, width=width, height=height)
|
| 163 |
-
pdf.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
if image is not None:
|
| 161 |
pdf.drawImage(ImageReader(Image.fromarray(image)),
|
| 162 |
0, 0, width=width, height=height)
|
| 163 |
+
pdf.save()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 168 |
+
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
| 169 |
+
from langchain_community.vectorstores import Chroma
|
| 170 |
+
from langchain.schema import Document
|
| 171 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 172 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
| 173 |
+
import torch
|
| 174 |
+
|
| 175 |
+
embedding_model_name = 'l3cube-pune/punjabi-sentence-similarity-sbert'
|
| 176 |
+
|
| 177 |
+
model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True}
|
| 178 |
+
|
| 179 |
+
embeddings = HuggingFaceEmbeddings(
|
| 180 |
+
model_name=embedding_model_name,
|
| 181 |
+
model_kwargs=model_kwargs
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
vectorstore = None
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def read_file(data: str) -> Document:
|
| 189 |
+
f = open(data,'r')
|
| 190 |
+
content = f.read()
|
| 191 |
+
f.close()
|
| 192 |
+
doc = Document(page_content=content, metadata={"name": data.split('/')[-1]})
|
| 193 |
+
return doc
|
| 194 |
+
|
| 195 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
|
| 196 |
+
|
| 197 |
+
def add_doc(data,vectorstore):
|
| 198 |
+
doc = read_file(data)
|
| 199 |
+
splits = text_splitter.split_documents([doc])
|
| 200 |
+
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
|
| 201 |
+
retriever = vectorstore.as_retriever(search_kwargs={'k':1})
|
| 202 |
+
return retriever, vectorstore
|
| 203 |
+
|
| 204 |
+
def delete_doc(delete_name,vectorstore):
|
| 205 |
+
delete_doc_ids = []
|
| 206 |
+
for idx,name in enumerate(vectorstore.get()['metadatas']):
|
| 207 |
+
if name['name'] == delete_name:
|
| 208 |
+
delete_doc_ids.append(vectorstore.get()['ids'][idx])
|
| 209 |
+
for id in delete_doc_ids:
|
| 210 |
+
vectorstore.delete(ids = id)
|
| 211 |
+
# vectorstore.persist()
|
| 212 |
+
retriever = vectorstore.as_retriever(search_kwargs={'k':1})
|
| 213 |
+
return retriever, vectorstore
|
| 214 |
+
|
| 215 |
+
def delete_all_doc(vectorstore):
|
| 216 |
+
delete_doc_ids = vectorstore.get()['ids']
|
| 217 |
+
for id in delete_doc_ids:
|
| 218 |
+
vectorstore.delete(ids = id)
|
| 219 |
+
# vectorstore.persist()
|
| 220 |
+
retriever = vectorstore.as_retriever(search_kwargs={'k':1})
|
| 221 |
+
return retriever, vectorstore
|