sayanAIAI commited on
Commit
743f7ef
·
verified ·
1 Parent(s): 08b9e9f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +213 -99
main.py CHANGED
@@ -1,32 +1,52 @@
1
  import os
2
  os.environ['HF_HOME'] = '/tmp'
3
 
4
- from flask import Flask, request, jsonify, render_template
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
- import json, re, time
7
  import logging
8
- # ... keep your existing imports
 
 
 
 
 
 
 
 
9
 
10
- logger = logging.getLogger(__name__)
 
 
11
  app = Flask(__name__)
 
 
 
 
 
 
 
 
 
12
 
13
  # -------------------------
14
- # Models (CPU as requested)
15
  # -------------------------
16
- # Primary summarizer: higher-quality model
17
- MODEL_NAME = "facebook/bart-large-cnn"
18
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
20
- summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1) # CPU
21
 
22
- # Small instruction model to choose length/tone when "auto" is requested
23
- PARAM_MODEL_NAME = "google/flan-t5-small"
24
- param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL_NAME)
25
- param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL_NAME)
26
- param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=-1) # CPU
 
27
 
28
  # -------------------------
29
- # Presets & helpers
30
  # -------------------------
31
  LENGTH_PRESETS = {
32
  "short": {"min_length": 20, "max_length": 60},
@@ -34,6 +54,34 @@ LENGTH_PRESETS = {
34
  "long": {"min_length": 130, "max_length": 300},
35
  }
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def chunk_text_by_chars(text, max_chars=1500, overlap=200):
38
  if len(text) <= max_chars:
39
  return [text]
@@ -47,168 +95,234 @@ def chunk_text_by_chars(text, max_chars=1500, overlap=200):
47
  end = start + nl
48
  chunk = text[start:end]
49
  parts.append(chunk.strip())
50
- start = end - overlap
51
  return parts
52
 
53
- def apply_tone_instruction(text, tone):
 
 
 
54
  tone = (tone or "neutral").lower()
55
- if tone == "formal":
56
- instr = "Summarize in a formal, professional tone:"
 
 
 
 
 
57
  elif tone == "casual":
58
- instr = "Summarize in a casual, conversational tone:"
59
- elif tone == "bullet":
60
- instr = "Summarize into short bullet points:"
61
  else:
62
- instr = "Summarize:"
63
- return f"{instr}\n\n{text}"
64
 
65
- # small regex int extractor
66
- def _first_int_from_text(s, fallback=None):
67
- m = re.search(r"\d{1,5}", s)
68
- return int(m.group()) if m else fallback
69
 
 
70
 
 
 
 
 
71
 
 
 
 
72
  def generate_summarization_config(text):
73
  """
74
- Fast, robust parameter generator that prefers quick generation settings to avoid worker timeouts.
75
- If the param model fails or is slow, we fall back to heuristics.
76
  """
77
- # short prompt (keeps prompt length bounded)
78
  prompt = (
79
- "You are a helpful assistant that recommends summarization settings.\n"
80
- "Given the following source text, pick a summary LENGTH category (short/medium/long), "
81
- "an estimated MIN and MAX length in words for the summary, and a TONE (neutral/formal/casual/bullet).\n"
82
- "Respond ONLY in compact JSON (single line):\n"
83
  '{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n'
84
  "Text:\n'''"
85
- + (text[:3000]) # limit prompt size so generation is fast
86
- + "'''"
87
  )
88
 
89
  try:
90
- # IMPORTANT: use max_new_tokens (not max_length), small beam or sampling off,
91
- # and a small token limit to keep latency low on CPU.
92
  gen = param_generator(
93
  prompt,
94
- max_new_tokens=64, # keep short
95
- num_beams=1, # avoid expensive beam-search
96
- do_sample=False # deterministic and typically faster for small models
 
97
  )
98
  out = gen[0].get("generated_text", "").strip()
99
- # try to extract JSON substring
100
  cfg = None
101
  try:
102
  cfg = json.loads(out)
103
  except Exception:
104
- jmatch = re.search(r"\{.*\}", out, re.DOTALL)
105
- if jmatch:
106
- raw = jmatch.group().replace("'", '"')
107
  cfg = json.loads(raw)
108
  if not cfg:
109
- raise ValueError("Failed to parse param-generator output")
110
 
111
  length = cfg.get("length", "").lower()
112
  tone = cfg.get("tone", "").lower()
113
  min_w = cfg.get("min_words")
114
  max_w = cfg.get("max_words")
115
 
116
- # normalize & fallback rules
117
  if length not in ("short", "medium", "long"):
118
  words = len(text.split())
119
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
120
  if tone not in ("neutral", "formal", "casual", "bullet"):
121
  tone = "neutral"
122
 
 
 
 
 
 
123
  defaults = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
124
- dmin, dmax = defaults[length]
125
  min_len = int(min_w) if isinstance(min_w, int) else dmin
126
  max_len = int(max_w) if isinstance(max_w, int) else dmax
127
 
128
- # clamp to sane bounds
129
  min_len = max(5, min(min_len, 2000))
130
  max_len = max(min_len + 5, min(max_len, 4000))
131
 
 
132
  return {"length": length, "min_length": min_len, "max_length": max_len, "tone": tone}
133
  except Exception as e:
134
- # log the error and fallback to quick heuristic
135
- logger.exception("param-generator failed or timed out, falling back to heuristic: %s", str(e))
136
  words = len(text.split())
137
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
138
  fallback = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
139
  mn, mx = fallback[length]
140
  return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # -------------------------
143
  # Routes
144
  # -------------------------
145
  @app.route("/")
146
  def home():
147
- # expects templates/index.html to exist (your frontend)
148
  return render_template("index.html")
149
 
150
  @app.route("/summarize", methods=["POST"])
151
  def summarize_route():
152
- start_time = time.time()
153
  data = request.get_json(force=True)
154
- text = data.get("text", "")[:20000] # cap input
155
  requested_length = (data.get("length") or "medium").lower()
156
  requested_tone = (data.get("tone") or "neutral").lower()
157
 
158
  if not text or len(text.split()) < 5:
159
  return jsonify({"error": "Input too short."}), 400
160
 
161
- # If user asks AI to choose settings
162
  if requested_length in ("auto", "ai") or requested_tone in ("auto", "ai"):
163
  cfg = generate_summarization_config(text)
164
- length = cfg.get("length", "medium")
165
- tone = cfg.get("tone", "neutral")
166
  preset_min = cfg.get("min_length")
167
  preset_max = cfg.get("max_length")
168
- preset = LENGTH_PRESETS.get(length, LENGTH_PRESETS["medium"])
169
  else:
170
- length = requested_length if requested_length in LENGTH_PRESETS else "medium"
171
- tone = requested_tone if requested_tone in ("neutral", "formal", "casual", "bullet") else "neutral"
172
- preset = LENGTH_PRESETS.get(length, LENGTH_PRESETS["medium"])
173
- preset_min = preset["min_length"]
174
- preset_max = preset["max_length"]
175
 
176
- # chunk input for long texts
177
- chunks = chunk_text_by_chars(text, max_chars=1500, overlap=200)
178
- summaries = []
179
 
180
- for chunk in chunks:
181
- prompted = apply_tone_instruction(chunk, tone)
182
- min_l = int(preset_min) if preset_min is not None else preset["min_length"]
183
- max_l = int(preset_max) if preset_max is not None else preset["max_length"]
184
-
185
- out = summarizer(
186
- prompted,
187
- min_length=min_l,
188
- max_length=max_l,
189
- truncation=True
190
- )[0]["summary_text"]
191
- summaries.append(out.strip())
192
-
193
- if len(summaries) == 1:
194
- final = summaries[0]
195
  else:
196
- combined = "\n\n".join(summaries)
197
- prompted = apply_tone_instruction(combined, tone)
198
- final = summarizer(
199
- prompted,
200
- min_length=preset["min_length"],
201
- max_length=preset["max_length"],
202
- truncation=True
203
- )[0]["summary_text"]
204
 
205
- if tone == "bullet":
206
- lines = [l.strip() for s in final.splitlines() for l in s.split(". ") if l.strip()]
207
- final = "\n".join(f"- {l.rstrip('.')}" for l in lines[:20])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- elapsed = time.time() - start_time
210
- return jsonify({"summary": final, "meta": {"length_choice": length, "tone": tone, "time_seconds": round(elapsed, 2)}})
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  if __name__ == "__main__":
213
- # keep debug off in production; using CPU as requested
214
- app.run(host="0.0.0.0", port=7860, debug=True)
 
1
  import os
2
  os.environ['HF_HOME'] = '/tmp'
3
 
4
+ import time
5
+ import json
6
+ import re
7
  import logging
8
+ from collections import Counter
9
+
10
+ from flask import Flask, request, jsonify, render_template
11
+ import torch
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForSeq2SeqLM,
15
+ pipeline
16
+ )
17
 
18
+ # -------------------------
19
+ # Basic app + logging
20
+ # -------------------------
21
  app = Flask(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger("summarizer")
24
+
25
+ # -------------------------
26
+ # Device selection (GPU if available)
27
+ # -------------------------
28
+ USE_GPU = torch.cuda.is_available()
29
+ DEVICE = 0 if USE_GPU else -1
30
+ logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE)
31
 
32
  # -------------------------
33
+ # Models (quality-first)
34
  # -------------------------
35
+ # Primary summarizer (higher-quality model)
36
+ SUMMARIZER_MODEL = "facebook/bart-large-cnn" # quality-focused
37
+ summ_tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL)
38
+ summ_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL)
39
+ summarizer = pipeline("summarization", model=summ_model, tokenizer=summ_tokenizer, device=DEVICE)
40
 
41
+ # Parameter-generator (small instruction model to "think" and choose settings)
42
+ # We keep this compact but capable. If you later want stronger reasoning, swap to flan-t5-base.
43
+ PARAM_MODEL = "google/flan-t5-small"
44
+ param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL)
45
+ param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
46
+ param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE)
47
 
48
  # -------------------------
49
+ # Presets & utilities
50
  # -------------------------
51
  LENGTH_PRESETS = {
52
  "short": {"min_length": 20, "max_length": 60},
 
54
  "long": {"min_length": 130, "max_length": 300},
55
  }
56
 
57
+ # Simple sentence splitter and extractive prefilter (helps focus abstractive model)
58
+ _STOPWORDS = {
59
+ "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
60
+ }
61
+
62
+ def tokenize_sentences(text):
63
+ sents = re.split(r'(?<=[.!?])\s+', text.strip())
64
+ return [s.strip() for s in sents if s.strip()]
65
+
66
+ def extractive_prefilter(text, top_k=12):
67
+ """
68
+ Rank sentences by (non-stopword) word-frequency and return top_k sentences
69
+ in original order joined. Useful for very long inputs.
70
+ """
71
+ sents = tokenize_sentences(text)
72
+ if len(sents) <= top_k:
73
+ return text
74
+ words = re.findall(r"\w+", text.lower())
75
+ freqs = Counter(w for w in words if w not in _STOPWORDS)
76
+ scores = []
77
+ for i, s in enumerate(sents):
78
+ ws = re.findall(r"\w+", s.lower())
79
+ score = sum(freqs.get(w, 0) for w in ws)
80
+ scores.append((score, i, s))
81
+ scores.sort(reverse=True)
82
+ chosen = [s for _, _, s in sorted(scores[:top_k], key=lambda t: t[1])]
83
+ return " ".join(chosen)
84
+
85
  def chunk_text_by_chars(text, max_chars=1500, overlap=200):
86
  if len(text) <= max_chars:
87
  return [text]
 
95
  end = start + nl
96
  chunk = text[start:end]
97
  parts.append(chunk.strip())
98
+ start = max(end - overlap, end) # move forward with overlap
99
  return parts
100
 
101
+ def apply_tone_instruction(text, tone, target_sentences=None):
102
+ """
103
+ Build a clear instruction prompt for the summarizer based on tone/length.
104
+ """
105
  tone = (tone or "neutral").lower()
106
+ if tone == "bullet":
107
+ instr = "Produce concise bullet points. Each bullet should be short (<=20 words) and focused. No extra commentary."
108
+ elif tone == "short":
109
+ ts = target_sentences or 1
110
+ instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim."
111
+ elif tone == "formal":
112
+ instr = "Summarize in a formal, professional tone in 2-4 sentences. Keep it precise and well-structured."
113
  elif tone == "casual":
114
+ instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language."
115
+ elif tone == "long":
116
+ instr = "Provide a clear, structured summary in 4-8 sentences, covering key points and relevant context."
117
  else:
118
+ instr = "Summarize the content in 2-3 sentences. Be clear and concise."
 
119
 
120
+ instr += " Do not repeat the same information. Prefer rephrasing over copying."
 
 
 
121
 
122
+ return f"{instr}\n\nText:\n{text}"
123
 
124
+ # helper: extract first integer
125
+ def _first_int_from_text(s, fallback=None):
126
+ m = re.search(r"\d{1,4}", s)
127
+ return int(m.group()) if m else fallback
128
 
129
+ # -------------------------
130
+ # Parameter generator (AI "thinking" module)
131
+ # -------------------------
132
  def generate_summarization_config(text):
133
  """
134
+ Use the instruction model to recommend: length(short|medium|long), min_words, max_words, tone.
135
+ Falls back to heuristics on failure.
136
  """
 
137
  prompt = (
138
+ "You are an assistant that recommends optimal summarization settings.\n"
139
+ "Given the text, respond ONLY with single-line JSON EXACTLY like:\n"
 
 
140
  '{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n'
141
  "Text:\n'''"
142
+ + text[:4000] +
143
+ "'''"
144
  )
145
 
146
  try:
147
+ # keep generation short and deterministic; use max_new_tokens (avoid max_length)
 
148
  gen = param_generator(
149
  prompt,
150
+ max_new_tokens=64,
151
+ num_beams=1,
152
+ do_sample=False,
153
+ early_stopping=True
154
  )
155
  out = gen[0].get("generated_text", "").strip()
156
+ # attempt JSON parse
157
  cfg = None
158
  try:
159
  cfg = json.loads(out)
160
  except Exception:
161
+ j = re.search(r"\{.*\}", out, re.DOTALL)
162
+ if j:
163
+ raw = j.group().replace("'", '"')
164
  cfg = json.loads(raw)
165
  if not cfg:
166
+ raise ValueError("Param-generator output not parseable")
167
 
168
  length = cfg.get("length", "").lower()
169
  tone = cfg.get("tone", "").lower()
170
  min_w = cfg.get("min_words")
171
  max_w = cfg.get("max_words")
172
 
 
173
  if length not in ("short", "medium", "long"):
174
  words = len(text.split())
175
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
176
  if tone not in ("neutral", "formal", "casual", "bullet"):
177
  tone = "neutral"
178
 
179
+ if not isinstance(min_w, int):
180
+ min_w = _first_int_from_text(out, fallback=None)
181
+ if not isinstance(max_w, int):
182
+ max_w = _first_int_from_text(out[::-1], fallback=None)
183
+
184
  defaults = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
185
+ dmin, dmax = defaults.get(length, (50,130))
186
  min_len = int(min_w) if isinstance(min_w, int) else dmin
187
  max_len = int(max_w) if isinstance(max_w, int) else dmax
188
 
 
189
  min_len = max(5, min(min_len, 2000))
190
  max_len = max(min_len + 5, min(max_len, 4000))
191
 
192
+ logger.info("Param-generator chose: length=%s tone=%s min=%s max=%s", length, tone, min_len, max_len)
193
  return {"length": length, "min_length": min_len, "max_length": max_len, "tone": tone}
194
  except Exception as e:
195
+ logger.exception("Param-generator failed; falling back to heuristic: %s", str(e))
 
196
  words = len(text.split())
197
  length = "short" if words < 150 else ("medium" if words < 800 else "long")
198
  fallback = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
199
  mn, mx = fallback[length]
200
  return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
201
 
202
+ # -------------------------
203
+ # Two-stage summarization helpers
204
+ # -------------------------
205
+ def refine_and_combine(summaries_list, tone, final_target_sentences=None):
206
+ """
207
+ Combine chunk summaries and perform a refinement pass to produce cohesive final summary.
208
+ """
209
+ combined = "\n\n".join(summaries_list)
210
+ if len(combined.split()) > 2000:
211
+ combined = extractive_prefilter(combined, top_k=20)
212
+
213
+ prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
214
+
215
+ # heuristics for min/max
216
+ tgt_sent = final_target_sentences or 3
217
+ gen_kwargs = {
218
+ "min_length": max(20, int(tgt_sent * 8)),
219
+ "max_length": max(60, int(tgt_sent * 30)),
220
+ "num_beams": 6,
221
+ "early_stopping": True,
222
+ "no_repeat_ngram_size": 3,
223
+ "do_sample": False,
224
+ }
225
+
226
+ try:
227
+ out = summarizer(prompt, **gen_kwargs)[0]["summary_text"].strip()
228
+ return out
229
+ except Exception as e:
230
+ logger.exception("Refine step failed: %s", e)
231
+ return " ".join(summaries_list[:3])
232
+
233
  # -------------------------
234
  # Routes
235
  # -------------------------
236
  @app.route("/")
237
  def home():
238
+ # Ensure you have templates/index.html in place
239
  return render_template("index.html")
240
 
241
  @app.route("/summarize", methods=["POST"])
242
  def summarize_route():
243
+ t0 = time.time()
244
  data = request.get_json(force=True)
245
+ text = (data.get("text") or "")[:60000] # cap input to reasonable size
246
  requested_length = (data.get("length") or "medium").lower()
247
  requested_tone = (data.get("tone") or "neutral").lower()
248
 
249
  if not text or len(text.split()) < 5:
250
  return jsonify({"error": "Input too short."}), 400
251
 
252
+ # 1) Decide settings (AI or explicit)
253
  if requested_length in ("auto", "ai") or requested_tone in ("auto", "ai"):
254
  cfg = generate_summarization_config(text)
255
+ length_choice = cfg.get("length", "medium")
256
+ tone_choice = cfg.get("tone", "neutral")
257
  preset_min = cfg.get("min_length")
258
  preset_max = cfg.get("max_length")
 
259
  else:
260
+ length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
261
+ tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet","short","long") else "neutral"
262
+ preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"]
263
+ preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"]
 
264
 
265
+ # Map chosen length to target final sentences
266
+ sentence_map = {"short": 1, "medium": 3, "long": 6}
267
+ final_target_sentences = sentence_map.get(length_choice, 3)
268
 
269
+ # 2) Prefilter if extremely long
270
+ words_len = len(text.split())
271
+ if words_len > 3500:
272
+ text_for_chunks = extractive_prefilter(text, top_k=40)
 
 
 
 
 
 
 
 
 
 
 
273
  else:
274
+ text_for_chunks = text
 
 
 
 
 
 
 
275
 
276
+ # 3) Chunking
277
+ chunks = chunk_text_by_chars(text_for_chunks, max_chars=1400, overlap=200)
278
+ chunk_summaries = []
279
+
280
+ # 4) Summarize each chunk
281
+ for chunk in chunks:
282
+ chunk_target = 1 if length_choice == "short" else 2
283
+ chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral"
284
+ prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target)
285
+
286
+ gen_kwargs = {
287
+ "min_length": 12 if chunk_target == 1 else 24,
288
+ "max_length": 60 if chunk_target == 1 else 120,
289
+ "num_beams": 5,
290
+ "early_stopping": True,
291
+ "no_repeat_ngram_size": 3,
292
+ "do_sample": False,
293
+ }
294
+
295
+ try:
296
+ out = summarizer(prompt, **gen_kwargs)[0]["summary_text"].strip()
297
+ except Exception as e:
298
+ logger.exception("Chunk summarization failed, using extractive fallback: %s", e)
299
+ out = extractive_prefilter(chunk, top_k=3)
300
+ chunk_summaries.append(out)
301
 
302
+ # 5) Combine & refine
303
+ final = refine_and_combine(chunk_summaries, tone_choice, final_target_sentences=final_target_sentences)
304
 
305
+ # 6) Post-process for bullet tone
306
+ if tone_choice == "bullet":
307
+ parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
308
+ bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
309
+ final = "\n".join(bullets[:20])
310
+
311
+ elapsed = time.time() - t0
312
+ meta = {
313
+ "length_choice": length_choice,
314
+ "tone": tone_choice,
315
+ "chunks": len(chunks),
316
+ "input_words": words_len,
317
+ "time_seconds": round(elapsed, 2),
318
+ "device": ("gpu" if USE_GPU else "cpu")
319
+ }
320
+
321
+ return jsonify({"summary": final, "meta": meta})
322
+
323
+ # -------------------------
324
+ # Run
325
+ # -------------------------
326
  if __name__ == "__main__":
327
+ # In production use Gunicorn; debug True here only for local testing
328
+ app.run(host="0.0.0.0", port=7860, debug=False)