Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -6,47 +6,44 @@ import json
|
|
| 6 |
import re
|
| 7 |
import logging
|
| 8 |
from collections import Counter
|
|
|
|
| 9 |
|
| 10 |
from flask import Flask, request, jsonify, render_template
|
| 11 |
import torch
|
| 12 |
-
from transformers import
|
| 13 |
-
AutoTokenizer,
|
| 14 |
-
AutoModelForSeq2SeqLM,
|
| 15 |
-
pipeline
|
| 16 |
-
)
|
| 17 |
|
| 18 |
# -------------------------
|
| 19 |
-
#
|
| 20 |
# -------------------------
|
| 21 |
app = Flask(__name__)
|
| 22 |
logging.basicConfig(level=logging.INFO)
|
| 23 |
logger = logging.getLogger("summarizer")
|
| 24 |
|
| 25 |
# -------------------------
|
| 26 |
-
# Device selection
|
| 27 |
# -------------------------
|
| 28 |
USE_GPU = torch.cuda.is_available()
|
| 29 |
DEVICE = 0 if USE_GPU else -1
|
| 30 |
logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE)
|
| 31 |
|
| 32 |
# -------------------------
|
| 33 |
-
#
|
| 34 |
# -------------------------
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL)
|
| 45 |
param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
|
| 46 |
param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE)
|
| 47 |
|
| 48 |
# -------------------------
|
| 49 |
-
# Presets &
|
| 50 |
# -------------------------
|
| 51 |
LENGTH_PRESETS = {
|
| 52 |
"short": {"min_length": 20, "max_length": 60},
|
|
@@ -54,7 +51,6 @@ LENGTH_PRESETS = {
|
|
| 54 |
"long": {"min_length": 130, "max_length": 300},
|
| 55 |
}
|
| 56 |
|
| 57 |
-
# Simple sentence splitter and extractive prefilter (helps focus abstractive model)
|
| 58 |
_STOPWORDS = {
|
| 59 |
"the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
|
| 60 |
}
|
|
@@ -64,22 +60,18 @@ def tokenize_sentences(text):
|
|
| 64 |
return [s.strip() for s in sents if s.strip()]
|
| 65 |
|
| 66 |
def extractive_prefilter(text, top_k=12):
|
| 67 |
-
"""
|
| 68 |
-
Rank sentences by (non-stopword) word-frequency and return top_k sentences
|
| 69 |
-
in original order joined. Useful for very long inputs.
|
| 70 |
-
"""
|
| 71 |
sents = tokenize_sentences(text)
|
| 72 |
if len(sents) <= top_k:
|
| 73 |
return text
|
| 74 |
words = re.findall(r"\w+", text.lower())
|
| 75 |
freqs = Counter(w for w in words if w not in _STOPWORDS)
|
| 76 |
-
|
| 77 |
for i, s in enumerate(sents):
|
| 78 |
ws = re.findall(r"\w+", s.lower())
|
| 79 |
score = sum(freqs.get(w, 0) for w in ws)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
chosen = [s for _, _, s in sorted(
|
| 83 |
return " ".join(chosen)
|
| 84 |
|
| 85 |
def chunk_text_by_chars(text, max_chars=1500, overlap=200):
|
|
@@ -95,16 +87,47 @@ def chunk_text_by_chars(text, max_chars=1500, overlap=200):
|
|
| 95 |
end = start + nl
|
| 96 |
chunk = text[start:end]
|
| 97 |
parts.append(chunk.strip())
|
| 98 |
-
start = max(end - overlap, end)
|
| 99 |
return parts
|
| 100 |
|
| 101 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
"""
|
| 103 |
-
|
|
|
|
| 104 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
tone = (tone or "neutral").lower()
|
| 106 |
if tone == "bullet":
|
| 107 |
-
instr = "Produce concise bullet points. Each bullet
|
| 108 |
elif tone == "short":
|
| 109 |
ts = target_sentences or 1
|
| 110 |
instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim."
|
|
@@ -113,47 +136,32 @@ def apply_tone_instruction(text, tone, target_sentences=None):
|
|
| 113 |
elif tone == "casual":
|
| 114 |
instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language."
|
| 115 |
elif tone == "long":
|
| 116 |
-
instr = "Provide a clear, structured summary in 4-8 sentences
|
| 117 |
else:
|
| 118 |
instr = "Summarize the content in 2-3 sentences. Be clear and concise."
|
| 119 |
-
|
| 120 |
-
instr += " Do not repeat the same information. Prefer rephrasing over copying."
|
| 121 |
-
|
| 122 |
return f"{instr}\n\nText:\n{text}"
|
| 123 |
|
| 124 |
-
# helper: extract first integer
|
| 125 |
-
def _first_int_from_text(s, fallback=None):
|
| 126 |
-
m = re.search(r"\d{1,4}", s)
|
| 127 |
-
return int(m.group()) if m else fallback
|
| 128 |
-
|
| 129 |
-
# -------------------------
|
| 130 |
-
# Parameter generator (AI "thinking" module)
|
| 131 |
-
# -------------------------
|
| 132 |
def generate_summarization_config(text):
|
| 133 |
"""
|
| 134 |
-
|
| 135 |
-
|
| 136 |
"""
|
| 137 |
prompt = (
|
| 138 |
-
"You are an assistant that recommends
|
| 139 |
"Given the text, respond ONLY with single-line JSON EXACTLY like:\n"
|
| 140 |
'{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n'
|
| 141 |
"Text:\n'''"
|
| 142 |
-
+ text[:4000] +
|
| 143 |
-
"'''"
|
| 144 |
)
|
| 145 |
-
|
| 146 |
try:
|
| 147 |
-
|
| 148 |
-
gen = param_generator(
|
| 149 |
prompt,
|
| 150 |
max_new_tokens=64,
|
| 151 |
num_beams=1,
|
| 152 |
do_sample=False,
|
| 153 |
early_stopping=True
|
| 154 |
-
)
|
| 155 |
-
out = gen[0].get("generated_text", "").strip()
|
| 156 |
-
# attempt JSON parse
|
| 157 |
cfg = None
|
| 158 |
try:
|
| 159 |
cfg = json.loads(out)
|
|
@@ -163,56 +171,44 @@ def generate_summarization_config(text):
|
|
| 163 |
raw = j.group().replace("'", '"')
|
| 164 |
cfg = json.loads(raw)
|
| 165 |
if not cfg:
|
| 166 |
-
raise ValueError("
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
tone = cfg.get("tone", "").lower()
|
| 170 |
min_w = cfg.get("min_words")
|
| 171 |
max_w = cfg.get("max_words")
|
| 172 |
-
|
| 173 |
-
if length not in ("short", "medium", "long"):
|
| 174 |
words = len(text.split())
|
| 175 |
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 176 |
-
if tone not in ("neutral",
|
| 177 |
tone = "neutral"
|
| 178 |
-
|
| 179 |
-
if not isinstance(min_w, int):
|
| 180 |
min_w = _first_int_from_text(out, fallback=None)
|
| 181 |
-
if not isinstance(max_w,
|
| 182 |
max_w = _first_int_from_text(out[::-1], fallback=None)
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
max_len = int(max_w) if isinstance(max_w, int) else dmax
|
| 188 |
-
|
| 189 |
min_len = max(5, min(min_len, 2000))
|
| 190 |
-
max_len = max(min_len
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
return {"length": length, "min_length": min_len, "max_length": max_len, "tone": tone}
|
| 194 |
except Exception as e:
|
| 195 |
-
logger.exception("Param-generator failed
|
| 196 |
words = len(text.split())
|
| 197 |
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 198 |
-
fallback = {"short":
|
| 199 |
-
mn,
|
| 200 |
-
return {"length":
|
| 201 |
|
| 202 |
# -------------------------
|
| 203 |
-
# Two-stage summarization
|
| 204 |
# -------------------------
|
| 205 |
-
def refine_and_combine(summaries_list, tone, final_target_sentences=None):
|
| 206 |
-
"""
|
| 207 |
-
Combine chunk summaries and perform a refinement pass to produce cohesive final summary.
|
| 208 |
-
"""
|
| 209 |
combined = "\n\n".join(summaries_list)
|
| 210 |
if len(combined.split()) > 2000:
|
| 211 |
combined = extractive_prefilter(combined, top_k=20)
|
| 212 |
-
|
| 213 |
prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
|
| 214 |
-
|
| 215 |
-
# heuristics for min/max
|
| 216 |
tgt_sent = final_target_sentences or 3
|
| 217 |
gen_kwargs = {
|
| 218 |
"min_length": max(20, int(tgt_sent * 8)),
|
|
@@ -222,87 +218,135 @@ def refine_and_combine(summaries_list, tone, final_target_sentences=None):
|
|
| 222 |
"no_repeat_ngram_size": 3,
|
| 223 |
"do_sample": False,
|
| 224 |
}
|
| 225 |
-
|
| 226 |
try:
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
| 228 |
return out
|
| 229 |
except Exception as e:
|
| 230 |
-
logger.exception("Refine
|
| 231 |
return " ".join(summaries_list[:3])
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
# -------------------------
|
| 234 |
# Routes
|
| 235 |
# -------------------------
|
| 236 |
@app.route("/")
|
| 237 |
def home():
|
| 238 |
-
# Ensure you have templates/index.html in place
|
| 239 |
return render_template("index.html")
|
| 240 |
|
| 241 |
@app.route("/summarize", methods=["POST"])
|
| 242 |
def summarize_route():
|
| 243 |
t0 = time.time()
|
| 244 |
-
data = request.get_json(force=True)
|
| 245 |
-
text = (data.get("text") or "")[:
|
| 246 |
-
|
| 247 |
-
|
|
|
|
| 248 |
|
| 249 |
if not text or len(text.split()) < 5:
|
| 250 |
-
return jsonify({"error":
|
| 251 |
|
| 252 |
# 1) Decide settings (AI or explicit)
|
| 253 |
-
if requested_length in ("auto",
|
| 254 |
cfg = generate_summarization_config(text)
|
| 255 |
-
length_choice = cfg.get("length",
|
| 256 |
-
tone_choice = cfg.get("tone",
|
| 257 |
preset_min = cfg.get("min_length")
|
| 258 |
preset_max = cfg.get("max_length")
|
| 259 |
else:
|
| 260 |
length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
|
| 261 |
-
tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet"
|
| 262 |
preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"]
|
| 263 |
preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"]
|
| 264 |
|
| 265 |
-
#
|
| 266 |
-
|
| 267 |
-
final_target_sentences = sentence_map.get(length_choice, 3)
|
| 268 |
-
|
| 269 |
-
# 2) Prefilter if extremely long
|
| 270 |
words_len = len(text.split())
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
text_for_chunks = extractive_prefilter(text, top_k=40)
|
| 273 |
else:
|
| 274 |
text_for_chunks = text
|
| 275 |
|
| 276 |
-
#
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
#
|
|
|
|
| 281 |
for chunk in chunks:
|
| 282 |
chunk_target = 1 if length_choice == "short" else 2
|
| 283 |
chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral"
|
| 284 |
prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target)
|
| 285 |
-
|
| 286 |
-
gen_kwargs = {
|
| 287 |
-
"min_length": 12 if chunk_target == 1 else 24,
|
| 288 |
-
"max_length": 60 if chunk_target == 1 else 120,
|
| 289 |
-
"num_beams": 5,
|
| 290 |
-
"early_stopping": True,
|
| 291 |
-
"no_repeat_ngram_size": 3,
|
| 292 |
-
"do_sample": False,
|
| 293 |
-
}
|
| 294 |
-
|
| 295 |
try:
|
| 296 |
-
|
|
|
|
| 297 |
except Exception as e:
|
| 298 |
logger.exception("Chunk summarization failed, using extractive fallback: %s", e)
|
| 299 |
out = extractive_prefilter(chunk, top_k=3)
|
| 300 |
chunk_summaries.append(out)
|
| 301 |
|
| 302 |
-
#
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
-
#
|
| 306 |
if tone_choice == "bullet":
|
| 307 |
parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
|
| 308 |
bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
|
|
@@ -312,17 +356,18 @@ def summarize_route():
|
|
| 312 |
meta = {
|
| 313 |
"length_choice": length_choice,
|
| 314 |
"tone": tone_choice,
|
|
|
|
|
|
|
| 315 |
"chunks": len(chunks),
|
| 316 |
"input_words": words_len,
|
| 317 |
"time_seconds": round(elapsed, 2),
|
| 318 |
"device": ("gpu" if USE_GPU else "cpu")
|
| 319 |
}
|
| 320 |
-
|
| 321 |
return jsonify({"summary": final, "meta": meta})
|
| 322 |
|
| 323 |
# -------------------------
|
| 324 |
# Run
|
| 325 |
# -------------------------
|
| 326 |
if __name__ == "__main__":
|
| 327 |
-
#
|
| 328 |
app.run(host="0.0.0.0", port=7860, debug=False)
|
|
|
|
| 6 |
import re
|
| 7 |
import logging
|
| 8 |
from collections import Counter
|
| 9 |
+
from typing import Optional
|
| 10 |
|
| 11 |
from flask import Flask, request, jsonify, render_template
|
| 12 |
import torch
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# -------------------------
|
| 16 |
+
# App + logging
|
| 17 |
# -------------------------
|
| 18 |
app = Flask(__name__)
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
| 20 |
logger = logging.getLogger("summarizer")
|
| 21 |
|
| 22 |
# -------------------------
|
| 23 |
+
# Device selection
|
| 24 |
# -------------------------
|
| 25 |
USE_GPU = torch.cuda.is_available()
|
| 26 |
DEVICE = 0 if USE_GPU else -1
|
| 27 |
logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE)
|
| 28 |
|
| 29 |
# -------------------------
|
| 30 |
+
# Model names (we'll load summarizers lazily)
|
| 31 |
# -------------------------
|
| 32 |
+
PEGASUS_MODEL = "google/pegasus-large"
|
| 33 |
+
LED_MODEL = "allenai/led-large-16384"
|
| 34 |
+
PARAM_MODEL = "google/flan-t5-small" # instruction model for parameter generation
|
| 35 |
+
|
| 36 |
+
# caches for lazy-loaded pipelines
|
| 37 |
+
_SUMMARIZER_CACHE = {}
|
| 38 |
+
|
| 39 |
+
# load the small param-generator right away (keeps it small)
|
| 40 |
+
logger.info("Loading parameter generator model: %s", PARAM_MODEL)
|
| 41 |
param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL)
|
| 42 |
param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
|
| 43 |
param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE)
|
| 44 |
|
| 45 |
# -------------------------
|
| 46 |
+
# Presets & utils
|
| 47 |
# -------------------------
|
| 48 |
LENGTH_PRESETS = {
|
| 49 |
"short": {"min_length": 20, "max_length": 60},
|
|
|
|
| 51 |
"long": {"min_length": 130, "max_length": 300},
|
| 52 |
}
|
| 53 |
|
|
|
|
| 54 |
_STOPWORDS = {
|
| 55 |
"the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
|
| 56 |
}
|
|
|
|
| 60 |
return [s.strip() for s in sents if s.strip()]
|
| 61 |
|
| 62 |
def extractive_prefilter(text, top_k=12):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
sents = tokenize_sentences(text)
|
| 64 |
if len(sents) <= top_k:
|
| 65 |
return text
|
| 66 |
words = re.findall(r"\w+", text.lower())
|
| 67 |
freqs = Counter(w for w in words if w not in _STOPWORDS)
|
| 68 |
+
scored = []
|
| 69 |
for i, s in enumerate(sents):
|
| 70 |
ws = re.findall(r"\w+", s.lower())
|
| 71 |
score = sum(freqs.get(w, 0) for w in ws)
|
| 72 |
+
scored.append((score, i, s))
|
| 73 |
+
scored.sort(reverse=True)
|
| 74 |
+
chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])]
|
| 75 |
return " ".join(chosen)
|
| 76 |
|
| 77 |
def chunk_text_by_chars(text, max_chars=1500, overlap=200):
|
|
|
|
| 87 |
end = start + nl
|
| 88 |
chunk = text[start:end]
|
| 89 |
parts.append(chunk.strip())
|
| 90 |
+
start = max(end - overlap, end)
|
| 91 |
return parts
|
| 92 |
|
| 93 |
+
def _first_int_from_text(s, fallback=None):
|
| 94 |
+
m = re.search(r"\d{1,4}", s)
|
| 95 |
+
return int(m.group()) if m else fallback
|
| 96 |
+
|
| 97 |
+
# -------------------------
|
| 98 |
+
# Lazy summarizer loader
|
| 99 |
+
# -------------------------
|
| 100 |
+
def get_summarizer(model_key: str):
|
| 101 |
"""
|
| 102 |
+
Returns a pipeline summarizer for 'pegasus' or 'led', loading it lazily.
|
| 103 |
+
model_key: "pegasus" or "led"
|
| 104 |
"""
|
| 105 |
+
model_key = model_key.lower()
|
| 106 |
+
if model_key in _SUMMARIZER_CACHE:
|
| 107 |
+
return _SUMMARIZER_CACHE[model_key]
|
| 108 |
+
|
| 109 |
+
if model_key == "pegasus":
|
| 110 |
+
model_name = PEGASUS_MODEL
|
| 111 |
+
elif model_key == "led":
|
| 112 |
+
model_name = LED_MODEL
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError("Unknown model_key: " + str(model_key))
|
| 115 |
+
|
| 116 |
+
logger.info("Loading summarizer model '%s' (%s) on device %s ...", model_key, model_name, DEVICE)
|
| 117 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
| 118 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 119 |
+
pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
|
| 120 |
+
_SUMMARIZER_CACHE[model_key] = pipe
|
| 121 |
+
logger.info("Loaded summarizer '%s' successfully.", model_key)
|
| 122 |
+
return pipe
|
| 123 |
+
|
| 124 |
+
# -------------------------
|
| 125 |
+
# Prompt and decision logic
|
| 126 |
+
# -------------------------
|
| 127 |
+
def apply_tone_instruction(text, tone, target_sentences=None):
|
| 128 |
tone = (tone or "neutral").lower()
|
| 129 |
if tone == "bullet":
|
| 130 |
+
instr = "Produce concise bullet points. Each bullet short (<=20 words). No extra commentary."
|
| 131 |
elif tone == "short":
|
| 132 |
ts = target_sentences or 1
|
| 133 |
instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim."
|
|
|
|
| 136 |
elif tone == "casual":
|
| 137 |
instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language."
|
| 138 |
elif tone == "long":
|
| 139 |
+
instr = "Provide a clear, structured summary in 4-8 sentences covering key points and context."
|
| 140 |
else:
|
| 141 |
instr = "Summarize the content in 2-3 sentences. Be clear and concise."
|
| 142 |
+
instr += " Do not repeat information; prefer rephrasing."
|
|
|
|
|
|
|
| 143 |
return f"{instr}\n\nText:\n{text}"
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
def generate_summarization_config(text):
|
| 146 |
"""
|
| 147 |
+
Ask small instruction model for settings; fallback to heuristic.
|
| 148 |
+
Returns dict with keys: length, min_length, max_length, tone
|
| 149 |
"""
|
| 150 |
prompt = (
|
| 151 |
+
"You are an assistant that recommends summarization settings.\n"
|
| 152 |
"Given the text, respond ONLY with single-line JSON EXACTLY like:\n"
|
| 153 |
'{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n'
|
| 154 |
"Text:\n'''"
|
| 155 |
+
+ text[:4000] + "'''"
|
|
|
|
| 156 |
)
|
|
|
|
| 157 |
try:
|
| 158 |
+
out = param_generator(
|
|
|
|
| 159 |
prompt,
|
| 160 |
max_new_tokens=64,
|
| 161 |
num_beams=1,
|
| 162 |
do_sample=False,
|
| 163 |
early_stopping=True
|
| 164 |
+
)[0].get("generated_text","").strip()
|
|
|
|
|
|
|
| 165 |
cfg = None
|
| 166 |
try:
|
| 167 |
cfg = json.loads(out)
|
|
|
|
| 171 |
raw = j.group().replace("'", '"')
|
| 172 |
cfg = json.loads(raw)
|
| 173 |
if not cfg:
|
| 174 |
+
raise ValueError("Unparseable param-generator output")
|
| 175 |
+
length = cfg.get("length","").lower()
|
| 176 |
+
tone = cfg.get("tone","").lower()
|
|
|
|
| 177 |
min_w = cfg.get("min_words")
|
| 178 |
max_w = cfg.get("max_words")
|
| 179 |
+
if length not in ("short","medium","long"):
|
|
|
|
| 180 |
words = len(text.split())
|
| 181 |
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 182 |
+
if tone not in ("neutral","formal","casual","bullet"):
|
| 183 |
tone = "neutral"
|
| 184 |
+
if not isinstance(min_w,int):
|
|
|
|
| 185 |
min_w = _first_int_from_text(out, fallback=None)
|
| 186 |
+
if not isinstance(max_w,int):
|
| 187 |
max_w = _first_int_from_text(out[::-1], fallback=None)
|
| 188 |
+
defaults = {"short":(15,50),"medium":(50,130),"long":(130,300)}
|
| 189 |
+
dmin,dmax = defaults.get(length,(50,130))
|
| 190 |
+
min_len = int(min_w) if isinstance(min_w,int) else dmin
|
| 191 |
+
max_len = int(max_w) if isinstance(max_w,int) else dmax
|
|
|
|
|
|
|
| 192 |
min_len = max(5, min(min_len, 2000))
|
| 193 |
+
max_len = max(min_len+5, min(max_len, 4000))
|
| 194 |
+
logger.info("Param-generator chose length=%s tone=%s min=%s max=%s", length, tone, min_len, max_len)
|
| 195 |
+
return {"length":length,"min_length":min_len,"max_length":max_len,"tone":tone}
|
|
|
|
| 196 |
except Exception as e:
|
| 197 |
+
logger.exception("Param-generator failed: %s", e)
|
| 198 |
words = len(text.split())
|
| 199 |
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 200 |
+
fallback = {"short":(15,50),"medium":(50,130),"long":(130,300)}
|
| 201 |
+
mn,mx = fallback[length]
|
| 202 |
+
return {"length":length,"min_length":mn,"max_length":mx,"tone":"neutral"}
|
| 203 |
|
| 204 |
# -------------------------
|
| 205 |
+
# Two-stage summarization (chunk -> chunk summaries -> refine)
|
| 206 |
# -------------------------
|
| 207 |
+
def refine_and_combine(summaries_list, tone, final_target_sentences=None, summarizer_pipe=None):
|
|
|
|
|
|
|
|
|
|
| 208 |
combined = "\n\n".join(summaries_list)
|
| 209 |
if len(combined.split()) > 2000:
|
| 210 |
combined = extractive_prefilter(combined, top_k=20)
|
|
|
|
| 211 |
prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
|
|
|
|
|
|
|
| 212 |
tgt_sent = final_target_sentences or 3
|
| 213 |
gen_kwargs = {
|
| 214 |
"min_length": max(20, int(tgt_sent * 8)),
|
|
|
|
| 218 |
"no_repeat_ngram_size": 3,
|
| 219 |
"do_sample": False,
|
| 220 |
}
|
|
|
|
| 221 |
try:
|
| 222 |
+
if summarizer_pipe is None:
|
| 223 |
+
# fallback to pegasus by default (if pipe not provided)
|
| 224 |
+
summarizer_pipe = get_summarizer("pegasus")
|
| 225 |
+
out = summarizer_pipe(prompt, **gen_kwargs)[0]["summary_text"].strip()
|
| 226 |
return out
|
| 227 |
except Exception as e:
|
| 228 |
+
logger.exception("Refine failed: %s", e)
|
| 229 |
return " ".join(summaries_list[:3])
|
| 230 |
|
| 231 |
+
# -------------------------
|
| 232 |
+
# Model-specific generation helper
|
| 233 |
+
# -------------------------
|
| 234 |
+
def summarize_with_model(pipe, text_prompt, short_target=False):
|
| 235 |
+
"""
|
| 236 |
+
Use model pipeline with conservative and model-appropriate generation settings.
|
| 237 |
+
short_target: if True use shorter min/max suitable for concise outputs
|
| 238 |
+
"""
|
| 239 |
+
# heuristics: if pipe is LED (model name in tied tokenizer), allow larger max_length
|
| 240 |
+
model_name = getattr(pipe.model.config, "name_or_path", "") or ""
|
| 241 |
+
is_led = "led" in model_name or "longformer" in model_name or "allenai" in model_name and "led" in model_name
|
| 242 |
+
if short_target:
|
| 243 |
+
min_l = 12
|
| 244 |
+
max_l = 60
|
| 245 |
+
else:
|
| 246 |
+
min_l = 24
|
| 247 |
+
max_l = 140 if not is_led else 400 # LED can handle longer outputs
|
| 248 |
+
gen_kwargs = {
|
| 249 |
+
"min_length": min_l,
|
| 250 |
+
"max_length": max_l,
|
| 251 |
+
"num_beams": 5 if not is_led else 4,
|
| 252 |
+
"early_stopping": True,
|
| 253 |
+
"no_repeat_ngram_size": 3,
|
| 254 |
+
"do_sample": False,
|
| 255 |
+
}
|
| 256 |
+
return pipe(text_prompt, **gen_kwargs)[0]["summary_text"].strip()
|
| 257 |
+
|
| 258 |
# -------------------------
|
| 259 |
# Routes
|
| 260 |
# -------------------------
|
| 261 |
@app.route("/")
|
| 262 |
def home():
|
|
|
|
| 263 |
return render_template("index.html")
|
| 264 |
|
| 265 |
@app.route("/summarize", methods=["POST"])
|
| 266 |
def summarize_route():
|
| 267 |
t0 = time.time()
|
| 268 |
+
data = request.get_json(force=True) or {}
|
| 269 |
+
text = (data.get("text") or "")[:90000]
|
| 270 |
+
user_model_pref = (data.get("model") or "auto").lower() # 'pegasus' | 'led' | 'auto'
|
| 271 |
+
requested_length = (data.get("length") or "auto").lower() # short|medium|long|auto
|
| 272 |
+
requested_tone = (data.get("tone") or "auto").lower() # neutral|formal|casual|bullet|auto
|
| 273 |
|
| 274 |
if not text or len(text.split()) < 5:
|
| 275 |
+
return jsonify({"error":"Input too short."}), 400
|
| 276 |
|
| 277 |
# 1) Decide settings (AI or explicit)
|
| 278 |
+
if requested_length in ("auto","ai") or requested_tone in ("auto","ai"):
|
| 279 |
cfg = generate_summarization_config(text)
|
| 280 |
+
length_choice = cfg.get("length","medium")
|
| 281 |
+
tone_choice = cfg.get("tone","neutral")
|
| 282 |
preset_min = cfg.get("min_length")
|
| 283 |
preset_max = cfg.get("max_length")
|
| 284 |
else:
|
| 285 |
length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
|
| 286 |
+
tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
|
| 287 |
preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"]
|
| 288 |
preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"]
|
| 289 |
|
| 290 |
+
# 2) Model selection (user preference or auto)
|
| 291 |
+
# auto rules: if user specifically asked 'led' or param-generator picked long / input is very long -> led
|
|
|
|
|
|
|
|
|
|
| 292 |
words_len = len(text.split())
|
| 293 |
+
prefer_led = False
|
| 294 |
+
if user_model_pref == "led":
|
| 295 |
+
prefer_led = True
|
| 296 |
+
elif user_model_pref == "pegasus":
|
| 297 |
+
prefer_led = False
|
| 298 |
+
else: # auto
|
| 299 |
+
if length_choice == "long" or words_len > 3000:
|
| 300 |
+
prefer_led = True
|
| 301 |
+
else:
|
| 302 |
+
prefer_led = False
|
| 303 |
+
|
| 304 |
+
model_key = "led" if prefer_led else "pegasus"
|
| 305 |
+
# get the pipeline (lazy load)
|
| 306 |
+
try:
|
| 307 |
+
summarizer_pipe = get_summarizer(model_key)
|
| 308 |
+
except Exception as e:
|
| 309 |
+
logger.exception("Failed to load summarizer '%s': %s", model_key, e)
|
| 310 |
+
# fallback to pegasus if led fails
|
| 311 |
+
summarizer_pipe = get_summarizer("pegasus")
|
| 312 |
+
model_key = "pegasus"
|
| 313 |
+
|
| 314 |
+
# 3) Prefilter very long inputs (if not using LED)
|
| 315 |
+
if not prefer_led and words_len > 2500:
|
| 316 |
text_for_chunks = extractive_prefilter(text, top_k=40)
|
| 317 |
else:
|
| 318 |
text_for_chunks = text
|
| 319 |
|
| 320 |
+
# 4) Chunking: choose chunk size depending on model
|
| 321 |
+
if model_key == "led":
|
| 322 |
+
chunk_max_chars = 8000 # LED can handle larger chunks
|
| 323 |
+
chunk_overlap = 400
|
| 324 |
+
else:
|
| 325 |
+
chunk_max_chars = 1400
|
| 326 |
+
chunk_overlap = 200
|
| 327 |
+
chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max_chars, overlap=chunk_overlap)
|
| 328 |
|
| 329 |
+
# 5) Summarize each chunk
|
| 330 |
+
chunk_summaries = []
|
| 331 |
for chunk in chunks:
|
| 332 |
chunk_target = 1 if length_choice == "short" else 2
|
| 333 |
chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral"
|
| 334 |
prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
try:
|
| 336 |
+
# choose short_target True for tiny chunk summaries
|
| 337 |
+
out = summarize_with_model(summarizer_pipe, prompt, short_target=(chunk_target==1))
|
| 338 |
except Exception as e:
|
| 339 |
logger.exception("Chunk summarization failed, using extractive fallback: %s", e)
|
| 340 |
out = extractive_prefilter(chunk, top_k=3)
|
| 341 |
chunk_summaries.append(out)
|
| 342 |
|
| 343 |
+
# 6) Combine + refine using the same model for consistency (or prefer Pegasus for elegant refinement)
|
| 344 |
+
refine_model_key = model_key if model_key == "led" else "pegasus"
|
| 345 |
+
refine_pipe = get_summarizer(refine_model_key)
|
| 346 |
+
final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice,3)
|
| 347 |
+
final = refine_and_combine(chunk_summaries, tone_choice, final_target_sentences, summarizer_pipe=refine_pipe)
|
| 348 |
|
| 349 |
+
# 7) Post-process bullet tone
|
| 350 |
if tone_choice == "bullet":
|
| 351 |
parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
|
| 352 |
bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
|
|
|
|
| 356 |
meta = {
|
| 357 |
"length_choice": length_choice,
|
| 358 |
"tone": tone_choice,
|
| 359 |
+
"model_used": model_key,
|
| 360 |
+
"refine_model": refine_model_key,
|
| 361 |
"chunks": len(chunks),
|
| 362 |
"input_words": words_len,
|
| 363 |
"time_seconds": round(elapsed, 2),
|
| 364 |
"device": ("gpu" if USE_GPU else "cpu")
|
| 365 |
}
|
|
|
|
| 366 |
return jsonify({"summary": final, "meta": meta})
|
| 367 |
|
| 368 |
# -------------------------
|
| 369 |
# Run
|
| 370 |
# -------------------------
|
| 371 |
if __name__ == "__main__":
|
| 372 |
+
# debug=False for production; use Gunicorn in deployment
|
| 373 |
app.run(host="0.0.0.0", port=7860, debug=False)
|