import os os.environ['HF_HOME'] = '/tmp' import time import json import re import logging from collections import Counter from typing import Optional from flask import Flask, request, jsonify, render_template import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # ------------------------- # App + logging # ------------------------- app = Flask(__name__) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # ------------------------- # Device selection # ------------------------- USE_GPU = torch.cuda.is_available() DEVICE = 0 if USE_GPU else -1 logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE) # ------------------------- # Model names (we'll load summarizers lazily) # ------------------------- PEGASUS_MODEL = "google/pegasus-large" LED_MODEL = "allenai/led-large-16384" PARAM_MODEL = "google/flan-t5-small" # instruction model for parameter generation # caches for lazy-loaded pipelines _SUMMARIZER_CACHE = {} # load the small param-generator right away (keeps it small) logger.info("Loading parameter generator model: %s", PARAM_MODEL) param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL) param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL) param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE) # ------------------------- # Presets & utils # ------------------------- LENGTH_PRESETS = { "short": {"min_length": 20, "max_length": 60}, "medium": {"min_length": 60, "max_length": 130}, "long": {"min_length": 130, "max_length": 300}, } _STOPWORDS = { "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have" } def tokenize_sentences(text): sents = re.split(r'(?<=[.!?])\s+', text.strip()) return [s.strip() for s in sents if s.strip()] def extractive_prefilter(text, top_k=12): sents = tokenize_sentences(text) if len(sents) <= top_k: return text words = re.findall(r"\w+", text.lower()) freqs = Counter(w for w in words if w not in _STOPWORDS) scored = [] for i, s in enumerate(sents): ws = re.findall(r"\w+", s.lower()) score = sum(freqs.get(w, 0) for w in ws) scored.append((score, i, s)) scored.sort(reverse=True) chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])] return " ".join(chosen) def chunk_text_by_chars(text, max_chars=1500, overlap=200): if len(text) <= max_chars: return [text] parts = [] start = 0 while start < len(text): end = min(len(text), start + max_chars) chunk = text[start:end] nl = chunk.rfind('\n') if nl > max_chars * 0.6: end = start + nl chunk = text[start:end] parts.append(chunk.strip()) start = max(end - overlap, end) return parts def _first_int_from_text(s, fallback=None): m = re.search(r"\d{1,4}", s) return int(m.group()) if m else fallback # ------------------------- # Lazy summarizer loader # ------------------------- def get_summarizer(model_key: str): """ Returns a pipeline summarizer for 'pegasus' or 'led', loading it lazily. model_key: "pegasus" or "led" """ model_key = model_key.lower() if model_key in _SUMMARIZER_CACHE: return _SUMMARIZER_CACHE[model_key] if model_key == "pegasus": model_name = PEGASUS_MODEL elif model_key == "led": model_name = LED_MODEL else: raise ValueError("Unknown model_key: " + str(model_key)) logger.info("Loading summarizer model '%s' (%s) on device %s ...", model_key, model_name, DEVICE) tok = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE) _SUMMARIZER_CACHE[model_key] = pipe logger.info("Loaded summarizer '%s' successfully.", model_key) return pipe # ------------------------- # Prompt and decision logic # ------------------------- def apply_tone_instruction(text, tone, target_sentences=None): tone = (tone or "neutral").lower() if tone == "bullet": instr = "Produce concise bullet points. Each bullet short (<=20 words). No extra commentary." elif tone == "short": ts = target_sentences or 1 instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim." elif tone == "formal": instr = "Summarize in a formal, professional tone in 2-4 sentences. Keep it precise and well-structured." elif tone == "casual": instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language." elif tone == "long": instr = "Provide a clear, structured summary in 4-8 sentences covering key points and context." else: instr = "Summarize the content in 2-3 sentences. Be clear and concise." instr += " Do not repeat information; prefer rephrasing." return f"{instr}\n\nText:\n{text}" def generate_summarization_config(text): """ Ask small instruction model for settings; fallback to heuristic. Returns dict with keys: length, min_length, max_length, tone """ prompt = ( "You are an assistant that recommends summarization settings.\n" "Given the text, respond ONLY with single-line JSON EXACTLY like:\n" '{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n' "Text:\n'''" + text[:4000] + "'''" ) try: out = param_generator( prompt, max_new_tokens=64, num_beams=1, do_sample=False, early_stopping=True )[0].get("generated_text","").strip() cfg = None try: cfg = json.loads(out) except Exception: j = re.search(r"\{.*\}", out, re.DOTALL) if j: raw = j.group().replace("'", '"') cfg = json.loads(raw) if not cfg: raise ValueError("Unparseable param-generator output") length = cfg.get("length","").lower() tone = cfg.get("tone","").lower() min_w = cfg.get("min_words") max_w = cfg.get("max_words") if length not in ("short","medium","long"): words = len(text.split()) length = "short" if words < 150 else ("medium" if words < 800 else "long") if tone not in ("neutral","formal","casual","bullet"): tone = "neutral" if not isinstance(min_w,int): min_w = _first_int_from_text(out, fallback=None) if not isinstance(max_w,int): max_w = _first_int_from_text(out[::-1], fallback=None) defaults = {"short":(15,50),"medium":(50,130),"long":(130,300)} dmin,dmax = defaults.get(length,(50,130)) min_len = int(min_w) if isinstance(min_w,int) else dmin max_len = int(max_w) if isinstance(max_w,int) else dmax min_len = max(5, min(min_len, 2000)) max_len = max(min_len+5, min(max_len, 4000)) logger.info("Param-generator chose length=%s tone=%s min=%s max=%s", length, tone, min_len, max_len) return {"length":length,"min_length":min_len,"max_length":max_len,"tone":tone} except Exception as e: logger.exception("Param-generator failed: %s", e) words = len(text.split()) length = "short" if words < 150 else ("medium" if words < 800 else "long") fallback = {"short":(15,50),"medium":(50,130),"long":(130,300)} mn,mx = fallback[length] return {"length":length,"min_length":mn,"max_length":mx,"tone":"neutral"} # ------------------------- # Two-stage summarization (chunk -> chunk summaries -> refine) # ------------------------- def refine_and_combine(summaries_list, tone, final_target_sentences=None, summarizer_pipe=None): combined = "\n\n".join(summaries_list) if len(combined.split()) > 2000: combined = extractive_prefilter(combined, top_k=20) prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences) tgt_sent = final_target_sentences or 3 gen_kwargs = { "min_length": max(20, int(tgt_sent * 8)), "max_length": max(60, int(tgt_sent * 30)), "num_beams": 6, "early_stopping": True, "no_repeat_ngram_size": 3, "do_sample": False, } try: if summarizer_pipe is None: # fallback to pegasus by default (if pipe not provided) summarizer_pipe = get_summarizer("pegasus") out = summarizer_pipe(prompt, **gen_kwargs)[0]["summary_text"].strip() return out except Exception as e: logger.exception("Refine failed: %s", e) return " ".join(summaries_list[:3]) # ------------------------- # Model-specific generation helper # ------------------------- def summarize_with_model(pipe, text_prompt, short_target=False): """ Use model pipeline with conservative and model-appropriate generation settings. short_target: if True use shorter min/max suitable for concise outputs """ # heuristics: if pipe is LED (model name in tied tokenizer), allow larger max_length model_name = getattr(pipe.model.config, "name_or_path", "") or "" is_led = "led" in model_name or "longformer" in model_name or "allenai" in model_name and "led" in model_name if short_target: min_l = 12 max_l = 60 else: min_l = 24 max_l = 140 if not is_led else 400 # LED can handle longer outputs gen_kwargs = { "min_length": min_l, "max_length": max_l, "num_beams": 5 if not is_led else 4, "early_stopping": True, "no_repeat_ngram_size": 3, "do_sample": False, } return pipe(text_prompt, **gen_kwargs)[0]["summary_text"].strip() # ------------------------- # Routes # ------------------------- @app.route("/") def home(): return render_template("index.html") @app.route("/summarize", methods=["POST"]) def summarize_route(): t0 = time.time() data = request.get_json(force=True) or {} text = (data.get("text") or "")[:90000] user_model_pref = (data.get("model") or "auto").lower() # 'pegasus' | 'led' | 'auto' requested_length = (data.get("length") or "auto").lower() # short|medium|long|auto requested_tone = (data.get("tone") or "auto").lower() # neutral|formal|casual|bullet|auto if not text or len(text.split()) < 5: return jsonify({"error":"Input too short."}), 400 # 1) Decide settings (AI or explicit) if requested_length in ("auto","ai") or requested_tone in ("auto","ai"): cfg = generate_summarization_config(text) length_choice = cfg.get("length","medium") tone_choice = cfg.get("tone","neutral") preset_min = cfg.get("min_length") preset_max = cfg.get("max_length") else: length_choice = requested_length if requested_length in ("short","medium","long") else "medium" tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral" preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"] preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"] # 2) Model selection (user preference or auto) # auto rules: if user specifically asked 'led' or param-generator picked long / input is very long -> led words_len = len(text.split()) prefer_led = False if user_model_pref == "led": prefer_led = True elif user_model_pref == "pegasus": prefer_led = False else: # auto if length_choice == "long" or words_len > 3000: prefer_led = True else: prefer_led = False model_key = "led" if prefer_led else "pegasus" # get the pipeline (lazy load) try: summarizer_pipe = get_summarizer(model_key) except Exception as e: logger.exception("Failed to load summarizer '%s': %s", model_key, e) # fallback to pegasus if led fails summarizer_pipe = get_summarizer("pegasus") model_key = "pegasus" # 3) Prefilter very long inputs (if not using LED) if not prefer_led and words_len > 2500: text_for_chunks = extractive_prefilter(text, top_k=40) else: text_for_chunks = text # 4) Chunking: choose chunk size depending on model if model_key == "led": chunk_max_chars = 8000 # LED can handle larger chunks chunk_overlap = 400 else: chunk_max_chars = 1400 chunk_overlap = 200 chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max_chars, overlap=chunk_overlap) # 5) Summarize each chunk chunk_summaries = [] for chunk in chunks: chunk_target = 1 if length_choice == "short" else 2 chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral" prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target) try: # choose short_target True for tiny chunk summaries out = summarize_with_model(summarizer_pipe, prompt, short_target=(chunk_target==1)) except Exception as e: logger.exception("Chunk summarization failed, using extractive fallback: %s", e) out = extractive_prefilter(chunk, top_k=3) chunk_summaries.append(out) # 6) Combine + refine using the same model for consistency (or prefer Pegasus for elegant refinement) refine_model_key = model_key if model_key == "led" else "pegasus" refine_pipe = get_summarizer(refine_model_key) final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice,3) final = refine_and_combine(chunk_summaries, tone_choice, final_target_sentences, summarizer_pipe=refine_pipe) # 7) Post-process bullet tone if tone_choice == "bullet": parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final) bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()] final = "\n".join(bullets[:20]) elapsed = time.time() - t0 meta = { "length_choice": length_choice, "tone": tone_choice, "model_used": model_key, "refine_model": refine_model_key, "chunks": len(chunks), "input_words": words_len, "time_seconds": round(elapsed, 2), "device": ("gpu" if USE_GPU else "cpu") } return jsonify({"summary": final, "meta": meta}) # ------------------------- # Run # ------------------------- if __name__ == "__main__": # debug=False for production; use Gunicorn in deployment app.run(host="0.0.0.0", port=7860, debug=False)