Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import torch | |
| from datasets import load_dataset | |
| from huggingface_hub import InferenceClient | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| # === 1. Build the FAISS vectorstore from CUAD === | |
| print(" Loading CUAD and building index...") | |
| #new | |
| from datasets import load_dataset | |
| cuad_data = load_dataset("lex_glue", "cuad") | |
| texts = [item["text"] for item in cuad_data["train"] if "text" in item] | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| docs = splitter.create_documents(texts) | |
| embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| vectorstore = FAISS.from_documents(docs, embedding_model) | |
| # === 2. Model setup === | |
| USE_LLAMA = os.environ.get("USE_LLAMA", "false").lower() == "true" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| def load_llama(): | |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "meta-llama/Llama-2-7b-chat-hf", | |
| device_map="auto", | |
| torch_dtype=torch.float16 | |
| ) | |
| return tokenizer, model | |
| def generate_llama_response(prompt): | |
| inputs = llama_tokenizer(prompt, return_tensors="pt").to("cuda") | |
| outputs = llama_model.generate(**inputs, max_new_tokens=300) | |
| return llama_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def generate_mistral_response(prompt): | |
| return mistral_client.text_generation(prompt=prompt, max_new_tokens=300).strip() | |
| if USE_LLAMA: | |
| llama_tokenizer, llama_model = load_llama() | |
| generate_response = generate_llama_response | |
| else: | |
| mistral_client = InferenceClient( | |
| model="mistralai/Mistral-7B-Instruct-v0.1", | |
| token=HF_TOKEN | |
| ) | |
| generate_response = generate_mistral_response | |
| # === 3. Main QA function === | |
| def answer_question(user_query): | |
| docs = vectorstore.similarity_search(user_query, k=3) | |
| context = "\n".join([doc.page_content for doc in docs]) | |
| prompt = f"""[Context] | |
| {context} | |
| [User Question] | |
| {user_query} | |
| [Answer] | |
| """ | |
| return generate_response(prompt) | |
| # === 4. Gradio UI === | |
| iface = gr.Interface( | |
| fn=answer_question, | |
| inputs=gr.Textbox(placeholder="Ask a question about your contract..."), | |
| outputs=gr.Textbox(label="Answer"), | |
| title="LawCounsel AI", | |
| description="Ask clause-specific questions from CUAD-trained contracts. Powered by RAG using Mistral or LLaMA.", | |
| ) | |
| iface.launch() | |