import os os.environ['HF_HOME'] = '/tmp' import time import json import re import logging from collections import Counter from typing import Optional, Dict, Any, List from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError from flask import Flask, request, jsonify, render_template import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # ------------------------- # App + logging # ------------------------- app = Flask(__name__) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # ------------------------- # Device selection (CPU by default) # ------------------------- USE_GPU = False DEVICE = -1 logger.info("Startup: forcing CPU usage for models (DEVICE=%s)", DEVICE) # ------------------------- # Model names and caches # ------------------------- PEGASUS_MODEL = "google/pegasus-large" LED_MODEL = "allenai/led-large-16384" DISTILBART_MODEL = "sshleifer/distilbart-cnn-12-6" PARAM_MODEL = "google/flan-t5-small" _SUMMARIZER_CACHE: Dict[str, Any] = {} _PARAM_GENERATOR = None _PREFERRED_SUMMARIZER_KEY: Optional[str] = None # ------------------------- # Utilities: chunking, extractive fallback # ------------------------- _STOPWORDS = { "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have" } def tokenize_sentences(text: str) -> List[str]: sents = re.split(r'(?<=[.!?])\s+', text.strip()) return [s.strip() for s in sents if s.strip()] def extractive_prefilter(text: str, top_k: int = 6) -> str: sents = tokenize_sentences(text) if len(sents) <= top_k: return text words = re.findall(r"\w+", text.lower()) freqs = Counter(w for w in words if w not in _STOPWORDS) scored = [] for i, s in enumerate(sents): ws = re.findall(r"\w+", s.lower()) score = sum(freqs.get(w, 0) for w in ws) scored.append((score, i, s)) scored.sort(reverse=True) chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])] return " ".join(chosen) def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]: n = len(text) if n <= max_chars: return [text] parts = [] start = 0 prev_start = -1 while start < n and start != prev_start: prev_start = start end = min(n, start + max_chars) chunk = text[start:end] nl = chunk.rfind('\n') if nl > int(max_chars * 0.6): end = start + nl chunk = text[start:end] parts.append(chunk.strip()) start = end - overlap if start <= prev_start: start = end return [p for p in parts if p] # ------------------------- # safe loader (defined before any calls) # ------------------------- def safe_load_pipeline(model_name: str): """ Try to load a summarization pipeline robustly: - try fast tokenizer first - if that fails, try use_fast=False - return pipeline or None if both fail """ try: logger.info("Loading tokenizer/model for %s (fast)...", model_name) tok = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE) logger.info("Loaded %s (fast tokenizer)", model_name) return pipe except Exception as e_fast: logger.warning("Fast tokenizer failed for %s: %s. Trying slow tokenizer...", model_name, e_fast) try: tok = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE) logger.info("Loaded %s (slow tokenizer)", model_name) return pipe except Exception as e_slow: logger.exception("Slow tokenizer failed for %s: %s", model_name, e_slow) return None # ------------------------- # get_summarizer: lazy load + cache + fallback # ------------------------- def get_summarizer(key: str): """ key: 'pegasus'|'led'|'distilbart'|'auto' returns a pipeline (cached), or raises RuntimeError if no pipeline can be loaded. """ key = (key or "auto").lower() if key == "auto": key = _PREFERRED_SUMMARIZER_KEY or "distilbart" # direct mapping model_name = { "pegasus": PEGASUS_MODEL, "led": LED_MODEL, "distilbart": DISTILBART_MODEL }.get(key, DISTILBART_MODEL) if key in _SUMMARIZER_CACHE: return _SUMMARIZER_CACHE[key] # try to load logger.info("Attempting to lazy-load summarizer '%s' -> %s", key, model_name) pipe = safe_load_pipeline(model_name) if pipe: _SUMMARIZER_CACHE[key] = pipe return pipe # fallback attempts logger.warning("Failed to load %s. Trying distilbart fallback.", key) if "distilbart" in _SUMMARIZER_CACHE: return _SUMMARIZER_CACHE["distilbart"] fb = safe_load_pipeline(DISTILBART_MODEL) if fb: _SUMMARIZER_CACHE["distilbart"] = fb return fb # nothing works raise RuntimeError("No summarizer available. Install required libraries and/or choose smaller model.") # ------------------------- # Generation strategy + small helpers # ------------------------- def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) -> str: model_name = getattr(pipe.model.config, "name_or_path", "") or "" is_led = "led" in model_name.lower() or "longformer" in model_name.lower() # Fast sampling pass fast_cfg = { "max_new_tokens": 64 if short_target else (120 if not is_led else 240), "do_sample": True, "top_p": 0.92, "temperature": 0.85, "num_beams": 1, "early_stopping": True, "no_repeat_ngram_size": 3, } try: return pipe(text_prompt, **fast_cfg)[0].get("summary_text","").strip() except Exception as e: logger.warning("Fast pass failed: %s, trying quality pass...", e) quality_cfg = { "max_new_tokens": 140 if not is_led else 320, "do_sample": False, "num_beams": 3, "early_stopping": True, "no_repeat_ngram_size": 3, } try: return pipe(text_prompt, **quality_cfg)[0].get("summary_text","").strip() except Exception as e: logger.exception("Quality pass failed: %s", e) # fallback extractive try: return extractive_prefilter(text_prompt, top_k=3) except Exception: return "Summarization failed; try shorter input." # ------------------------- # Param generator (AI decision) - lazy loader # ------------------------- def get_param_generator(): global _PARAM_GENERATOR if _PARAM_GENERATOR is not None: return _PARAM_GENERATOR # try to load text2text pipeline for PARAM_MODEL try: logger.info("Loading param-generator (text2text) lazily: %s", PARAM_MODEL) tok = AutoTokenizer.from_pretrained(PARAM_MODEL) mod = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL) _PARAM_GENERATOR = pipeline("text2text-generation", model=mod, tokenizer=tok, device=DEVICE) logger.info("Param-generator loaded.") return _PARAM_GENERATOR except Exception as e: logger.exception("Param-generator lazy load failed: %s", e) _PARAM_GENERATOR = None return None def generate_summarization_config(text: str) -> Dict[str, Any]: defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)} pg = get_param_generator() if pg is None: words = len(text.split()) length = "short" if words < 150 else ("medium" if words < 800 else "long") mn, mx = defaults[length] return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"} prompt = ( "Recommend summarization settings for this text. Answer ONLY with JSON like:\n" '{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n' "Text:\n'''" + text[:3000] + "'''" ) try: out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1)[0] out = out_item.get("generated_text") or out_item.get("summary_text") or "" out = (out or "").strip() if not out: raise ValueError("Empty param-generator output") # reject noisy echo outputs input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200]) out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200]) if len(input_words) and (len(input_words & out_words) / max(1, len(input_words))) > 0.4: logger.warning("Param-generator appears to echo input; using heuristic fallback.") words = len(text.split()) length = "short" if words < 150 else ("medium" if words < 800 else "long") mn, mx = defaults[length] return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"} jmatch = re.search(r"\{.*\}", out, re.DOTALL) if jmatch: raw = jmatch.group().replace("'", '"') cfg = json.loads(raw) else: cfg = None if not cfg or not isinstance(cfg, dict): raise ValueError("Param output not parseable") length = cfg.get("length","medium").lower() tone = cfg.get("tone","neutral").lower() mn = int(cfg.get("min_words") or cfg.get("min_length") or defaults[length][0]) mx = int(cfg.get("max_words") or cfg.get("max_length") or defaults[length][1]) mn = max(5, min(mn, 2000)) mx = max(mn + 5, min(mx, 4000)) return {"length": length, "min_length": mn, "max_length": mx, "tone": tone} except Exception as e: logger.exception("Param-generator parse failed: %s", e) words = len(text.split()) length = "short" if words < 150 else ("medium" if words < 800 else "long") mn, mx = defaults[length] return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"} # ------------------------- # Threaded chunk summarization with per-chunk timeout (to prevent hang) # ------------------------- executor = ThreadPoolExecutor(max_workers=min(8, max(2, (os.cpu_count() or 2)))) CHUNK_TIMEOUT_SECONDS = 28 REFINE_TIMEOUT_SECONDS = 60 def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> List[str]: futures = {} results = [None] * len(chunks) for idx, chunk in enumerate(chunks): prompt = apply_tone_instruction(chunk, "neutral", target_sentences=chunk_target) fut = executor.submit(summarize_with_model, pipe, prompt, short_target=(chunk_target==1)) futures[fut] = idx start = time.time() for fut in as_completed(futures): idx = futures[fut] try: remaining = max(0.1, CHUNK_TIMEOUT_SECONDS - (time.time() - start)) results[idx] = fut.result(timeout=remaining) except TimeoutError: logger.warning("Chunk %d timed out; using extractive fallback.", idx) results[idx] = extractive_prefilter(chunks[idx], top_k=3) except Exception as e: logger.exception("Chunk %d failed: %s; falling back", idx, e) results[idx] = extractive_prefilter(chunks[idx], top_k=3) for i, r in enumerate(results): if not r: results[i] = extractive_prefilter(chunks[i], top_k=3) return results # ------------------------- # Prompt helpers and refine # ------------------------- def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int] = None) -> str: tone = (tone or "neutral").lower() if tone == "bullet": instr = "Produce concise bullet points. Each bullet <= 20 words. No extra commentary." elif tone == "short": ts = target_sentences or 1 instr = f"Summarize in {ts} sentence{'s' if ts>1 else ''}. Be abstractive." elif tone == "formal": instr = "Summarize in a formal, professional tone (2-4 sentences)." elif tone == "casual": instr = "Summarize in a casual, conversational tone (1-3 sentences)." elif tone == "long": instr = "Provide a structured summary (4-8 sentences)." else: instr = "Summarize in 2-3 clear sentences." instr += " Do not repeat information. Prefer rephrasing." return f"{instr}\n\nText:\n{text}" def refine_combined(pipe, summaries_list: List[str], tone: str, final_target_sentences: int) -> str: combined = "\n\n".join(summaries_list) if len(combined.split()) > 1200: combined = extractive_prefilter(combined, top_k=20) prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences) fut = executor.submit(summarize_with_model, pipe, prompt, short_target=False) try: return fut.result(timeout=REFINE_TIMEOUT_SECONDS) except TimeoutError: logger.warning("Refine timed out; returning concatenated chunk summaries.") return " ".join(summaries_list[:6]) except Exception as e: logger.exception("Refine failed: %s", e) return " ".join(summaries_list[:6]) # ------------------------- # Routes # ------------------------- @app.route("/", methods=["GET"]) def home(): try: return render_template("index.html") except Exception: return "Summarizer (lazy-load) — POST /summarize with JSON {text:'...'}", 200 @app.route("/preload", methods=["POST"]) def preload_models(): """ Explicit endpoint to attempt preloading heavy models. Call this only when you want the process to attempt loading Pegasus/LED (may be slow). """ results = {} for key, model_name in [("pegasus", PEGASUS_MODEL), ("led", LED_MODEL), ("distilbart", DISTILBART_MODEL)]: if key in _SUMMARIZER_CACHE: results[key] = "already_loaded" continue try: p = safe_load_pipeline(model_name) if p: _SUMMARIZER_CACHE[key] = p results[key] = "loaded" else: results[key] = "failed" except Exception as e: results[key] = f"error: {e}" return jsonify(results) @app.route("/summarize", methods=["POST"]) def summarize_route(): t0 = time.time() data = request.get_json(force=True) or {} text = (data.get("text") or "").strip()[:90000] user_model_pref = (data.get("model") or "auto").lower() requested_length = (data.get("length") or "auto").lower() requested_tone = (data.get("tone") or "auto").lower() if not text or len(text.split()) < 5: return jsonify({"error": "Input too short."}), 400 # decide settings if requested_length in ("auto","ai") or requested_tone in ("auto","ai"): cfg = generate_summarization_config(text) length_choice = cfg.get("length","medium") tone_choice = cfg.get("tone","neutral") else: length_choice = requested_length if requested_length in ("short","medium","long") else "medium" tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral" # model selection logic words_len = len(text.split()) prefer_led = False if user_model_pref == "led": prefer_led = True elif user_model_pref == "pegasus": prefer_led = False else: if length_choice == "long" or words_len > 3000: prefer_led = True model_key = "led" if prefer_led else (_PREFERRED_SUMMARIZER_KEY or "distilbart") try: summarizer_pipe = get_summarizer(model_key) except Exception as e: logger.exception("get_summarizer failed (%s). Falling back to distilbart.", e) summarizer_pipe = get_summarizer("distilbart") model_key = "distilbart" # prefilter very long inputs for non-LED if model_key != "led" and words_len > 2500: text_for_chunks = extractive_prefilter(text, top_k=40) else: text_for_chunks = text # chunk sizing if model_key == "led": chunk_max = 6000 overlap = 400 else: chunk_max = 800 overlap = 120 chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max, overlap=overlap) chunk_target = 1 if length_choice == "short" else 2 # summarize chunks in parallel try: chunk_summaries = summarize_chunks_parallel(summarizer_pipe, chunks, chunk_target) except Exception as e: logger.exception("Chunk summarization orchestration failed: %s", e) chunk_summaries = [extractive_prefilter(c, top_k=3) for c in chunks] # refine step — prefer Pegasus if loaded, otherwise use current pipe refine_pipe = _SUMMARIZER_CACHE.get("pegasus") or summarizer_pipe final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice, 3) final = refine_combined(refine_pipe, chunk_summaries, tone_choice, final_target_sentences) # bullet postprocess if tone_choice == "bullet": parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final) bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()] final = "\n".join(bullets[:20]) elapsed = time.time() - t0 meta = { "length_choice": length_choice, "tone": tone_choice, "model_used": model_key, "chunks": len(chunks), "input_words": words_len, "time_seconds": round(elapsed, 2), "device": "cpu" } return jsonify({"summary": final, "meta": meta}) # ------------------------- # Local run (safe) # ------------------------- if __name__ == "__main__": # For local testing you may call preload_models_at_startup manually or use /preload. app.run(host="0.0.0.0", port=7860, debug=False)