lexsum / app.py
abeergandhi's picture
Update app.py
5366c36 verified
# app.py
# --- 1. IMPORTS ---
import os
import re
import io
import nltk
import torch
import numpy as np
import pdfplumber
import gradio as gr
import google.generativeai as genai
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
from rouge_score import rouge_scorer
from nltk.tokenize import sent_tokenize
# Download NLTK data during the build process
nltk.download('punkt_tab')
nltk.download('stopwords')
print(" NLTK data downloaded.")
# --- 2. MODEL LOADING & BACKEND CLASS ---
# --- IMPORTANT: LOAD API KEY FROM HF SECRETS ---
try:
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
if not GEMINI_API_KEY:
raise ValueError("GEMINI_API_KEY secret not found!")
genai.configure(api_key=GEMINI_API_KEY)
print(" Gemini API configured.")
GEMINI_AVAILABLE = True
except Exception as e:
print(f" Could not configure Gemini: {e}")
GEMINI_AVAILABLE = False
# Load LegalBERT
print("\nLoading LegalBERT model... (This may take a minute on first startup)")
try:
legalbert_tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
legalbert_model = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")
print(" LegalBERT model loaded successfully.")
except Exception as e:
print(f" Error loading LegalBERT: {e}")
LEGAL_KEYWORDS = [
"plaintiff", "defendant", "respondent", "petitioner", "appellant", "accused",
"complainant", "counsel", "judge", "bench", "jury", "tribunal", "magistrate",
"prosecution", "litigant", "witness", "party", "victim", "guardian", "appeal",
"trial", "hearing", "proceedings", "petition", "lawsuit", "litigation",
"jurisdiction", "adjudication", "motion", "plea", "summons", "order", "verdict",
"judgment", "injunction", "stay", "ruling", "decree", "directive", "sentence",
"conviction", "acquittal", "remand", "bail", "bond", "arrest", "charge", "claim",
"complaint", "indictment", "retrial", "cross-examination", "examination",
"evidence", "testimony", "affidavit", "exhibit", "discovery", "burden", "standard",
"precedent", "case law", "statute", "legislation", "constitution", "clause",
"provision", "contract", "agreement", "treaty", "bill", "code", "regulation",
"enactment", "doctrine", "principle", "interpretation", "finding", "conclusion",
"damages", "compensation", "fine", "penalty", "sanction", "relief", "settlement",
"restitution", "injury", "liability", "negligence", "breach", "fault", "guilt",
"rights", "obligation", "duty", "responsibility", "violation", "remedy", "probation",
"parole", "dismissal", "overrule", "uphold", "vacate", "enforcement"
]
class HybridLegalSummarizer:
def __init__(self):
self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
self.legal_keywords = set(word.lower() for word in LEGAL_KEYWORDS)
if GEMINI_AVAILABLE:
self.refinement_model = genai.GenerativeModel('models/gemini-2.5-flash')
def get_legalbert_embedding(self, text):
inputs = legalbert_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = legalbert_model(**inputs)
return torch.mean(outputs.last_hidden_state, dim=1).squeeze().numpy()
def preprocess_text(self, text):
text = re.sub(r'\s+', ' ', text)
sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]
seen = set()
unique_sentences = [s for s in sentences if not (s.lower() in seen or seen.add(s.lower()))]
return ' '.join(unique_sentences) if unique_sentences else None
def extract_legal_terms(self, text):
words = text.lower().split()
return {word for word in words if re.sub(r'[^\w\s]', '', word) in self.legal_keywords}
def generate_extractive_draft(self, text, max_words=200):
sentences = sent_tokenize(text)
if not sentences: return ""
sentence_embeddings = np.array([self.get_legalbert_embedding(sent) for sent in sentences])
centroid = np.mean(sentence_embeddings, axis=0)
scores = [cosine_similarity(emb.reshape(1, -1), centroid.reshape(1, -1))[0][0] for emb in sentence_embeddings]
ranked_indices = np.argsort(scores)[::-1]
selected = []
current_count = 0
target_draft_words = int(max_words * 1.5)
for i in ranked_indices:
sent = sentences[i]
word_count = len(sent.split())
if current_count + word_count <= target_draft_words:
selected.append((i, sent))
current_count += word_count
selected.sort(key=lambda x: x[0])
return ' '.join([sent for i, sent in selected])
def generate_rag_draft(self, text, user_query, max_words=600):
sentences = sent_tokenize(text)
if not sentences: return ""
query_embedding = self.get_legalbert_embedding(user_query)
sentence_embeddings = np.array([self.get_legalbert_embedding(sent) for sent in sentences])
scores = [cosine_similarity(emb.reshape(1, -1), query_embedding.reshape(1, -1))[0][0] for emb in sentence_embeddings]
ranked_indices = np.argsort(scores)[::-1]
selected = []
current_count = 0
target_draft_words = int(max_words * 1.5)
for i in ranked_indices:
sent = sentences[i]
word_count = len(sent.split())
if current_count + word_count <= target_draft_words:
selected.append((i, sent))
current_count += word_count
selected.sort(key=lambda x: x[0])
return ' '.join([sent for i, sent in selected])
def refine_with_llm(self, draft_text, max_words, user_query=None):
if not GEMINI_AVAILABLE or not draft_text: return "Skipping refinement stage."
if user_query:
prompt = f"""Based ONLY on the provided context, answer the user's question concisely. Question: "{user_query}" Context: --- {draft_text} --- Answer:"""
else:
prompt = f"""Rewrite and condense the following draft into a polished legal summary of approximately {max_words} words. Draft: --- {draft_text} --- Summary:"""
try:
response = self.refinement_model.generate_content(prompt)
return response.text.strip()
except Exception as e:
return f"Refinement failed. Error: {e}"
def calculate_all_scores(self, original_text, summary):
if not summary or not original_text: return {}, 0.0, 0.0
rouge = self.rouge_scorer.score(original_text, summary)
rouge_scores = {"rouge1": rouge['rouge1'].fmeasure, "rouge2": rouge['rouge2'].fmeasure, "rougeL": rouge['rougeL'].fmeasure}
orig_emb = self.get_legalbert_embedding(original_text).reshape(1, -1)
sum_emb = self.get_legalbert_embedding(summary).reshape(1, -1)
consistency = cosine_similarity(orig_emb, sum_emb)[0][0]
orig_kw = self.extract_legal_terms(original_text)
sum_kw = self.extract_legal_terms(summary)
coverage = (len(orig_kw.intersection(sum_kw)) / len(orig_kw) * 100) if orig_kw else 0
return rouge_scores, consistency, coverage
summarizer = HybridLegalSummarizer()
print("\n Backend class is defined and ready.")
# --- 3. THE MAIN FUNCTION & GRADIO UI ---
def process_document(pdf_file, mode, word_limit, query):
if pdf_file is None:
return "Please upload a PDF file.", ""
try:
with pdfplumber.open(pdf_file.name) as pdf:
full_text = " ".join(page.extract_text() or "" for page in pdf.pages if page.extract_text())
cleaned_text = summarizer.preprocess_text(full_text)
if not cleaned_text:
return "Error: Could not extract usable text from the PDF.", ""
except Exception as e:
return f"Error reading PDF: {e}", ""
word_limit_internal = word_limit if mode == 'Summarizer' else 600
if mode == 'Summarizer':
draft = summarizer.generate_extractive_draft(cleaned_text, word_limit_internal)
final_output = summarizer.refine_with_llm(draft, word_limit_internal)
elif mode == 'RAG / Query':
if not query:
return "Error: Please enter a query for RAG mode.", ""
draft = summarizer.generate_rag_draft(cleaned_text, query, word_limit_internal)
final_output = summarizer.refine_with_llm(draft, word_limit_internal, user_query=query)
else:
return "Error: Invalid mode selected.", ""
final_rouge, final_consistency, final_coverage = summarizer.calculate_all_scores(cleaned_text, final_output)
metrics_str = (
f"ROUGE Scores: R1: {final_rouge.get('rouge1', 0):.3f}, R2: {final_rouge.get('rouge2', 0):.3f}, RL: {final_rouge.get('rougeL', 0):.3f}\n"
f" Factual Consistency (Semantic Similarity): {final_consistency:.3f}\n"
f" Legal Keyword Coverage: {final_coverage:.1f}%\n"
f"Words in Output: {len(final_output.split())}"
)
return final_output, metrics_str
with gr.Blocks() as demo:
gr.Markdown("# ⚖️ Hybrid Legal Document Analyzer")
gr.Markdown("Upload a legal PDF to either generate a summary or ask a specific question using RAG.")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(label="1. Upload PDF Document", file_types=[".pdf"])
mode_input = gr.Radio(["Summarizer", "RAG / Query"], label="2. Select Mode", value="Summarizer")
word_limit_input = gr.Slider(50, 10000, value=250, step=50, label="3. Word Limit (for Summarizer mode)", interactive=True)
query_input = gr.Textbox(label="4. Your Query (for RAG mode)", placeholder="e.g., What was the final verdict and why?", interactive=True, visible=False) # Start with query hidden
submit_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
output_summary = gr.Textbox(label="Final Result", lines=15)
output_metrics = gr.Textbox(label="Evaluation Metrics for Researchers", lines=5)
def toggle_inputs(mode):
if mode == "Summarizer":
return {
word_limit_input: gr.update(visible=True),
query_input: gr.update(visible=False)
}
else: # RAG / Query
return {
word_limit_input: gr.update(visible=False),
query_input: gr.update(visible=True)
}
mode_input.change(toggle_inputs, inputs=mode_input, outputs=[word_limit_input, query_input])
submit_btn.click(
fn=process_document,
inputs=[pdf_input, mode_input, word_limit_input, query_input],
outputs=[output_summary, output_metrics]
)
demo.launch()