Upload 4 files
Browse files
    	
        README.md
    CHANGED
    
    | @@ -1,13 +1,15 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 4. | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
            -
            pinned:  | 
| 10 | 
            -
             | 
|  | |
|  | |
| 11 | 
             
            ---
         | 
| 12 |  | 
| 13 | 
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: RAG-Chatbot
         | 
| 3 | 
            +
            emoji: 🌘w🌖
         | 
| 4 | 
            +
            colorFrom: yellow
         | 
| 5 | 
            +
            colorTo: red
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 4.39.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
            +
            pinned: true
         | 
| 10 | 
            +
            short_description: A retrieval system with chatbot integration
         | 
| 11 | 
            +
            thumbnail: >-
         | 
| 12 | 
            +
              https://cdn-uploads.huggingface.co/production/uploads/6527e89a8808d80ccff88b7a/XVgtQiizeFHIUUj1huwdv.png
         | 
| 13 | 
             
            ---
         | 
| 14 |  | 
| 15 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,150 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from datasets import load_dataset
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from threading import Thread
         | 
| 7 | 
            +
            from sentence_transformers import SentenceTransformer
         | 
| 8 | 
            +
            import faiss
         | 
| 9 | 
            +
            import fitz  # PyMuPDF
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # 환경 변수에서 Hugging Face 토큰 가져오기
         | 
| 12 | 
            +
            token = os.environ.get("HF_TOKEN")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            # 임베딩 모델 로드
         | 
| 16 | 
            +
            ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # PDF에서 텍스트 추출
         | 
| 19 | 
            +
            def extract_text_from_pdf(pdf_path):
         | 
| 20 | 
            +
                doc = fitz.open(pdf_path)
         | 
| 21 | 
            +
                text = ""
         | 
| 22 | 
            +
                for page in doc:
         | 
| 23 | 
            +
                    text += page.get_text()
         | 
| 24 | 
            +
                return text
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # 법률 문서 PDF 경로 지정 및 텍스트 추출
         | 
| 27 | 
            +
            pdf_path = "laws.pdf"  # 여기에 실제 PDF 경로를 입력하세요.
         | 
| 28 | 
            +
            law_text = extract_text_from_pdf(pdf_path)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # 법률 문서 텍스트를 문장 단위로 나누고 임베딩
         | 
| 31 | 
            +
            law_sentences = law_text.split('\n')  # Adjust splitting based on your PDF structure
         | 
| 32 | 
            +
            law_embeddings = ST.encode(law_sentences)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # FAISS 인덱스 생성 및 임베딩 추가
         | 
| 35 | 
            +
            index = faiss.IndexFlatL2(law_embeddings.shape[1])
         | 
| 36 | 
            +
            index.add(law_embeddings)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            # Hugging Face에서 법률 상담 데이터셋 로드
         | 
| 39 | 
            +
            dataset = load_dataset("jihye-moon/LawQA-Ko")
         | 
| 40 | 
            +
            data = dataset["train"]
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            # 질문 컬럼을 임베딩하여 새로운 컬럼에 추가
         | 
| 43 | 
            +
            data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
         | 
| 44 | 
            +
            data.add_faiss_index(column="question_embedding")
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            # LLaMA 모델 설정
         | 
| 47 | 
            +
            model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
         | 
| 48 | 
            +
            bnb_config = BitsAndBytesConfig(
         | 
| 49 | 
            +
                load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
         | 
| 50 | 
            +
            )
         | 
| 51 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
         | 
| 52 | 
            +
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 53 | 
            +
                model_id,
         | 
| 54 | 
            +
                torch_dtype=torch.bfloat16,
         | 
| 55 | 
            +
                device_map="auto",
         | 
| 56 | 
            +
                quantization_config=bnb_config,
         | 
| 57 | 
            +
                token=token
         | 
| 58 | 
            +
            )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            SYS_PROMPT = """You are an assistant for answering legal questions.
         | 
| 61 | 
            +
            You are given the extracted parts of legal documents and a question. Provide a conversational answer.
         | 
| 62 | 
            +
            If you don't know the answer, just say "I do not know." Don't make up an answer.
         | 
| 63 | 
            +
            you must answer korean."""
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            # 법률 문서 검색 함수
         | 
| 66 | 
            +
            def search_law(query, k=5):
         | 
| 67 | 
            +
                query_embedding = ST.encode([query])
         | 
| 68 | 
            +
                D, I = index.search(query_embedding, k)
         | 
| 69 | 
            +
                return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            # 법률 상담 데이터 검색 함수
         | 
| 72 | 
            +
            def search_qa(query, k=3):
         | 
| 73 | 
            +
                scores, retrieved_examples = data.get_nearest_examples(
         | 
| 74 | 
            +
                    "question_embedding", ST.encode(query), k=k
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
                return [retrieved_examples["answer"][i] for i in range(k)]
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            # 최종 프롬프트 생성
         | 
| 79 | 
            +
            def format_prompt(prompt, law_docs, qa_docs):
         | 
| 80 | 
            +
                PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
         | 
| 81 | 
            +
                for doc in law_docs:
         | 
| 82 | 
            +
                    PROMPT += f"{doc[0]}\n"  # Assuming doc[0] contains the relevant text
         | 
| 83 | 
            +
                PROMPT += "\nLegal QA:\n"
         | 
| 84 | 
            +
                for doc in qa_docs:
         | 
| 85 | 
            +
                    PROMPT += f"{doc}\n"
         | 
| 86 | 
            +
                return PROMPT
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            # 챗봇 응답 함수
         | 
| 89 | 
            +
            def talk(prompt, history):
         | 
| 90 | 
            +
                law_results = search_law(prompt, k=3)
         | 
| 91 | 
            +
                qa_results = search_qa(prompt, k=3)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                retrieved_law_docs = [result[0] for result in law_results]
         | 
| 94 | 
            +
                formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
         | 
| 95 | 
            +
                formatted_prompt = formatted_prompt[:2000]  # GPU 메모리 부족을 피하기 위해 프롬프트 제한
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # 모델에게 생성 지시
         | 
| 100 | 
            +
                input_ids = tokenizer.apply_chat_template(
         | 
| 101 | 
            +
                    messages,
         | 
| 102 | 
            +
                    add_generation_prompt=True,
         | 
| 103 | 
            +
                    return_tensors="pt"
         | 
| 104 | 
            +
                ).to(model.device)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                streamer = TextIteratorStreamer(
         | 
| 107 | 
            +
                    tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
         | 
| 108 | 
            +
                )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                generate_kwargs = dict(
         | 
| 111 | 
            +
                    input_ids=input_ids,
         | 
| 112 | 
            +
                    streamer=streamer,
         | 
| 113 | 
            +
                    max_new_tokens=1024,
         | 
| 114 | 
            +
                    do_sample=True,
         | 
| 115 | 
            +
                    top_p=0.95,
         | 
| 116 | 
            +
                    temperature=0.75,
         | 
| 117 | 
            +
                    eos_token_id=tokenizer.eos_token_id,
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                t = Thread(target=model.generate, kwargs=generate_kwargs)
         | 
| 121 | 
            +
                t.start()
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                outputs = []
         | 
| 124 | 
            +
                for text in streamer:
         | 
| 125 | 
            +
                    outputs.append(text)
         | 
| 126 | 
            +
                    yield "".join(outputs)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            # Gradio 인터페이스 설정
         | 
| 129 | 
            +
            TITLE = "Legal RAG Chatbot"
         | 
| 130 | 
            +
            DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
         | 
| 131 | 
            +
            This chatbot can search legal documents and previous legal QA pairs to provide answers."""
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            demo = gr.ChatInterface(
         | 
| 134 | 
            +
                fn=talk,
         | 
| 135 | 
            +
                chatbot=gr.Chatbot(
         | 
| 136 | 
            +
                    show_label=True,
         | 
| 137 | 
            +
                    show_share_button=True,
         | 
| 138 | 
            +
                    show_copy_button=True,
         | 
| 139 | 
            +
                    likeable=True,
         | 
| 140 | 
            +
                    layout="bubble",
         | 
| 141 | 
            +
                    bubble_full_width=False,
         | 
| 142 | 
            +
                ),
         | 
| 143 | 
            +
                theme="Soft",
         | 
| 144 | 
            +
                examples=[["What are the regulations on data privacy?"]],
         | 
| 145 | 
            +
                title=TITLE,
         | 
| 146 | 
            +
                description=DESCRIPTION,
         | 
| 147 | 
            +
            )
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            # Gradio 데모 실행
         | 
| 150 | 
            +
            demo.launch(debug=True)
         | 
    	
        laws.pdf
    ADDED
    
    | Binary file (836 kB). View file | 
|  | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            spaces
         | 
| 2 | 
            +
            torch==2.2.0
         | 
| 3 | 
            +
            transformers
         | 
| 4 | 
            +
            sentence-transformers
         | 
| 5 | 
            +
            faiss-gpu
         | 
| 6 | 
            +
            datasets
         | 
| 7 | 
            +
            accelerate
         | 
| 8 | 
            +
            bitsandbytes
         | 
| 9 | 
            +
            PyMuPDF
         |