abeergandhi commited on
Commit
5366c36
·
verified ·
1 Parent(s): 1195b25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -66
app.py CHANGED
@@ -16,7 +16,7 @@ from rouge_score import rouge_scorer
16
  from nltk.tokenize import sent_tokenize
17
 
18
  # Download NLTK data during the build process
19
- nltk.download('punkt_tab') # Changed 'punkt_tab' to 'punkt' for robustness
20
  nltk.download('stopwords')
21
  print(" NLTK data downloaded.")
22
 
@@ -70,32 +70,11 @@ class HybridLegalSummarizer:
70
  self.refinement_model = genai.GenerativeModel('models/gemini-2.5-flash')
71
 
72
  def get_legalbert_embedding(self, text):
73
- # This function is now only used by the slow calculate_all_scores
74
  inputs = legalbert_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
75
  with torch.no_grad():
76
  outputs = legalbert_model(**inputs)
77
  return torch.mean(outputs.last_hidden_state, dim=1).squeeze().numpy()
78
 
79
- # --- OPTIMIZATION 1: NEW BATCH FUNCTION ---
80
- # This function processes a LIST of sentences at once
81
- def get_legalbert_embeddings_batch(self, sentences_list):
82
- if not sentences_list:
83
- return np.array([])
84
- # 1. Tokenize the whole batch
85
- inputs = legalbert_tokenizer(
86
- sentences_list,
87
- return_tensors="pt",
88
- padding=True,
89
- truncation=True,
90
- max_length=512
91
- )
92
- # 2. Run the model ONCE
93
- with torch.no_grad():
94
- outputs = legalbert_model(**inputs)
95
- # 3. Get all embeddings
96
- embeddings = torch.mean(outputs.last_hidden_state, dim=1)
97
- return embeddings.cpu().numpy()
98
-
99
  def preprocess_text(self, text):
100
  text = re.sub(r'\s+', ' ', text)
101
  sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]
@@ -110,18 +89,9 @@ class HybridLegalSummarizer:
110
  def generate_extractive_draft(self, text, max_words=200):
111
  sentences = sent_tokenize(text)
112
  if not sentences: return ""
113
-
114
- # --- OPTIMIZATION 2: BATCHED & VECTORIZED ---
115
- # 1. Get all embeddings at once (replaces a for loop)
116
- sentence_embeddings = self.get_legalbert_embeddings_batch(sentences)
117
- if sentence_embeddings.size == 0: return ""
118
-
119
  centroid = np.mean(sentence_embeddings, axis=0)
120
-
121
- # 2. Get all scores at once (replaces another for loop)
122
- scores = cosine_similarity(sentence_embeddings, centroid.reshape(1, -1)).flatten()
123
- # --- END OPTIMIZATION ---
124
-
125
  ranked_indices = np.argsort(scores)[::-1]
126
  selected = []
127
  current_count = 0
@@ -138,17 +108,9 @@ class HybridLegalSummarizer:
138
  def generate_rag_draft(self, text, user_query, max_words=600):
139
  sentences = sent_tokenize(text)
140
  if not sentences: return ""
141
-
142
- # --- OPTIMIZATION 2: BATCHED & VECTORIZED ---
143
- # 1. Get query and sentence embeddings at once
144
- query_embedding = self.get_legalbert_embeddings_batch([user_query])[0]
145
- sentence_embeddings = self.get_legalbert_embeddings_batch(sentences)
146
- if sentence_embeddings.size == 0: return ""
147
-
148
- # 2. Get all scores at once (replaces a for loop)
149
- scores = cosine_similarity(sentence_embeddings, query_embedding.reshape(1, -1)).flatten()
150
- # --- END OPTIMIZATION ---
151
-
152
  ranked_indices = np.argsort(scores)[::-1]
153
  selected = []
154
  current_count = 0
@@ -178,23 +140,12 @@ class HybridLegalSummarizer:
178
  if not summary or not original_text: return {}, 0.0, 0.0
179
  rouge = self.rouge_scorer.score(original_text, summary)
180
  rouge_scores = {"rouge1": rouge['rouge1'].fmeasure, "rouge2": rouge['rouge2'].fmeasure, "rougeL": rouge['rougeL'].fmeasure}
181
-
182
- # --- OPTIMIZATION 3: DISABLED SLOW METRICS ---
183
- # The following lines are too slow for a live demo as they run
184
- # the model on the *entire* text.
185
-
186
- # orig_emb = self.get_legalbert_embedding(original_text).reshape(1, -1)
187
- # sum_emb = self.get_legalbert_embedding(summary).reshape(1, -1)
188
- # consistency = cosine_similarity(orig_emb, sum_emb)[0][0]
189
- # orig_kw = self.extract_legal_terms(original_text)
190
- # sum_kw = self.extract_legal_terms(summary)
191
- # coverage = (len(orig_kw.intersection(sum_kw)) / len(orig_kw) * 100) if orig_kw else 0
192
-
193
- # Return dummy values for a fast demo.
194
- consistency = 0.0
195
- coverage = 0.0
196
- # --- END OPTIMIZATION ---
197
-
198
  return rouge_scores, consistency, coverage
199
 
200
  summarizer = HybridLegalSummarizer()
@@ -227,14 +178,11 @@ def process_document(pdf_file, mode, word_limit, query):
227
  else:
228
  return "Error: Invalid mode selected.", ""
229
 
230
- # This part is now fast because calculate_all_scores is fast
231
  final_rouge, final_consistency, final_coverage = summarizer.calculate_all_scores(cleaned_text, final_output)
232
-
233
- # Updated metrics string to show which values are disabled
234
  metrics_str = (
235
  f"ROUGE Scores: R1: {final_rouge.get('rouge1', 0):.3f}, R2: {final_rouge.get('rouge2', 0):.3f}, RL: {final_rouge.get('rougeL', 0):.3f}\n"
236
- f" Factual Consistency (Semantic Similarity): {final_consistency:.3f} (Disabled for demo speed)\n"
237
- f" Legal Keyword Coverage: {final_coverage:.1f}% (Disabled for demo speed)\n"
238
  f"Words in Output: {len(final_output.split())}"
239
  )
240
  return final_output, metrics_str
 
16
  from nltk.tokenize import sent_tokenize
17
 
18
  # Download NLTK data during the build process
19
+ nltk.download('punkt_tab')
20
  nltk.download('stopwords')
21
  print(" NLTK data downloaded.")
22
 
 
70
  self.refinement_model = genai.GenerativeModel('models/gemini-2.5-flash')
71
 
72
  def get_legalbert_embedding(self, text):
 
73
  inputs = legalbert_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
74
  with torch.no_grad():
75
  outputs = legalbert_model(**inputs)
76
  return torch.mean(outputs.last_hidden_state, dim=1).squeeze().numpy()
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def preprocess_text(self, text):
79
  text = re.sub(r'\s+', ' ', text)
80
  sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]
 
89
  def generate_extractive_draft(self, text, max_words=200):
90
  sentences = sent_tokenize(text)
91
  if not sentences: return ""
92
+ sentence_embeddings = np.array([self.get_legalbert_embedding(sent) for sent in sentences])
 
 
 
 
 
93
  centroid = np.mean(sentence_embeddings, axis=0)
94
+ scores = [cosine_similarity(emb.reshape(1, -1), centroid.reshape(1, -1))[0][0] for emb in sentence_embeddings]
 
 
 
 
95
  ranked_indices = np.argsort(scores)[::-1]
96
  selected = []
97
  current_count = 0
 
108
  def generate_rag_draft(self, text, user_query, max_words=600):
109
  sentences = sent_tokenize(text)
110
  if not sentences: return ""
111
+ query_embedding = self.get_legalbert_embedding(user_query)
112
+ sentence_embeddings = np.array([self.get_legalbert_embedding(sent) for sent in sentences])
113
+ scores = [cosine_similarity(emb.reshape(1, -1), query_embedding.reshape(1, -1))[0][0] for emb in sentence_embeddings]
 
 
 
 
 
 
 
 
114
  ranked_indices = np.argsort(scores)[::-1]
115
  selected = []
116
  current_count = 0
 
140
  if not summary or not original_text: return {}, 0.0, 0.0
141
  rouge = self.rouge_scorer.score(original_text, summary)
142
  rouge_scores = {"rouge1": rouge['rouge1'].fmeasure, "rouge2": rouge['rouge2'].fmeasure, "rougeL": rouge['rougeL'].fmeasure}
143
+ orig_emb = self.get_legalbert_embedding(original_text).reshape(1, -1)
144
+ sum_emb = self.get_legalbert_embedding(summary).reshape(1, -1)
145
+ consistency = cosine_similarity(orig_emb, sum_emb)[0][0]
146
+ orig_kw = self.extract_legal_terms(original_text)
147
+ sum_kw = self.extract_legal_terms(summary)
148
+ coverage = (len(orig_kw.intersection(sum_kw)) / len(orig_kw) * 100) if orig_kw else 0
 
 
 
 
 
 
 
 
 
 
 
149
  return rouge_scores, consistency, coverage
150
 
151
  summarizer = HybridLegalSummarizer()
 
178
  else:
179
  return "Error: Invalid mode selected.", ""
180
 
 
181
  final_rouge, final_consistency, final_coverage = summarizer.calculate_all_scores(cleaned_text, final_output)
 
 
182
  metrics_str = (
183
  f"ROUGE Scores: R1: {final_rouge.get('rouge1', 0):.3f}, R2: {final_rouge.get('rouge2', 0):.3f}, RL: {final_rouge.get('rougeL', 0):.3f}\n"
184
+ f" Factual Consistency (Semantic Similarity): {final_consistency:.3f}\n"
185
+ f" Legal Keyword Coverage: {final_coverage:.1f}%\n"
186
  f"Words in Output: {len(final_output.split())}"
187
  )
188
  return final_output, metrics_str