abeergandhi commited on
Commit
afc1b64
·
verified ·
1 Parent(s): 194d7c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -34
app.py CHANGED
@@ -16,7 +16,7 @@ from rouge_score import rouge_scorer
16
  from nltk.tokenize import sent_tokenize
17
 
18
  # Download NLTK data during the build process
19
- nltk.download('punkt_tab')
20
  nltk.download('stopwords')
21
  print(" NLTK data downloaded.")
22
 
@@ -70,11 +70,32 @@ class HybridLegalSummarizer:
70
  self.refinement_model = genai.GenerativeModel('models/gemini-2.5-flash')
71
 
72
  def get_legalbert_embedding(self, text):
 
73
  inputs = legalbert_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
74
  with torch.no_grad():
75
  outputs = legalbert_model(**inputs)
76
  return torch.mean(outputs.last_hidden_state, dim=1).squeeze().numpy()
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def preprocess_text(self, text):
79
  text = re.sub(r'\s+', ' ', text)
80
  sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]
@@ -89,12 +110,18 @@ class HybridLegalSummarizer:
89
  def generate_extractive_draft(self, text, max_words=200):
90
  sentences = sent_tokenize(text)
91
  if not sentences: return ""
92
- # USE THE BATCH FUNCTION
 
 
93
  sentence_embeddings = self.get_legalbert_embeddings_batch(sentences)
94
  if sentence_embeddings.size == 0: return ""
95
 
96
  centroid = np.mean(sentence_embeddings, axis=0)
97
- scores = [cosine_similarity(emb.reshape(1, -1), centroid.reshape(1, -1))[0][0] for emb in sentence_embeddings]
 
 
 
 
98
  ranked_indices = np.argsort(scores)[::-1]
99
  selected = []
100
  current_count = 0
@@ -111,12 +138,17 @@ class HybridLegalSummarizer:
111
  def generate_rag_draft(self, text, user_query, max_words=600):
112
  sentences = sent_tokenize(text)
113
  if not sentences: return ""
114
- # USE THE BATCH FUNCTION FOR BOTH
 
 
115
  query_embedding = self.get_legalbert_embeddings_batch([user_query])[0]
116
  sentence_embeddings = self.get_legalbert_embeddings_batch(sentences)
117
  if sentence_embeddings.size == 0: return ""
118
-
119
- scores = [cosine_similarity(emb.reshape(1, -1), query_embedding.reshape(1, -1))[0][0] for emb in sentence_embeddings]
 
 
 
120
  ranked_indices = np.argsort(scores)[::-1]
121
  selected = []
122
  current_count = 0
@@ -141,35 +173,16 @@ class HybridLegalSummarizer:
141
  return response.text.strip()
142
  except Exception as e:
143
  return f"Refinement failed. Error: {e}"
144
- # NEW BATCH FUNCTION TO ADD
145
- def get_legalbert_embeddings_batch(self, sentences_list):
146
- if not sentences_list:
147
- return np.array([])
148
-
149
- # 1. Tokenize the whole batch at once
150
- inputs = legalbert_tokenizer(
151
- sentences_list,
152
- return_tensors="pt",
153
- padding=True,
154
- truncation=True,
155
- max_length=512
156
- )
157
-
158
- # 2. Run the model ONCE on the whole batch
159
- with torch.no_grad():
160
- outputs = legalbert_model(**inputs)
161
-
162
- # 3. Use "mean pooling" to get the sentence embeddings
163
- # This averages the tokens for each sentence to get a single vector
164
- embeddings = torch.mean(outputs.last_hidden_state, dim=1)
165
- return embeddings.cpu().numpy()
166
-
167
  def calculate_all_scores(self, original_text, summary):
168
  if not summary or not original_text: return {}, 0.0, 0.0
169
  rouge = self.rouge_scorer.score(original_text, summary)
170
  rouge_scores = {"rouge1": rouge['rouge1'].fmeasure, "rouge2": rouge['rouge2'].fmeasure, "rougeL": rouge['rougeL'].fmeasure}
171
 
172
- # --- TEMPORARILY DISABLED FOR SPEED ---
 
 
 
173
  # orig_emb = self.get_legalbert_embedding(original_text).reshape(1, -1)
174
  # sum_emb = self.get_legalbert_embedding(summary).reshape(1, -1)
175
  # consistency = cosine_similarity(orig_emb, sum_emb)[0][0]
@@ -177,10 +190,10 @@ class HybridLegalSummarizer:
177
  # sum_kw = self.extract_legal_terms(summary)
178
  # coverage = (len(orig_kw.intersection(sum_kw)) / len(orig_kw) * 100) if orig_kw else 0
179
 
180
- # Just return dummy values for the disabled metrics
181
  consistency = 0.0
182
  coverage = 0.0
183
- # --- END TEMPORARY FIX ---
184
 
185
  return rouge_scores, consistency, coverage
186
 
@@ -214,11 +227,14 @@ def process_document(pdf_file, mode, word_limit, query):
214
  else:
215
  return "Error: Invalid mode selected.", ""
216
 
 
217
  final_rouge, final_consistency, final_coverage = summarizer.calculate_all_scores(cleaned_text, final_output)
 
 
218
  metrics_str = (
219
  f"ROUGE Scores: R1: {final_rouge.get('rouge1', 0):.3f}, R2: {final_rouge.get('rouge2', 0):.3f}, RL: {final_rouge.get('rougeL', 0):.3f}\n"
220
- f" Factual Consistency (Semantic Similarity): {final_consistency:.3f}\n"
221
- f" Legal Keyword Coverage: {final_coverage:.1f}%\n"
222
  f"Words in Output: {len(final_output.split())}"
223
  )
224
  return final_output, metrics_str
 
16
  from nltk.tokenize import sent_tokenize
17
 
18
  # Download NLTK data during the build process
19
+ nltk.download('punkt') # Changed 'punkt_tab' to 'punkt' for robustness
20
  nltk.download('stopwords')
21
  print(" NLTK data downloaded.")
22
 
 
70
  self.refinement_model = genai.GenerativeModel('models/gemini-2.5-flash')
71
 
72
  def get_legalbert_embedding(self, text):
73
+ # This function is now only used by the slow calculate_all_scores
74
  inputs = legalbert_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
75
  with torch.no_grad():
76
  outputs = legalbert_model(**inputs)
77
  return torch.mean(outputs.last_hidden_state, dim=1).squeeze().numpy()
78
 
79
+ # --- OPTIMIZATION 1: NEW BATCH FUNCTION ---
80
+ # This function processes a LIST of sentences at once
81
+ def get_legalbert_embeddings_batch(self, sentences_list):
82
+ if not sentences_list:
83
+ return np.array([])
84
+ # 1. Tokenize the whole batch
85
+ inputs = legalbert_tokenizer(
86
+ sentences_list,
87
+ return_tensors="pt",
88
+ padding=True,
89
+ truncation=True,
90
+ max_length=512
91
+ )
92
+ # 2. Run the model ONCE
93
+ with torch.no_grad():
94
+ outputs = legalbert_model(**inputs)
95
+ # 3. Get all embeddings
96
+ embeddings = torch.mean(outputs.last_hidden_state, dim=1)
97
+ return embeddings.cpu().numpy()
98
+
99
  def preprocess_text(self, text):
100
  text = re.sub(r'\s+', ' ', text)
101
  sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]
 
110
  def generate_extractive_draft(self, text, max_words=200):
111
  sentences = sent_tokenize(text)
112
  if not sentences: return ""
113
+
114
+ # --- OPTIMIZATION 2: BATCHED & VECTORIZED ---
115
+ # 1. Get all embeddings at once (replaces a for loop)
116
  sentence_embeddings = self.get_legalbert_embeddings_batch(sentences)
117
  if sentence_embeddings.size == 0: return ""
118
 
119
  centroid = np.mean(sentence_embeddings, axis=0)
120
+
121
+ # 2. Get all scores at once (replaces another for loop)
122
+ scores = cosine_similarity(sentence_embeddings, centroid.reshape(1, -1)).flatten()
123
+ # --- END OPTIMIZATION ---
124
+
125
  ranked_indices = np.argsort(scores)[::-1]
126
  selected = []
127
  current_count = 0
 
138
  def generate_rag_draft(self, text, user_query, max_words=600):
139
  sentences = sent_tokenize(text)
140
  if not sentences: return ""
141
+
142
+ # --- OPTIMIZATION 2: BATCHED & VECTORIZED ---
143
+ # 1. Get query and sentence embeddings at once
144
  query_embedding = self.get_legalbert_embeddings_batch([user_query])[0]
145
  sentence_embeddings = self.get_legalbert_embeddings_batch(sentences)
146
  if sentence_embeddings.size == 0: return ""
147
+
148
+ # 2. Get all scores at once (replaces a for loop)
149
+ scores = cosine_similarity(sentence_embeddings, query_embedding.reshape(1, -1)).flatten()
150
+ # --- END OPTIMIZATION ---
151
+
152
  ranked_indices = np.argsort(scores)[::-1]
153
  selected = []
154
  current_count = 0
 
173
  return response.text.strip()
174
  except Exception as e:
175
  return f"Refinement failed. Error: {e}"
176
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def calculate_all_scores(self, original_text, summary):
178
  if not summary or not original_text: return {}, 0.0, 0.0
179
  rouge = self.rouge_scorer.score(original_text, summary)
180
  rouge_scores = {"rouge1": rouge['rouge1'].fmeasure, "rouge2": rouge['rouge2'].fmeasure, "rougeL": rouge['rougeL'].fmeasure}
181
 
182
+ # --- OPTIMIZATION 3: DISABLED SLOW METRICS ---
183
+ # The following lines are too slow for a live demo as they run
184
+ # the model on the *entire* text.
185
+
186
  # orig_emb = self.get_legalbert_embedding(original_text).reshape(1, -1)
187
  # sum_emb = self.get_legalbert_embedding(summary).reshape(1, -1)
188
  # consistency = cosine_similarity(orig_emb, sum_emb)[0][0]
 
190
  # sum_kw = self.extract_legal_terms(summary)
191
  # coverage = (len(orig_kw.intersection(sum_kw)) / len(orig_kw) * 100) if orig_kw else 0
192
 
193
+ # Return dummy values for a fast demo.
194
  consistency = 0.0
195
  coverage = 0.0
196
+ # --- END OPTIMIZATION ---
197
 
198
  return rouge_scores, consistency, coverage
199
 
 
227
  else:
228
  return "Error: Invalid mode selected.", ""
229
 
230
+ # This part is now fast because calculate_all_scores is fast
231
  final_rouge, final_consistency, final_coverage = summarizer.calculate_all_scores(cleaned_text, final_output)
232
+
233
+ # Updated metrics string to show which values are disabled
234
  metrics_str = (
235
  f"ROUGE Scores: R1: {final_rouge.get('rouge1', 0):.3f}, R2: {final_rouge.get('rouge2', 0):.3f}, RL: {final_rouge.get('rougeL', 0):.3f}\n"
236
+ f" Factual Consistency (Semantic Similarity): {final_consistency:.3f} (Disabled for demo speed)\n"
237
+ f" Legal Keyword Coverage: {final_coverage:.1f}% (Disabled for demo speed)\n"
238
  f"Words in Output: {len(final_output.split())}"
239
  )
240
  return final_output, metrics_str