sayanAIAI commited on
Commit
bb6c458
·
verified ·
1 Parent(s): 743f7ef

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +170 -125
main.py CHANGED
@@ -6,47 +6,44 @@ import json
6
  import re
7
  import logging
8
  from collections import Counter
 
9
 
10
  from flask import Flask, request, jsonify, render_template
11
  import torch
12
- from transformers import (
13
- AutoTokenizer,
14
- AutoModelForSeq2SeqLM,
15
- pipeline
16
- )
17
 
18
  # -------------------------
19
- # Basic app + logging
20
  # -------------------------
21
  app = Flask(__name__)
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger("summarizer")
24
 
25
  # -------------------------
26
- # Device selection (GPU if available)
27
  # -------------------------
28
  USE_GPU = torch.cuda.is_available()
29
  DEVICE = 0 if USE_GPU else -1
30
  logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE)
31
 
32
  # -------------------------
33
- # Models (quality-first)
34
  # -------------------------
35
- # Primary summarizer (higher-quality model)
36
- SUMMARIZER_MODEL = "facebook/bart-large-cnn" # quality-focused
37
- summ_tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL)
38
- summ_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL)
39
- summarizer = pipeline("summarization", model=summ_model, tokenizer=summ_tokenizer, device=DEVICE)
40
-
41
- # Parameter-generator (small instruction model to "think" and choose settings)
42
- # We keep this compact but capable. If you later want stronger reasoning, swap to flan-t5-base.
43
- PARAM_MODEL = "google/flan-t5-small"
44
  param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL)
45
  param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
46
  param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE)
47
 
48
  # -------------------------
49
- # Presets & utilities
50
  # -------------------------
51
  LENGTH_PRESETS = {
52
  "short": {"min_length": 20, "max_length": 60},
@@ -54,7 +51,6 @@ LENGTH_PRESETS = {
54
  "long": {"min_length": 130, "max_length": 300},
55
  }
56
 
57
- # Simple sentence splitter and extractive prefilter (helps focus abstractive model)
58
  _STOPWORDS = {
59
  "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
60
  }
@@ -64,22 +60,18 @@ def tokenize_sentences(text):
64
  return [s.strip() for s in sents if s.strip()]
65
 
66
  def extractive_prefilter(text, top_k=12):
67
- """
68
- Rank sentences by (non-stopword) word-frequency and return top_k sentences
69
- in original order joined. Useful for very long inputs.
70
- """
71
  sents = tokenize_sentences(text)
72
  if len(sents) <= top_k:
73
  return text
74
  words = re.findall(r"\w+", text.lower())
75
  freqs = Counter(w for w in words if w not in _STOPWORDS)
76
- scores = []
77
  for i, s in enumerate(sents):
78
  ws = re.findall(r"\w+", s.lower())
79
  score = sum(freqs.get(w, 0) for w in ws)
80
- scores.append((score, i, s))
81
- scores.sort(reverse=True)
82
- chosen = [s for _, _, s in sorted(scores[:top_k], key=lambda t: t[1])]
83
  return " ".join(chosen)
84
 
85
  def chunk_text_by_chars(text, max_chars=1500, overlap=200):
@@ -95,16 +87,47 @@ def chunk_text_by_chars(text, max_chars=1500, overlap=200):
95
  end = start + nl
96
  chunk = text[start:end]
97
  parts.append(chunk.strip())
98
- start = max(end - overlap, end) # move forward with overlap
99
  return parts
100
 
101
- def apply_tone_instruction(text, tone, target_sentences=None):
 
 
 
 
 
 
 
102
  """
103
- Build a clear instruction prompt for the summarizer based on tone/length.
 
104
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  tone = (tone or "neutral").lower()
106
  if tone == "bullet":
107
- instr = "Produce concise bullet points. Each bullet should be short (<=20 words) and focused. No extra commentary."
108
  elif tone == "short":
109
  ts = target_sentences or 1
110
  instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim."
@@ -113,47 +136,32 @@ def apply_tone_instruction(text, tone, target_sentences=None):
113
  elif tone == "casual":
114
  instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language."
115
  elif tone == "long":
116
- instr = "Provide a clear, structured summary in 4-8 sentences, covering key points and relevant context."
117
  else:
118
  instr = "Summarize the content in 2-3 sentences. Be clear and concise."
119
-
120
- instr += " Do not repeat the same information. Prefer rephrasing over copying."
121
-
122
  return f"{instr}\n\nText:\n{text}"
123
 
124
- # helper: extract first integer
125
- def _first_int_from_text(s, fallback=None):
126
- m = re.search(r"\d{1,4}", s)
127
- return int(m.group()) if m else fallback
128
-
129
- # -------------------------
130
- # Parameter generator (AI "thinking" module)
131
- # -------------------------
132
  def generate_summarization_config(text):
133
  """
134
- Use the instruction model to recommend: length(short|medium|long), min_words, max_words, tone.
135
- Falls back to heuristics on failure.
136
  """
137
  prompt = (
138
- "You are an assistant that recommends optimal summarization settings.\n"
139
  "Given the text, respond ONLY with single-line JSON EXACTLY like:\n"
140
  '{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n'
141
  "Text:\n'''"
142
- + text[:4000] +
143
- "'''"
144
  )
145
-
146
  try:
147
- # keep generation short and deterministic; use max_new_tokens (avoid max_length)
148
- gen = param_generator(
149
  prompt,
150
  max_new_tokens=64,
151
  num_beams=1,
152
  do_sample=False,
153
  early_stopping=True
154
- )
155
- out = gen[0].get("generated_text", "").strip()
156
- # attempt JSON parse
157
  cfg = None
158
  try:
159
  cfg = json.loads(out)
@@ -163,56 +171,44 @@ def generate_summarization_config(text):
163
  raw = j.group().replace("'", '"')
164
  cfg = json.loads(raw)
165
  if not cfg:
166
- raise ValueError("Param-generator output not parseable")
167
-
168
- length = cfg.get("length", "").lower()
169
- tone = cfg.get("tone", "").lower()
170
  min_w = cfg.get("min_words")
171
  max_w = cfg.get("max_words")
172
-
173
- if length not in ("short", "medium", "long"):
174
  words = len(text.split())
175
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
176
- if tone not in ("neutral", "formal", "casual", "bullet"):
177
  tone = "neutral"
178
-
179
- if not isinstance(min_w, int):
180
  min_w = _first_int_from_text(out, fallback=None)
181
- if not isinstance(max_w, int):
182
  max_w = _first_int_from_text(out[::-1], fallback=None)
183
-
184
- defaults = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
185
- dmin, dmax = defaults.get(length, (50,130))
186
- min_len = int(min_w) if isinstance(min_w, int) else dmin
187
- max_len = int(max_w) if isinstance(max_w, int) else dmax
188
-
189
  min_len = max(5, min(min_len, 2000))
190
- max_len = max(min_len + 5, min(max_len, 4000))
191
-
192
- logger.info("Param-generator chose: length=%s tone=%s min=%s max=%s", length, tone, min_len, max_len)
193
- return {"length": length, "min_length": min_len, "max_length": max_len, "tone": tone}
194
  except Exception as e:
195
- logger.exception("Param-generator failed; falling back to heuristic: %s", str(e))
196
  words = len(text.split())
197
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
198
- fallback = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
199
- mn, mx = fallback[length]
200
- return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
201
 
202
  # -------------------------
203
- # Two-stage summarization helpers
204
  # -------------------------
205
- def refine_and_combine(summaries_list, tone, final_target_sentences=None):
206
- """
207
- Combine chunk summaries and perform a refinement pass to produce cohesive final summary.
208
- """
209
  combined = "\n\n".join(summaries_list)
210
  if len(combined.split()) > 2000:
211
  combined = extractive_prefilter(combined, top_k=20)
212
-
213
  prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
214
-
215
- # heuristics for min/max
216
  tgt_sent = final_target_sentences or 3
217
  gen_kwargs = {
218
  "min_length": max(20, int(tgt_sent * 8)),
@@ -222,87 +218,135 @@ def refine_and_combine(summaries_list, tone, final_target_sentences=None):
222
  "no_repeat_ngram_size": 3,
223
  "do_sample": False,
224
  }
225
-
226
  try:
227
- out = summarizer(prompt, **gen_kwargs)[0]["summary_text"].strip()
 
 
 
228
  return out
229
  except Exception as e:
230
- logger.exception("Refine step failed: %s", e)
231
  return " ".join(summaries_list[:3])
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  # -------------------------
234
  # Routes
235
  # -------------------------
236
  @app.route("/")
237
  def home():
238
- # Ensure you have templates/index.html in place
239
  return render_template("index.html")
240
 
241
  @app.route("/summarize", methods=["POST"])
242
  def summarize_route():
243
  t0 = time.time()
244
- data = request.get_json(force=True)
245
- text = (data.get("text") or "")[:60000] # cap input to reasonable size
246
- requested_length = (data.get("length") or "medium").lower()
247
- requested_tone = (data.get("tone") or "neutral").lower()
 
248
 
249
  if not text or len(text.split()) < 5:
250
- return jsonify({"error": "Input too short."}), 400
251
 
252
  # 1) Decide settings (AI or explicit)
253
- if requested_length in ("auto", "ai") or requested_tone in ("auto", "ai"):
254
  cfg = generate_summarization_config(text)
255
- length_choice = cfg.get("length", "medium")
256
- tone_choice = cfg.get("tone", "neutral")
257
  preset_min = cfg.get("min_length")
258
  preset_max = cfg.get("max_length")
259
  else:
260
  length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
261
- tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet","short","long") else "neutral"
262
  preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"]
263
  preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"]
264
 
265
- # Map chosen length to target final sentences
266
- sentence_map = {"short": 1, "medium": 3, "long": 6}
267
- final_target_sentences = sentence_map.get(length_choice, 3)
268
-
269
- # 2) Prefilter if extremely long
270
  words_len = len(text.split())
271
- if words_len > 3500:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  text_for_chunks = extractive_prefilter(text, top_k=40)
273
  else:
274
  text_for_chunks = text
275
 
276
- # 3) Chunking
277
- chunks = chunk_text_by_chars(text_for_chunks, max_chars=1400, overlap=200)
278
- chunk_summaries = []
 
 
 
 
 
279
 
280
- # 4) Summarize each chunk
 
281
  for chunk in chunks:
282
  chunk_target = 1 if length_choice == "short" else 2
283
  chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral"
284
  prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target)
285
-
286
- gen_kwargs = {
287
- "min_length": 12 if chunk_target == 1 else 24,
288
- "max_length": 60 if chunk_target == 1 else 120,
289
- "num_beams": 5,
290
- "early_stopping": True,
291
- "no_repeat_ngram_size": 3,
292
- "do_sample": False,
293
- }
294
-
295
  try:
296
- out = summarizer(prompt, **gen_kwargs)[0]["summary_text"].strip()
 
297
  except Exception as e:
298
  logger.exception("Chunk summarization failed, using extractive fallback: %s", e)
299
  out = extractive_prefilter(chunk, top_k=3)
300
  chunk_summaries.append(out)
301
 
302
- # 5) Combine & refine
303
- final = refine_and_combine(chunk_summaries, tone_choice, final_target_sentences=final_target_sentences)
 
 
 
304
 
305
- # 6) Post-process for bullet tone
306
  if tone_choice == "bullet":
307
  parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
308
  bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
@@ -312,17 +356,18 @@ def summarize_route():
312
  meta = {
313
  "length_choice": length_choice,
314
  "tone": tone_choice,
 
 
315
  "chunks": len(chunks),
316
  "input_words": words_len,
317
  "time_seconds": round(elapsed, 2),
318
  "device": ("gpu" if USE_GPU else "cpu")
319
  }
320
-
321
  return jsonify({"summary": final, "meta": meta})
322
 
323
  # -------------------------
324
  # Run
325
  # -------------------------
326
  if __name__ == "__main__":
327
- # In production use Gunicorn; debug True here only for local testing
328
  app.run(host="0.0.0.0", port=7860, debug=False)
 
6
  import re
7
  import logging
8
  from collections import Counter
9
+ from typing import Optional
10
 
11
  from flask import Flask, request, jsonify, render_template
12
  import torch
13
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
 
 
14
 
15
  # -------------------------
16
+ # App + logging
17
  # -------------------------
18
  app = Flask(__name__)
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger("summarizer")
21
 
22
  # -------------------------
23
+ # Device selection
24
  # -------------------------
25
  USE_GPU = torch.cuda.is_available()
26
  DEVICE = 0 if USE_GPU else -1
27
  logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE)
28
 
29
  # -------------------------
30
+ # Model names (we'll load summarizers lazily)
31
  # -------------------------
32
+ PEGASUS_MODEL = "google/pegasus-large"
33
+ LED_MODEL = "allenai/led-large-16384"
34
+ PARAM_MODEL = "google/flan-t5-small" # instruction model for parameter generation
35
+
36
+ # caches for lazy-loaded pipelines
37
+ _SUMMARIZER_CACHE = {}
38
+
39
+ # load the small param-generator right away (keeps it small)
40
+ logger.info("Loading parameter generator model: %s", PARAM_MODEL)
41
  param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL)
42
  param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
43
  param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE)
44
 
45
  # -------------------------
46
+ # Presets & utils
47
  # -------------------------
48
  LENGTH_PRESETS = {
49
  "short": {"min_length": 20, "max_length": 60},
 
51
  "long": {"min_length": 130, "max_length": 300},
52
  }
53
 
 
54
  _STOPWORDS = {
55
  "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
56
  }
 
60
  return [s.strip() for s in sents if s.strip()]
61
 
62
  def extractive_prefilter(text, top_k=12):
 
 
 
 
63
  sents = tokenize_sentences(text)
64
  if len(sents) <= top_k:
65
  return text
66
  words = re.findall(r"\w+", text.lower())
67
  freqs = Counter(w for w in words if w not in _STOPWORDS)
68
+ scored = []
69
  for i, s in enumerate(sents):
70
  ws = re.findall(r"\w+", s.lower())
71
  score = sum(freqs.get(w, 0) for w in ws)
72
+ scored.append((score, i, s))
73
+ scored.sort(reverse=True)
74
+ chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])]
75
  return " ".join(chosen)
76
 
77
  def chunk_text_by_chars(text, max_chars=1500, overlap=200):
 
87
  end = start + nl
88
  chunk = text[start:end]
89
  parts.append(chunk.strip())
90
+ start = max(end - overlap, end)
91
  return parts
92
 
93
+ def _first_int_from_text(s, fallback=None):
94
+ m = re.search(r"\d{1,4}", s)
95
+ return int(m.group()) if m else fallback
96
+
97
+ # -------------------------
98
+ # Lazy summarizer loader
99
+ # -------------------------
100
+ def get_summarizer(model_key: str):
101
  """
102
+ Returns a pipeline summarizer for 'pegasus' or 'led', loading it lazily.
103
+ model_key: "pegasus" or "led"
104
  """
105
+ model_key = model_key.lower()
106
+ if model_key in _SUMMARIZER_CACHE:
107
+ return _SUMMARIZER_CACHE[model_key]
108
+
109
+ if model_key == "pegasus":
110
+ model_name = PEGASUS_MODEL
111
+ elif model_key == "led":
112
+ model_name = LED_MODEL
113
+ else:
114
+ raise ValueError("Unknown model_key: " + str(model_key))
115
+
116
+ logger.info("Loading summarizer model '%s' (%s) on device %s ...", model_key, model_name, DEVICE)
117
+ tok = AutoTokenizer.from_pretrained(model_name)
118
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
119
+ pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
120
+ _SUMMARIZER_CACHE[model_key] = pipe
121
+ logger.info("Loaded summarizer '%s' successfully.", model_key)
122
+ return pipe
123
+
124
+ # -------------------------
125
+ # Prompt and decision logic
126
+ # -------------------------
127
+ def apply_tone_instruction(text, tone, target_sentences=None):
128
  tone = (tone or "neutral").lower()
129
  if tone == "bullet":
130
+ instr = "Produce concise bullet points. Each bullet short (<=20 words). No extra commentary."
131
  elif tone == "short":
132
  ts = target_sentences or 1
133
  instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim."
 
136
  elif tone == "casual":
137
  instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language."
138
  elif tone == "long":
139
+ instr = "Provide a clear, structured summary in 4-8 sentences covering key points and context."
140
  else:
141
  instr = "Summarize the content in 2-3 sentences. Be clear and concise."
142
+ instr += " Do not repeat information; prefer rephrasing."
 
 
143
  return f"{instr}\n\nText:\n{text}"
144
 
 
 
 
 
 
 
 
 
145
  def generate_summarization_config(text):
146
  """
147
+ Ask small instruction model for settings; fallback to heuristic.
148
+ Returns dict with keys: length, min_length, max_length, tone
149
  """
150
  prompt = (
151
+ "You are an assistant that recommends summarization settings.\n"
152
  "Given the text, respond ONLY with single-line JSON EXACTLY like:\n"
153
  '{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n'
154
  "Text:\n'''"
155
+ + text[:4000] + "'''"
 
156
  )
 
157
  try:
158
+ out = param_generator(
 
159
  prompt,
160
  max_new_tokens=64,
161
  num_beams=1,
162
  do_sample=False,
163
  early_stopping=True
164
+ )[0].get("generated_text","").strip()
 
 
165
  cfg = None
166
  try:
167
  cfg = json.loads(out)
 
171
  raw = j.group().replace("'", '"')
172
  cfg = json.loads(raw)
173
  if not cfg:
174
+ raise ValueError("Unparseable param-generator output")
175
+ length = cfg.get("length","").lower()
176
+ tone = cfg.get("tone","").lower()
 
177
  min_w = cfg.get("min_words")
178
  max_w = cfg.get("max_words")
179
+ if length not in ("short","medium","long"):
 
180
  words = len(text.split())
181
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
182
+ if tone not in ("neutral","formal","casual","bullet"):
183
  tone = "neutral"
184
+ if not isinstance(min_w,int):
 
185
  min_w = _first_int_from_text(out, fallback=None)
186
+ if not isinstance(max_w,int):
187
  max_w = _first_int_from_text(out[::-1], fallback=None)
188
+ defaults = {"short":(15,50),"medium":(50,130),"long":(130,300)}
189
+ dmin,dmax = defaults.get(length,(50,130))
190
+ min_len = int(min_w) if isinstance(min_w,int) else dmin
191
+ max_len = int(max_w) if isinstance(max_w,int) else dmax
 
 
192
  min_len = max(5, min(min_len, 2000))
193
+ max_len = max(min_len+5, min(max_len, 4000))
194
+ logger.info("Param-generator chose length=%s tone=%s min=%s max=%s", length, tone, min_len, max_len)
195
+ return {"length":length,"min_length":min_len,"max_length":max_len,"tone":tone}
 
196
  except Exception as e:
197
+ logger.exception("Param-generator failed: %s", e)
198
  words = len(text.split())
199
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
200
+ fallback = {"short":(15,50),"medium":(50,130),"long":(130,300)}
201
+ mn,mx = fallback[length]
202
+ return {"length":length,"min_length":mn,"max_length":mx,"tone":"neutral"}
203
 
204
  # -------------------------
205
+ # Two-stage summarization (chunk -> chunk summaries -> refine)
206
  # -------------------------
207
+ def refine_and_combine(summaries_list, tone, final_target_sentences=None, summarizer_pipe=None):
 
 
 
208
  combined = "\n\n".join(summaries_list)
209
  if len(combined.split()) > 2000:
210
  combined = extractive_prefilter(combined, top_k=20)
 
211
  prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
 
 
212
  tgt_sent = final_target_sentences or 3
213
  gen_kwargs = {
214
  "min_length": max(20, int(tgt_sent * 8)),
 
218
  "no_repeat_ngram_size": 3,
219
  "do_sample": False,
220
  }
 
221
  try:
222
+ if summarizer_pipe is None:
223
+ # fallback to pegasus by default (if pipe not provided)
224
+ summarizer_pipe = get_summarizer("pegasus")
225
+ out = summarizer_pipe(prompt, **gen_kwargs)[0]["summary_text"].strip()
226
  return out
227
  except Exception as e:
228
+ logger.exception("Refine failed: %s", e)
229
  return " ".join(summaries_list[:3])
230
 
231
+ # -------------------------
232
+ # Model-specific generation helper
233
+ # -------------------------
234
+ def summarize_with_model(pipe, text_prompt, short_target=False):
235
+ """
236
+ Use model pipeline with conservative and model-appropriate generation settings.
237
+ short_target: if True use shorter min/max suitable for concise outputs
238
+ """
239
+ # heuristics: if pipe is LED (model name in tied tokenizer), allow larger max_length
240
+ model_name = getattr(pipe.model.config, "name_or_path", "") or ""
241
+ is_led = "led" in model_name or "longformer" in model_name or "allenai" in model_name and "led" in model_name
242
+ if short_target:
243
+ min_l = 12
244
+ max_l = 60
245
+ else:
246
+ min_l = 24
247
+ max_l = 140 if not is_led else 400 # LED can handle longer outputs
248
+ gen_kwargs = {
249
+ "min_length": min_l,
250
+ "max_length": max_l,
251
+ "num_beams": 5 if not is_led else 4,
252
+ "early_stopping": True,
253
+ "no_repeat_ngram_size": 3,
254
+ "do_sample": False,
255
+ }
256
+ return pipe(text_prompt, **gen_kwargs)[0]["summary_text"].strip()
257
+
258
  # -------------------------
259
  # Routes
260
  # -------------------------
261
  @app.route("/")
262
  def home():
 
263
  return render_template("index.html")
264
 
265
  @app.route("/summarize", methods=["POST"])
266
  def summarize_route():
267
  t0 = time.time()
268
+ data = request.get_json(force=True) or {}
269
+ text = (data.get("text") or "")[:90000]
270
+ user_model_pref = (data.get("model") or "auto").lower() # 'pegasus' | 'led' | 'auto'
271
+ requested_length = (data.get("length") or "auto").lower() # short|medium|long|auto
272
+ requested_tone = (data.get("tone") or "auto").lower() # neutral|formal|casual|bullet|auto
273
 
274
  if not text or len(text.split()) < 5:
275
+ return jsonify({"error":"Input too short."}), 400
276
 
277
  # 1) Decide settings (AI or explicit)
278
+ if requested_length in ("auto","ai") or requested_tone in ("auto","ai"):
279
  cfg = generate_summarization_config(text)
280
+ length_choice = cfg.get("length","medium")
281
+ tone_choice = cfg.get("tone","neutral")
282
  preset_min = cfg.get("min_length")
283
  preset_max = cfg.get("max_length")
284
  else:
285
  length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
286
+ tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
287
  preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"]
288
  preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"]
289
 
290
+ # 2) Model selection (user preference or auto)
291
+ # auto rules: if user specifically asked 'led' or param-generator picked long / input is very long -> led
 
 
 
292
  words_len = len(text.split())
293
+ prefer_led = False
294
+ if user_model_pref == "led":
295
+ prefer_led = True
296
+ elif user_model_pref == "pegasus":
297
+ prefer_led = False
298
+ else: # auto
299
+ if length_choice == "long" or words_len > 3000:
300
+ prefer_led = True
301
+ else:
302
+ prefer_led = False
303
+
304
+ model_key = "led" if prefer_led else "pegasus"
305
+ # get the pipeline (lazy load)
306
+ try:
307
+ summarizer_pipe = get_summarizer(model_key)
308
+ except Exception as e:
309
+ logger.exception("Failed to load summarizer '%s': %s", model_key, e)
310
+ # fallback to pegasus if led fails
311
+ summarizer_pipe = get_summarizer("pegasus")
312
+ model_key = "pegasus"
313
+
314
+ # 3) Prefilter very long inputs (if not using LED)
315
+ if not prefer_led and words_len > 2500:
316
  text_for_chunks = extractive_prefilter(text, top_k=40)
317
  else:
318
  text_for_chunks = text
319
 
320
+ # 4) Chunking: choose chunk size depending on model
321
+ if model_key == "led":
322
+ chunk_max_chars = 8000 # LED can handle larger chunks
323
+ chunk_overlap = 400
324
+ else:
325
+ chunk_max_chars = 1400
326
+ chunk_overlap = 200
327
+ chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max_chars, overlap=chunk_overlap)
328
 
329
+ # 5) Summarize each chunk
330
+ chunk_summaries = []
331
  for chunk in chunks:
332
  chunk_target = 1 if length_choice == "short" else 2
333
  chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral"
334
  prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target)
 
 
 
 
 
 
 
 
 
 
335
  try:
336
+ # choose short_target True for tiny chunk summaries
337
+ out = summarize_with_model(summarizer_pipe, prompt, short_target=(chunk_target==1))
338
  except Exception as e:
339
  logger.exception("Chunk summarization failed, using extractive fallback: %s", e)
340
  out = extractive_prefilter(chunk, top_k=3)
341
  chunk_summaries.append(out)
342
 
343
+ # 6) Combine + refine using the same model for consistency (or prefer Pegasus for elegant refinement)
344
+ refine_model_key = model_key if model_key == "led" else "pegasus"
345
+ refine_pipe = get_summarizer(refine_model_key)
346
+ final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice,3)
347
+ final = refine_and_combine(chunk_summaries, tone_choice, final_target_sentences, summarizer_pipe=refine_pipe)
348
 
349
+ # 7) Post-process bullet tone
350
  if tone_choice == "bullet":
351
  parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
352
  bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
 
356
  meta = {
357
  "length_choice": length_choice,
358
  "tone": tone_choice,
359
+ "model_used": model_key,
360
+ "refine_model": refine_model_key,
361
  "chunks": len(chunks),
362
  "input_words": words_len,
363
  "time_seconds": round(elapsed, 2),
364
  "device": ("gpu" if USE_GPU else "cpu")
365
  }
 
366
  return jsonify({"summary": final, "meta": meta})
367
 
368
  # -------------------------
369
  # Run
370
  # -------------------------
371
  if __name__ == "__main__":
372
+ # debug=False for production; use Gunicorn in deployment
373
  app.run(host="0.0.0.0", port=7860, debug=False)