File size: 14,887 Bytes
05ced71
2226679
a323f1e
743f7ef
 
 
08b9e9f
743f7ef
bb6c458
743f7ef
 
 
bb6c458
05ced71
743f7ef
bb6c458
743f7ef
05ced71
743f7ef
 
 
 
bb6c458
743f7ef
 
 
 
05ced71
d005cea
bb6c458
d005cea
bb6c458
 
 
 
 
 
 
 
 
743f7ef
 
 
d005cea
 
bb6c458
d005cea
fd8623d
a323f1e
 
 
fd8623d
 
743f7ef
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6c458
743f7ef
 
 
bb6c458
 
 
743f7ef
 
fd8623d
a323f1e
fd8623d
 
 
 
 
 
 
a323f1e
fd8623d
 
 
bb6c458
fd8623d
 
bb6c458
 
 
 
 
 
 
 
743f7ef
bb6c458
 
743f7ef
bb6c458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd8623d
743f7ef
bb6c458
743f7ef
 
 
 
 
fd8623d
743f7ef
 
bb6c458
fd8623d
743f7ef
bb6c458
743f7ef
08b9e9f
d005cea
 
bb6c458
 
d005cea
 
bb6c458
743f7ef
08b9e9f
d005cea
bb6c458
d005cea
 
bb6c458
08b9e9f
743f7ef
 
 
 
bb6c458
08b9e9f
d005cea
 
 
743f7ef
 
 
d005cea
08b9e9f
bb6c458
 
 
d005cea
 
bb6c458
d005cea
 
bb6c458
d005cea
bb6c458
743f7ef
bb6c458
743f7ef
bb6c458
 
 
 
d005cea
bb6c458
 
 
08b9e9f
bb6c458
d005cea
 
bb6c458
 
 
d005cea
743f7ef
bb6c458
743f7ef
bb6c458
743f7ef
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6c458
 
 
 
743f7ef
 
bb6c458
743f7ef
 
bb6c458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d005cea
 
 
a323f1e
 
 
 
05ced71
fd8623d
743f7ef
bb6c458
 
 
 
 
fd8623d
 
bb6c458
05ced71
743f7ef
bb6c458
d005cea
bb6c458
 
d005cea
 
 
743f7ef
bb6c458
743f7ef
 
d005cea
bb6c458
 
743f7ef
bb6c458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743f7ef
fd8623d
743f7ef
05ced71
bb6c458
 
 
 
 
 
 
 
743f7ef
bb6c458
 
743f7ef
 
 
 
 
bb6c458
 
743f7ef
 
 
 
05ced71
bb6c458
 
 
 
 
05ced71
bb6c458
743f7ef
 
 
 
 
 
 
 
 
bb6c458
 
743f7ef
 
 
 
 
 
 
 
 
 
05ced71
bb6c458
743f7ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import os
os.environ['HF_HOME'] = '/tmp'

import time
import json
import re
import logging
from collections import Counter
from typing import Optional

from flask import Flask, request, jsonify, render_template
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# -------------------------
# App + logging
# -------------------------
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")

# -------------------------
# Device selection
# -------------------------
USE_GPU = torch.cuda.is_available()
DEVICE = 0 if USE_GPU else -1
logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE)

# -------------------------
# Model names (we'll load summarizers lazily)
# -------------------------
PEGASUS_MODEL = "google/pegasus-large"
LED_MODEL = "allenai/led-large-16384"
PARAM_MODEL = "google/flan-t5-small"   # instruction model for parameter generation

# caches for lazy-loaded pipelines
_SUMMARIZER_CACHE = {}

# load the small param-generator right away (keeps it small)
logger.info("Loading parameter generator model: %s", PARAM_MODEL)
param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL)
param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE)

# -------------------------
# Presets & utils
# -------------------------
LENGTH_PRESETS = {
    "short": {"min_length": 20, "max_length": 60},
    "medium": {"min_length": 60, "max_length": 130},
    "long": {"min_length": 130, "max_length": 300},
}

_STOPWORDS = {
    "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
}

def tokenize_sentences(text):
    sents = re.split(r'(?<=[.!?])\s+', text.strip())
    return [s.strip() for s in sents if s.strip()]

def extractive_prefilter(text, top_k=12):
    sents = tokenize_sentences(text)
    if len(sents) <= top_k:
        return text
    words = re.findall(r"\w+", text.lower())
    freqs = Counter(w for w in words if w not in _STOPWORDS)
    scored = []
    for i, s in enumerate(sents):
        ws = re.findall(r"\w+", s.lower())
        score = sum(freqs.get(w, 0) for w in ws)
        scored.append((score, i, s))
    scored.sort(reverse=True)
    chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])]
    return " ".join(chosen)

def chunk_text_by_chars(text, max_chars=1500, overlap=200):
    if len(text) <= max_chars:
        return [text]
    parts = []
    start = 0
    while start < len(text):
        end = min(len(text), start + max_chars)
        chunk = text[start:end]
        nl = chunk.rfind('\n')
        if nl > max_chars * 0.6:
            end = start + nl
            chunk = text[start:end]
        parts.append(chunk.strip())
        start = max(end - overlap, end)
    return parts

def _first_int_from_text(s, fallback=None):
    m = re.search(r"\d{1,4}", s)
    return int(m.group()) if m else fallback

# -------------------------
# Lazy summarizer loader
# -------------------------
def get_summarizer(model_key: str):
    """
    Returns a pipeline summarizer for 'pegasus' or 'led', loading it lazily.
    model_key: "pegasus" or "led"
    """
    model_key = model_key.lower()
    if model_key in _SUMMARIZER_CACHE:
        return _SUMMARIZER_CACHE[model_key]

    if model_key == "pegasus":
        model_name = PEGASUS_MODEL
    elif model_key == "led":
        model_name = LED_MODEL
    else:
        raise ValueError("Unknown model_key: " + str(model_key))

    logger.info("Loading summarizer model '%s' (%s) on device %s ...", model_key, model_name, DEVICE)
    tok = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
    _SUMMARIZER_CACHE[model_key] = pipe
    logger.info("Loaded summarizer '%s' successfully.", model_key)
    return pipe

# -------------------------
# Prompt and decision logic
# -------------------------
def apply_tone_instruction(text, tone, target_sentences=None):
    tone = (tone or "neutral").lower()
    if tone == "bullet":
        instr = "Produce concise bullet points. Each bullet short (<=20 words). No extra commentary."
    elif tone == "short":
        ts = target_sentences or 1
        instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim."
    elif tone == "formal":
        instr = "Summarize in a formal, professional tone in 2-4 sentences. Keep it precise and well-structured."
    elif tone == "casual":
        instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language."
    elif tone == "long":
        instr = "Provide a clear, structured summary in 4-8 sentences covering key points and context."
    else:
        instr = "Summarize the content in 2-3 sentences. Be clear and concise."
    instr += " Do not repeat information; prefer rephrasing."
    return f"{instr}\n\nText:\n{text}"

def generate_summarization_config(text):
    """
    Ask small instruction model for settings; fallback to heuristic.
    Returns dict with keys: length, min_length, max_length, tone
    """
    prompt = (
        "You are an assistant that recommends summarization settings.\n"
        "Given the text, respond ONLY with single-line JSON EXACTLY like:\n"
        '{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n'
        "Text:\n'''"
        + text[:4000] + "'''"
    )
    try:
        out = param_generator(
            prompt,
            max_new_tokens=64,
            num_beams=1,
            do_sample=False,
            early_stopping=True
        )[0].get("generated_text","").strip()
        cfg = None
        try:
            cfg = json.loads(out)
        except Exception:
            j = re.search(r"\{.*\}", out, re.DOTALL)
            if j:
                raw = j.group().replace("'", '"')
                cfg = json.loads(raw)
        if not cfg:
            raise ValueError("Unparseable param-generator output")
        length = cfg.get("length","").lower()
        tone = cfg.get("tone","").lower()
        min_w = cfg.get("min_words")
        max_w = cfg.get("max_words")
        if length not in ("short","medium","long"):
            words = len(text.split())
            length = "short" if words < 150 else ("medium" if words < 800 else "long")
        if tone not in ("neutral","formal","casual","bullet"):
            tone = "neutral"
        if not isinstance(min_w,int):
            min_w = _first_int_from_text(out, fallback=None)
        if not isinstance(max_w,int):
            max_w = _first_int_from_text(out[::-1], fallback=None)
        defaults = {"short":(15,50),"medium":(50,130),"long":(130,300)}
        dmin,dmax = defaults.get(length,(50,130))
        min_len = int(min_w) if isinstance(min_w,int) else dmin
        max_len = int(max_w) if isinstance(max_w,int) else dmax
        min_len = max(5, min(min_len, 2000))
        max_len = max(min_len+5, min(max_len, 4000))
        logger.info("Param-generator chose length=%s tone=%s min=%s max=%s", length, tone, min_len, max_len)
        return {"length":length,"min_length":min_len,"max_length":max_len,"tone":tone}
    except Exception as e:
        logger.exception("Param-generator failed: %s", e)
        words = len(text.split())
        length = "short" if words < 150 else ("medium" if words < 800 else "long")
        fallback = {"short":(15,50),"medium":(50,130),"long":(130,300)}
        mn,mx = fallback[length]
        return {"length":length,"min_length":mn,"max_length":mx,"tone":"neutral"}

# -------------------------
# Two-stage summarization (chunk -> chunk summaries -> refine)
# -------------------------
def refine_and_combine(summaries_list, tone, final_target_sentences=None, summarizer_pipe=None):
    combined = "\n\n".join(summaries_list)
    if len(combined.split()) > 2000:
        combined = extractive_prefilter(combined, top_k=20)
    prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
    tgt_sent = final_target_sentences or 3
    gen_kwargs = {
        "min_length": max(20, int(tgt_sent * 8)),
        "max_length": max(60, int(tgt_sent * 30)),
        "num_beams": 6,
        "early_stopping": True,
        "no_repeat_ngram_size": 3,
        "do_sample": False,
    }
    try:
        if summarizer_pipe is None:
            # fallback to pegasus by default (if pipe not provided)
            summarizer_pipe = get_summarizer("pegasus")
        out = summarizer_pipe(prompt, **gen_kwargs)[0]["summary_text"].strip()
        return out
    except Exception as e:
        logger.exception("Refine failed: %s", e)
        return " ".join(summaries_list[:3])

# -------------------------
# Model-specific generation helper
# -------------------------
def summarize_with_model(pipe, text_prompt, short_target=False):
    """
    Use model pipeline with conservative and model-appropriate generation settings.
    short_target: if True use shorter min/max suitable for concise outputs
    """
    # heuristics: if pipe is LED (model name in tied tokenizer), allow larger max_length
    model_name = getattr(pipe.model.config, "name_or_path", "") or ""
    is_led = "led" in model_name or "longformer" in model_name or "allenai" in model_name and "led" in model_name
    if short_target:
        min_l = 12
        max_l = 60
    else:
        min_l = 24
        max_l = 140 if not is_led else 400  # LED can handle longer outputs
    gen_kwargs = {
        "min_length": min_l,
        "max_length": max_l,
        "num_beams": 5 if not is_led else 4,
        "early_stopping": True,
        "no_repeat_ngram_size": 3,
        "do_sample": False,
    }
    return pipe(text_prompt, **gen_kwargs)[0]["summary_text"].strip()

# -------------------------
# Routes
# -------------------------
@app.route("/")
def home():
    return render_template("index.html")

@app.route("/summarize", methods=["POST"])
def summarize_route():
    t0 = time.time()
    data = request.get_json(force=True) or {}
    text = (data.get("text") or "")[:90000]
    user_model_pref = (data.get("model") or "auto").lower()   # 'pegasus' | 'led' | 'auto'
    requested_length = (data.get("length") or "auto").lower()  # short|medium|long|auto
    requested_tone = (data.get("tone") or "auto").lower()      # neutral|formal|casual|bullet|auto

    if not text or len(text.split()) < 5:
        return jsonify({"error":"Input too short."}), 400

    # 1) Decide settings (AI or explicit)
    if requested_length in ("auto","ai") or requested_tone in ("auto","ai"):
        cfg = generate_summarization_config(text)
        length_choice = cfg.get("length","medium")
        tone_choice = cfg.get("tone","neutral")
        preset_min = cfg.get("min_length")
        preset_max = cfg.get("max_length")
    else:
        length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
        tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
        preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"]
        preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"]

    # 2) Model selection (user preference or auto)
    # auto rules: if user specifically asked 'led' or param-generator picked long / input is very long -> led
    words_len = len(text.split())
    prefer_led = False
    if user_model_pref == "led":
        prefer_led = True
    elif user_model_pref == "pegasus":
        prefer_led = False
    else:  # auto
        if length_choice == "long" or words_len > 3000:
            prefer_led = True
        else:
            prefer_led = False

    model_key = "led" if prefer_led else "pegasus"
    # get the pipeline (lazy load)
    try:
        summarizer_pipe = get_summarizer(model_key)
    except Exception as e:
        logger.exception("Failed to load summarizer '%s': %s", model_key, e)
        # fallback to pegasus if led fails
        summarizer_pipe = get_summarizer("pegasus")
        model_key = "pegasus"

    # 3) Prefilter very long inputs (if not using LED)
    if not prefer_led and words_len > 2500:
        text_for_chunks = extractive_prefilter(text, top_k=40)
    else:
        text_for_chunks = text

    # 4) Chunking: choose chunk size depending on model
    if model_key == "led":
        chunk_max_chars = 8000   # LED can handle larger chunks
        chunk_overlap = 400
    else:
        chunk_max_chars = 1400
        chunk_overlap = 200
    chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max_chars, overlap=chunk_overlap)

    # 5) Summarize each chunk
    chunk_summaries = []
    for chunk in chunks:
        chunk_target = 1 if length_choice == "short" else 2
        chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral"
        prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target)
        try:
            # choose short_target True for tiny chunk summaries
            out = summarize_with_model(summarizer_pipe, prompt, short_target=(chunk_target==1))
        except Exception as e:
            logger.exception("Chunk summarization failed, using extractive fallback: %s", e)
            out = extractive_prefilter(chunk, top_k=3)
        chunk_summaries.append(out)

    # 6) Combine + refine using the same model for consistency (or prefer Pegasus for elegant refinement)
    refine_model_key = model_key if model_key == "led" else "pegasus"
    refine_pipe = get_summarizer(refine_model_key)
    final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice,3)
    final = refine_and_combine(chunk_summaries, tone_choice, final_target_sentences, summarizer_pipe=refine_pipe)

    # 7) Post-process bullet tone
    if tone_choice == "bullet":
        parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
        bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
        final = "\n".join(bullets[:20])

    elapsed = time.time() - t0
    meta = {
        "length_choice": length_choice,
        "tone": tone_choice,
        "model_used": model_key,
        "refine_model": refine_model_key,
        "chunks": len(chunks),
        "input_words": words_len,
        "time_seconds": round(elapsed, 2),
        "device": ("gpu" if USE_GPU else "cpu")
    }
    return jsonify({"summary": final, "meta": meta})

# -------------------------
# Run
# -------------------------
if __name__ == "__main__":
    # debug=False for production; use Gunicorn in deployment
    app.run(host="0.0.0.0", port=7860, debug=False)