Spaces:
Runtime error
Runtime error
| import os | |
| os.environ['HF_HOME'] = '/tmp' | |
| import time | |
| import json | |
| import re | |
| import logging | |
| from collections import Counter | |
| from typing import Optional, Dict, Any, List | |
| from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError | |
| from flask import Flask, request, jsonify, render_template | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # ------------------------- | |
| # App + logging | |
| # ------------------------- | |
| app = Flask(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("summarizer") | |
| # ------------------------- | |
| # Device selection (CPU by default) | |
| # ------------------------- | |
| USE_GPU = False | |
| DEVICE = -1 | |
| logger.info("Startup: forcing CPU usage for models (DEVICE=%s)", DEVICE) | |
| # ------------------------- | |
| # Model names and caches | |
| # ------------------------- | |
| PEGASUS_MODEL = "google/pegasus-large" | |
| LED_MODEL = "allenai/led-large-16384" | |
| DISTILBART_MODEL = "sshleifer/distilbart-cnn-12-6" | |
| PARAM_MODEL = "google/flan-t5-small" | |
| _SUMMARIZER_CACHE: Dict[str, Any] = {} | |
| _PARAM_GENERATOR = None | |
| _PREFERRED_SUMMARIZER_KEY: Optional[str] = None | |
| # ------------------------- | |
| # Utilities: chunking, extractive fallback | |
| # ------------------------- | |
| _STOPWORDS = { | |
| "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have" | |
| } | |
| def tokenize_sentences(text: str) -> List[str]: | |
| sents = re.split(r'(?<=[.!?])\s+', text.strip()) | |
| return [s.strip() for s in sents if s.strip()] | |
| def extractive_prefilter(text: str, top_k: int = 6) -> str: | |
| sents = tokenize_sentences(text) | |
| if len(sents) <= top_k: | |
| return text | |
| words = re.findall(r"\w+", text.lower()) | |
| freqs = Counter(w for w in words if w not in _STOPWORDS) | |
| scored = [] | |
| for i, s in enumerate(sents): | |
| ws = re.findall(r"\w+", s.lower()) | |
| score = sum(freqs.get(w, 0) for w in ws) | |
| scored.append((score, i, s)) | |
| scored.sort(reverse=True) | |
| chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])] | |
| return " ".join(chosen) | |
| def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]: | |
| n = len(text) | |
| if n <= max_chars: | |
| return [text] | |
| parts = [] | |
| start = 0 | |
| prev_start = -1 | |
| while start < n and start != prev_start: | |
| prev_start = start | |
| end = min(n, start + max_chars) | |
| chunk = text[start:end] | |
| nl = chunk.rfind('\n') | |
| if nl > int(max_chars * 0.6): | |
| end = start + nl | |
| chunk = text[start:end] | |
| parts.append(chunk.strip()) | |
| start = end - overlap | |
| if start <= prev_start: | |
| start = end | |
| return [p for p in parts if p] | |
| # ------------------------- | |
| # safe loader (defined before any calls) | |
| # ------------------------- | |
| def safe_load_pipeline(model_name: str): | |
| """ | |
| Try to load a summarization pipeline robustly: | |
| - try fast tokenizer first | |
| - if that fails, try use_fast=False | |
| - return pipeline or None if both fail | |
| """ | |
| try: | |
| logger.info("Loading tokenizer/model for %s (fast)...", model_name) | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE) | |
| logger.info("Loaded %s (fast tokenizer)", model_name) | |
| return pipe | |
| except Exception as e_fast: | |
| logger.warning("Fast tokenizer failed for %s: %s. Trying slow tokenizer...", model_name, e_fast) | |
| try: | |
| tok = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE) | |
| logger.info("Loaded %s (slow tokenizer)", model_name) | |
| return pipe | |
| except Exception as e_slow: | |
| logger.exception("Slow tokenizer failed for %s: %s", model_name, e_slow) | |
| return None | |
| # ------------------------- | |
| # get_summarizer: lazy load + cache + fallback | |
| # ------------------------- | |
| def get_summarizer(key: str): | |
| """ | |
| key: 'pegasus'|'led'|'distilbart'|'auto' | |
| returns a pipeline (cached), or raises RuntimeError if no pipeline can be loaded. | |
| """ | |
| key = (key or "auto").lower() | |
| if key == "auto": | |
| key = _PREFERRED_SUMMARIZER_KEY or "distilbart" | |
| # direct mapping | |
| model_name = { | |
| "pegasus": PEGASUS_MODEL, | |
| "led": LED_MODEL, | |
| "distilbart": DISTILBART_MODEL | |
| }.get(key, DISTILBART_MODEL) | |
| if key in _SUMMARIZER_CACHE: | |
| return _SUMMARIZER_CACHE[key] | |
| # try to load | |
| logger.info("Attempting to lazy-load summarizer '%s' -> %s", key, model_name) | |
| pipe = safe_load_pipeline(model_name) | |
| if pipe: | |
| _SUMMARIZER_CACHE[key] = pipe | |
| return pipe | |
| # fallback attempts | |
| logger.warning("Failed to load %s. Trying distilbart fallback.", key) | |
| if "distilbart" in _SUMMARIZER_CACHE: | |
| return _SUMMARIZER_CACHE["distilbart"] | |
| fb = safe_load_pipeline(DISTILBART_MODEL) | |
| if fb: | |
| _SUMMARIZER_CACHE["distilbart"] = fb | |
| return fb | |
| # nothing works | |
| raise RuntimeError("No summarizer available. Install required libraries and/or choose smaller model.") | |
| # ------------------------- | |
| # Generation strategy + small helpers | |
| # ------------------------- | |
| def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) -> str: | |
| model_name = getattr(pipe.model.config, "name_or_path", "") or "" | |
| is_led = "led" in model_name.lower() or "longformer" in model_name.lower() | |
| # Fast sampling pass | |
| fast_cfg = { | |
| "max_new_tokens": 64 if short_target else (120 if not is_led else 240), | |
| "do_sample": True, | |
| "top_p": 0.92, | |
| "temperature": 0.85, | |
| "num_beams": 1, | |
| "early_stopping": True, | |
| "no_repeat_ngram_size": 3, | |
| } | |
| try: | |
| return pipe(text_prompt, **fast_cfg)[0].get("summary_text","").strip() | |
| except Exception as e: | |
| logger.warning("Fast pass failed: %s, trying quality pass...", e) | |
| quality_cfg = { | |
| "max_new_tokens": 140 if not is_led else 320, | |
| "do_sample": False, | |
| "num_beams": 3, | |
| "early_stopping": True, | |
| "no_repeat_ngram_size": 3, | |
| } | |
| try: | |
| return pipe(text_prompt, **quality_cfg)[0].get("summary_text","").strip() | |
| except Exception as e: | |
| logger.exception("Quality pass failed: %s", e) | |
| # fallback extractive | |
| try: | |
| return extractive_prefilter(text_prompt, top_k=3) | |
| except Exception: | |
| return "Summarization failed; try shorter input." | |
| # ------------------------- | |
| # Param generator (AI decision) - lazy loader | |
| # ------------------------- | |
| def get_param_generator(): | |
| global _PARAM_GENERATOR | |
| if _PARAM_GENERATOR is not None: | |
| return _PARAM_GENERATOR | |
| # try to load text2text pipeline for PARAM_MODEL | |
| try: | |
| logger.info("Loading param-generator (text2text) lazily: %s", PARAM_MODEL) | |
| tok = AutoTokenizer.from_pretrained(PARAM_MODEL) | |
| mod = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL) | |
| _PARAM_GENERATOR = pipeline("text2text-generation", model=mod, tokenizer=tok, device=DEVICE) | |
| logger.info("Param-generator loaded.") | |
| return _PARAM_GENERATOR | |
| except Exception as e: | |
| logger.exception("Param-generator lazy load failed: %s", e) | |
| _PARAM_GENERATOR = None | |
| return None | |
| def generate_summarization_config(text: str) -> Dict[str, Any]: | |
| defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)} | |
| pg = get_param_generator() | |
| if pg is None: | |
| words = len(text.split()) | |
| length = "short" if words < 150 else ("medium" if words < 800 else "long") | |
| mn, mx = defaults[length] | |
| return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"} | |
| prompt = ( | |
| "Recommend summarization settings for this text. Answer ONLY with JSON like:\n" | |
| '{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n' | |
| "Text:\n'''" | |
| + text[:3000] + "'''" | |
| ) | |
| try: | |
| out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1)[0] | |
| out = out_item.get("generated_text") or out_item.get("summary_text") or "" | |
| out = (out or "").strip() | |
| if not out: | |
| raise ValueError("Empty param-generator output") | |
| # reject noisy echo outputs | |
| input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200]) | |
| out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200]) | |
| if len(input_words) and (len(input_words & out_words) / max(1, len(input_words))) > 0.4: | |
| logger.warning("Param-generator appears to echo input; using heuristic fallback.") | |
| words = len(text.split()) | |
| length = "short" if words < 150 else ("medium" if words < 800 else "long") | |
| mn, mx = defaults[length] | |
| return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"} | |
| jmatch = re.search(r"\{.*\}", out, re.DOTALL) | |
| if jmatch: | |
| raw = jmatch.group().replace("'", '"') | |
| cfg = json.loads(raw) | |
| else: | |
| cfg = None | |
| if not cfg or not isinstance(cfg, dict): | |
| raise ValueError("Param output not parseable") | |
| length = cfg.get("length","medium").lower() | |
| tone = cfg.get("tone","neutral").lower() | |
| mn = int(cfg.get("min_words") or cfg.get("min_length") or defaults[length][0]) | |
| mx = int(cfg.get("max_words") or cfg.get("max_length") or defaults[length][1]) | |
| mn = max(5, min(mn, 2000)) | |
| mx = max(mn + 5, min(mx, 4000)) | |
| return {"length": length, "min_length": mn, "max_length": mx, "tone": tone} | |
| except Exception as e: | |
| logger.exception("Param-generator parse failed: %s", e) | |
| words = len(text.split()) | |
| length = "short" if words < 150 else ("medium" if words < 800 else "long") | |
| mn, mx = defaults[length] | |
| return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"} | |
| # ------------------------- | |
| # Threaded chunk summarization with per-chunk timeout (to prevent hang) | |
| # ------------------------- | |
| executor = ThreadPoolExecutor(max_workers=min(8, max(2, (os.cpu_count() or 2)))) | |
| CHUNK_TIMEOUT_SECONDS = 28 | |
| REFINE_TIMEOUT_SECONDS = 60 | |
| def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> List[str]: | |
| futures = {} | |
| results = [None] * len(chunks) | |
| for idx, chunk in enumerate(chunks): | |
| prompt = apply_tone_instruction(chunk, "neutral", target_sentences=chunk_target) | |
| fut = executor.submit(summarize_with_model, pipe, prompt, short_target=(chunk_target==1)) | |
| futures[fut] = idx | |
| start = time.time() | |
| for fut in as_completed(futures): | |
| idx = futures[fut] | |
| try: | |
| remaining = max(0.1, CHUNK_TIMEOUT_SECONDS - (time.time() - start)) | |
| results[idx] = fut.result(timeout=remaining) | |
| except TimeoutError: | |
| logger.warning("Chunk %d timed out; using extractive fallback.", idx) | |
| results[idx] = extractive_prefilter(chunks[idx], top_k=3) | |
| except Exception as e: | |
| logger.exception("Chunk %d failed: %s; falling back", idx, e) | |
| results[idx] = extractive_prefilter(chunks[idx], top_k=3) | |
| for i, r in enumerate(results): | |
| if not r: | |
| results[i] = extractive_prefilter(chunks[i], top_k=3) | |
| return results | |
| # ------------------------- | |
| # Prompt helpers and refine | |
| # ------------------------- | |
| def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int] = None) -> str: | |
| tone = (tone or "neutral").lower() | |
| if tone == "bullet": | |
| instr = "Produce concise bullet points. Each bullet <= 20 words. No extra commentary." | |
| elif tone == "short": | |
| ts = target_sentences or 1 | |
| instr = f"Summarize in {ts} sentence{'s' if ts>1 else ''}. Be abstractive." | |
| elif tone == "formal": | |
| instr = "Summarize in a formal, professional tone (2-4 sentences)." | |
| elif tone == "casual": | |
| instr = "Summarize in a casual, conversational tone (1-3 sentences)." | |
| elif tone == "long": | |
| instr = "Provide a structured summary (4-8 sentences)." | |
| else: | |
| instr = "Summarize in 2-3 clear sentences." | |
| instr += " Do not repeat information. Prefer rephrasing." | |
| return f"{instr}\n\nText:\n{text}" | |
| def refine_combined(pipe, summaries_list: List[str], tone: str, final_target_sentences: int) -> str: | |
| combined = "\n\n".join(summaries_list) | |
| if len(combined.split()) > 1200: | |
| combined = extractive_prefilter(combined, top_k=20) | |
| prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences) | |
| fut = executor.submit(summarize_with_model, pipe, prompt, short_target=False) | |
| try: | |
| return fut.result(timeout=REFINE_TIMEOUT_SECONDS) | |
| except TimeoutError: | |
| logger.warning("Refine timed out; returning concatenated chunk summaries.") | |
| return " ".join(summaries_list[:6]) | |
| except Exception as e: | |
| logger.exception("Refine failed: %s", e) | |
| return " ".join(summaries_list[:6]) | |
| # ------------------------- | |
| # Routes | |
| # ------------------------- | |
| def home(): | |
| try: | |
| return render_template("index.html") | |
| except Exception: | |
| return "Summarizer (lazy-load) — POST /summarize with JSON {text:'...'}", 200 | |
| def preload_models(): | |
| """ | |
| Explicit endpoint to attempt preloading heavy models. | |
| Call this only when you want the process to attempt loading Pegasus/LED (may be slow). | |
| """ | |
| results = {} | |
| for key, model_name in [("pegasus", PEGASUS_MODEL), ("led", LED_MODEL), ("distilbart", DISTILBART_MODEL)]: | |
| if key in _SUMMARIZER_CACHE: | |
| results[key] = "already_loaded" | |
| continue | |
| try: | |
| p = safe_load_pipeline(model_name) | |
| if p: | |
| _SUMMARIZER_CACHE[key] = p | |
| results[key] = "loaded" | |
| else: | |
| results[key] = "failed" | |
| except Exception as e: | |
| results[key] = f"error: {e}" | |
| return jsonify(results) | |
| def summarize_route(): | |
| t0 = time.time() | |
| data = request.get_json(force=True) or {} | |
| text = (data.get("text") or "").strip()[:90000] | |
| user_model_pref = (data.get("model") or "auto").lower() | |
| requested_length = (data.get("length") or "auto").lower() | |
| requested_tone = (data.get("tone") or "auto").lower() | |
| if not text or len(text.split()) < 5: | |
| return jsonify({"error": "Input too short."}), 400 | |
| # decide settings | |
| if requested_length in ("auto","ai") or requested_tone in ("auto","ai"): | |
| cfg = generate_summarization_config(text) | |
| length_choice = cfg.get("length","medium") | |
| tone_choice = cfg.get("tone","neutral") | |
| else: | |
| length_choice = requested_length if requested_length in ("short","medium","long") else "medium" | |
| tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral" | |
| # model selection logic | |
| words_len = len(text.split()) | |
| prefer_led = False | |
| if user_model_pref == "led": | |
| prefer_led = True | |
| elif user_model_pref == "pegasus": | |
| prefer_led = False | |
| else: | |
| if length_choice == "long" or words_len > 3000: | |
| prefer_led = True | |
| model_key = "led" if prefer_led else (_PREFERRED_SUMMARIZER_KEY or "distilbart") | |
| try: | |
| summarizer_pipe = get_summarizer(model_key) | |
| except Exception as e: | |
| logger.exception("get_summarizer failed (%s). Falling back to distilbart.", e) | |
| summarizer_pipe = get_summarizer("distilbart") | |
| model_key = "distilbart" | |
| # prefilter very long inputs for non-LED | |
| if model_key != "led" and words_len > 2500: | |
| text_for_chunks = extractive_prefilter(text, top_k=40) | |
| else: | |
| text_for_chunks = text | |
| # chunk sizing | |
| if model_key == "led": | |
| chunk_max = 6000 | |
| overlap = 400 | |
| else: | |
| chunk_max = 800 | |
| overlap = 120 | |
| chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max, overlap=overlap) | |
| chunk_target = 1 if length_choice == "short" else 2 | |
| # summarize chunks in parallel | |
| try: | |
| chunk_summaries = summarize_chunks_parallel(summarizer_pipe, chunks, chunk_target) | |
| except Exception as e: | |
| logger.exception("Chunk summarization orchestration failed: %s", e) | |
| chunk_summaries = [extractive_prefilter(c, top_k=3) for c in chunks] | |
| # refine step — prefer Pegasus if loaded, otherwise use current pipe | |
| refine_pipe = _SUMMARIZER_CACHE.get("pegasus") or summarizer_pipe | |
| final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice, 3) | |
| final = refine_combined(refine_pipe, chunk_summaries, tone_choice, final_target_sentences) | |
| # bullet postprocess | |
| if tone_choice == "bullet": | |
| parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final) | |
| bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()] | |
| final = "\n".join(bullets[:20]) | |
| elapsed = time.time() - t0 | |
| meta = { | |
| "length_choice": length_choice, | |
| "tone": tone_choice, | |
| "model_used": model_key, | |
| "chunks": len(chunks), | |
| "input_words": words_len, | |
| "time_seconds": round(elapsed, 2), | |
| "device": "cpu" | |
| } | |
| return jsonify({"summary": final, "meta": meta}) | |
| # ------------------------- | |
| # Local run (safe) | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| # For local testing you may call preload_models_at_startup manually or use /preload. | |
| app.run(host="0.0.0.0", port=7860, debug=False) | |