|  | import spaces | 
					
						
						|  | import gradio as gr | 
					
						
						|  | from datasets import load_dataset | 
					
						
						|  | import os | 
					
						
						|  | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | 
					
						
						|  | import torch | 
					
						
						|  | from threading import Thread | 
					
						
						|  | from sentence_transformers import SentenceTransformer | 
					
						
						|  | import faiss | 
					
						
						|  | import fitz | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | token = os.environ.get("HF_TOKEN") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ST = SentenceTransformer("jhgan/ko-sroberta-multitask") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def extract_text_from_pdf(pdf_path): | 
					
						
						|  | doc = fitz.open(pdf_path) | 
					
						
						|  | text = "" | 
					
						
						|  | for page in doc: | 
					
						
						|  | text += page.get_text() | 
					
						
						|  | return text | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pdf_path = "laws.pdf" | 
					
						
						|  | law_text = extract_text_from_pdf(pdf_path) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | law_sentences = law_text.split('\n') | 
					
						
						|  | law_embeddings = ST.encode(law_sentences) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | index = faiss.IndexFlatL2(law_embeddings.shape[1]) | 
					
						
						|  | index.add(law_embeddings) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dataset = load_dataset("jihye-moon/LawQA-Ko") | 
					
						
						|  | data = dataset["train"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True) | 
					
						
						|  | data.add_faiss_index(column="question_embedding") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_id = "google/gemma-2-27b-it" | 
					
						
						|  | bnb_config = BitsAndBytesConfig( | 
					
						
						|  | load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 | 
					
						
						|  | ) | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | 
					
						
						|  | model = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | model_id, | 
					
						
						|  | torch_dtype=torch.bfloat16, | 
					
						
						|  | device_map="auto", | 
					
						
						|  | quantization_config=bnb_config, | 
					
						
						|  | token=token | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | SYS_PROMPT = """You are an assistant for answering legal questions. | 
					
						
						|  | You are given the extracted parts of legal documents and a question. Provide a conversational answer. | 
					
						
						|  | If you don't know the answer, just say "I do not know." Don't make up an answer. | 
					
						
						|  | you must answer korean.""" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @spaces.Gpu | 
					
						
						|  | def search_law(query, k=5): | 
					
						
						|  | query_embedding = ST.encode([query]) | 
					
						
						|  | D, I = index.search(query_embedding, k) | 
					
						
						|  | return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @spaces.Gpu | 
					
						
						|  | def search_qa(query, k=3): | 
					
						
						|  | scores, retrieved_examples = data.get_nearest_examples( | 
					
						
						|  | "question_embedding", ST.encode(query), k=k | 
					
						
						|  | ) | 
					
						
						|  | return [retrieved_examples["answer"][i] for i in range(k)] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def format_prompt(prompt, law_docs, qa_docs): | 
					
						
						|  | PROMPT = f"Question: {prompt}\n\nLegal Context:\n" | 
					
						
						|  | for doc in law_docs: | 
					
						
						|  | PROMPT += f"{doc[0]}\n" | 
					
						
						|  | PROMPT += "\nLegal QA:\n" | 
					
						
						|  | for doc in qa_docs: | 
					
						
						|  | PROMPT += f"{doc}\n" | 
					
						
						|  | return PROMPT | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @spaces.Gpu | 
					
						
						|  | def talk(prompt, history): | 
					
						
						|  | law_results = search_law(prompt, k=3) | 
					
						
						|  | qa_results = search_qa(prompt, k=3) | 
					
						
						|  |  | 
					
						
						|  | retrieved_law_docs = [result[0] for result in law_results] | 
					
						
						|  | formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results) | 
					
						
						|  | formatted_prompt = formatted_prompt[:2000] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | messages = [{"role": "user", "content": SYS_PROMPT + "\n" + formatted_prompt}] | 
					
						
						|  |  | 
					
						
						|  | input_ids = tokenizer.apply_chat_template( | 
					
						
						|  | messages, | 
					
						
						|  | add_generation_prompt=True, | 
					
						
						|  | return_tensors="pt" | 
					
						
						|  | ).to(model.device) | 
					
						
						|  |  | 
					
						
						|  | streamer = TextIteratorStreamer( | 
					
						
						|  | tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | generate_kwargs = dict( | 
					
						
						|  | input_ids=input_ids, | 
					
						
						|  | streamer=streamer, | 
					
						
						|  | max_new_tokens=1024, | 
					
						
						|  | do_sample=True, | 
					
						
						|  | top_p=0.95, | 
					
						
						|  | temperature=0.2, | 
					
						
						|  | eos_token_id=tokenizer.eos_token_id, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | t = Thread(target=model.generate, kwargs=generate_kwargs) | 
					
						
						|  | t.start() | 
					
						
						|  |  | 
					
						
						|  | outputs = [] | 
					
						
						|  | for text in streamer: | 
					
						
						|  | outputs.append(text) | 
					
						
						|  | yield "".join(outputs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | TITLE = "Legal RAG Chatbot" | 
					
						
						|  | DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation. | 
					
						
						|  | This chatbot can search legal documents and previous legal QA pairs to provide answers.""" | 
					
						
						|  |  | 
					
						
						|  | demo = gr.ChatInterface( | 
					
						
						|  | fn=talk, | 
					
						
						|  | chatbot=gr.Chatbot( | 
					
						
						|  | show_label=True, | 
					
						
						|  | show_share_button=True, | 
					
						
						|  | show_copy_button=True, | 
					
						
						|  | likeable=True, | 
					
						
						|  | layout="bubble", | 
					
						
						|  | bubble_full_width=False, | 
					
						
						|  | ), | 
					
						
						|  | theme="Soft", | 
					
						
						|  | examples=[["What are the regulations on data privacy?"]], | 
					
						
						|  | title=TITLE, | 
					
						
						|  | description=DESCRIPTION, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | demo.launch(debug=True) |