Spaces:
Runtime error
Runtime error
| import os | |
| os.environ['HF_HOME'] = '/tmp' | |
| import time | |
| import json | |
| import re | |
| import logging | |
| from collections import Counter | |
| from typing import Optional | |
| from flask import Flask, request, jsonify, render_template | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # ------------------------- | |
| # App + logging | |
| # ------------------------- | |
| app = Flask(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("summarizer") | |
| # ------------------------- | |
| # Device selection | |
| # ------------------------- | |
| USE_GPU = torch.cuda.is_available() | |
| DEVICE = 0 if USE_GPU else -1 | |
| logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE) | |
| # ------------------------- | |
| # Model names (we'll load summarizers lazily) | |
| # ------------------------- | |
| PEGASUS_MODEL = "google/pegasus-large" | |
| LED_MODEL = "allenai/led-large-16384" | |
| PARAM_MODEL = "google/flan-t5-small" # instruction model for parameter generation | |
| # caches for lazy-loaded pipelines | |
| _SUMMARIZER_CACHE = {} | |
| # load the small param-generator right away (keeps it small) | |
| logger.info("Loading parameter generator model: %s", PARAM_MODEL) | |
| param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL) | |
| param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL) | |
| param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE) | |
| # ------------------------- | |
| # Presets & utils | |
| # ------------------------- | |
| LENGTH_PRESETS = { | |
| "short": {"min_length": 20, "max_length": 60}, | |
| "medium": {"min_length": 60, "max_length": 130}, | |
| "long": {"min_length": 130, "max_length": 300}, | |
| } | |
| _STOPWORDS = { | |
| "the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have" | |
| } | |
| def tokenize_sentences(text): | |
| sents = re.split(r'(?<=[.!?])\s+', text.strip()) | |
| return [s.strip() for s in sents if s.strip()] | |
| def extractive_prefilter(text, top_k=12): | |
| sents = tokenize_sentences(text) | |
| if len(sents) <= top_k: | |
| return text | |
| words = re.findall(r"\w+", text.lower()) | |
| freqs = Counter(w for w in words if w not in _STOPWORDS) | |
| scored = [] | |
| for i, s in enumerate(sents): | |
| ws = re.findall(r"\w+", s.lower()) | |
| score = sum(freqs.get(w, 0) for w in ws) | |
| scored.append((score, i, s)) | |
| scored.sort(reverse=True) | |
| chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])] | |
| return " ".join(chosen) | |
| def chunk_text_by_chars(text, max_chars=1500, overlap=200): | |
| if len(text) <= max_chars: | |
| return [text] | |
| parts = [] | |
| start = 0 | |
| while start < len(text): | |
| end = min(len(text), start + max_chars) | |
| chunk = text[start:end] | |
| nl = chunk.rfind('\n') | |
| if nl > max_chars * 0.6: | |
| end = start + nl | |
| chunk = text[start:end] | |
| parts.append(chunk.strip()) | |
| start = max(end - overlap, end) | |
| return parts | |
| def _first_int_from_text(s, fallback=None): | |
| m = re.search(r"\d{1,4}", s) | |
| return int(m.group()) if m else fallback | |
| # ------------------------- | |
| # Lazy summarizer loader | |
| # ------------------------- | |
| def get_summarizer(model_key: str): | |
| """ | |
| Returns a pipeline summarizer for 'pegasus' or 'led', loading it lazily. | |
| model_key: "pegasus" or "led" | |
| """ | |
| model_key = model_key.lower() | |
| if model_key in _SUMMARIZER_CACHE: | |
| return _SUMMARIZER_CACHE[model_key] | |
| if model_key == "pegasus": | |
| model_name = PEGASUS_MODEL | |
| elif model_key == "led": | |
| model_name = LED_MODEL | |
| else: | |
| raise ValueError("Unknown model_key: " + str(model_key)) | |
| logger.info("Loading summarizer model '%s' (%s) on device %s ...", model_key, model_name, DEVICE) | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE) | |
| _SUMMARIZER_CACHE[model_key] = pipe | |
| logger.info("Loaded summarizer '%s' successfully.", model_key) | |
| return pipe | |
| # ------------------------- | |
| # Prompt and decision logic | |
| # ------------------------- | |
| def apply_tone_instruction(text, tone, target_sentences=None): | |
| tone = (tone or "neutral").lower() | |
| if tone == "bullet": | |
| instr = "Produce concise bullet points. Each bullet short (<=20 words). No extra commentary." | |
| elif tone == "short": | |
| ts = target_sentences or 1 | |
| instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim." | |
| elif tone == "formal": | |
| instr = "Summarize in a formal, professional tone in 2-4 sentences. Keep it precise and well-structured." | |
| elif tone == "casual": | |
| instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language." | |
| elif tone == "long": | |
| instr = "Provide a clear, structured summary in 4-8 sentences covering key points and context." | |
| else: | |
| instr = "Summarize the content in 2-3 sentences. Be clear and concise." | |
| instr += " Do not repeat information; prefer rephrasing." | |
| return f"{instr}\n\nText:\n{text}" | |
| def generate_summarization_config(text): | |
| """ | |
| Ask small instruction model for settings; fallback to heuristic. | |
| Returns dict with keys: length, min_length, max_length, tone | |
| """ | |
| prompt = ( | |
| "You are an assistant that recommends summarization settings.\n" | |
| "Given the text, respond ONLY with single-line JSON EXACTLY like:\n" | |
| '{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n' | |
| "Text:\n'''" | |
| + text[:4000] + "'''" | |
| ) | |
| try: | |
| out = param_generator( | |
| prompt, | |
| max_new_tokens=64, | |
| num_beams=1, | |
| do_sample=False, | |
| early_stopping=True | |
| )[0].get("generated_text","").strip() | |
| cfg = None | |
| try: | |
| cfg = json.loads(out) | |
| except Exception: | |
| j = re.search(r"\{.*\}", out, re.DOTALL) | |
| if j: | |
| raw = j.group().replace("'", '"') | |
| cfg = json.loads(raw) | |
| if not cfg: | |
| raise ValueError("Unparseable param-generator output") | |
| length = cfg.get("length","").lower() | |
| tone = cfg.get("tone","").lower() | |
| min_w = cfg.get("min_words") | |
| max_w = cfg.get("max_words") | |
| if length not in ("short","medium","long"): | |
| words = len(text.split()) | |
| length = "short" if words < 150 else ("medium" if words < 800 else "long") | |
| if tone not in ("neutral","formal","casual","bullet"): | |
| tone = "neutral" | |
| if not isinstance(min_w,int): | |
| min_w = _first_int_from_text(out, fallback=None) | |
| if not isinstance(max_w,int): | |
| max_w = _first_int_from_text(out[::-1], fallback=None) | |
| defaults = {"short":(15,50),"medium":(50,130),"long":(130,300)} | |
| dmin,dmax = defaults.get(length,(50,130)) | |
| min_len = int(min_w) if isinstance(min_w,int) else dmin | |
| max_len = int(max_w) if isinstance(max_w,int) else dmax | |
| min_len = max(5, min(min_len, 2000)) | |
| max_len = max(min_len+5, min(max_len, 4000)) | |
| logger.info("Param-generator chose length=%s tone=%s min=%s max=%s", length, tone, min_len, max_len) | |
| return {"length":length,"min_length":min_len,"max_length":max_len,"tone":tone} | |
| except Exception as e: | |
| logger.exception("Param-generator failed: %s", e) | |
| words = len(text.split()) | |
| length = "short" if words < 150 else ("medium" if words < 800 else "long") | |
| fallback = {"short":(15,50),"medium":(50,130),"long":(130,300)} | |
| mn,mx = fallback[length] | |
| return {"length":length,"min_length":mn,"max_length":mx,"tone":"neutral"} | |
| # ------------------------- | |
| # Two-stage summarization (chunk -> chunk summaries -> refine) | |
| # ------------------------- | |
| def refine_and_combine(summaries_list, tone, final_target_sentences=None, summarizer_pipe=None): | |
| combined = "\n\n".join(summaries_list) | |
| if len(combined.split()) > 2000: | |
| combined = extractive_prefilter(combined, top_k=20) | |
| prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences) | |
| tgt_sent = final_target_sentences or 3 | |
| gen_kwargs = { | |
| "min_length": max(20, int(tgt_sent * 8)), | |
| "max_length": max(60, int(tgt_sent * 30)), | |
| "num_beams": 6, | |
| "early_stopping": True, | |
| "no_repeat_ngram_size": 3, | |
| "do_sample": False, | |
| } | |
| try: | |
| if summarizer_pipe is None: | |
| # fallback to pegasus by default (if pipe not provided) | |
| summarizer_pipe = get_summarizer("pegasus") | |
| out = summarizer_pipe(prompt, **gen_kwargs)[0]["summary_text"].strip() | |
| return out | |
| except Exception as e: | |
| logger.exception("Refine failed: %s", e) | |
| return " ".join(summaries_list[:3]) | |
| # ------------------------- | |
| # Model-specific generation helper | |
| # ------------------------- | |
| def summarize_with_model(pipe, text_prompt, short_target=False): | |
| """ | |
| Use model pipeline with conservative and model-appropriate generation settings. | |
| short_target: if True use shorter min/max suitable for concise outputs | |
| """ | |
| # heuristics: if pipe is LED (model name in tied tokenizer), allow larger max_length | |
| model_name = getattr(pipe.model.config, "name_or_path", "") or "" | |
| is_led = "led" in model_name or "longformer" in model_name or "allenai" in model_name and "led" in model_name | |
| if short_target: | |
| min_l = 12 | |
| max_l = 60 | |
| else: | |
| min_l = 24 | |
| max_l = 140 if not is_led else 400 # LED can handle longer outputs | |
| gen_kwargs = { | |
| "min_length": min_l, | |
| "max_length": max_l, | |
| "num_beams": 5 if not is_led else 4, | |
| "early_stopping": True, | |
| "no_repeat_ngram_size": 3, | |
| "do_sample": False, | |
| } | |
| return pipe(text_prompt, **gen_kwargs)[0]["summary_text"].strip() | |
| # ------------------------- | |
| # Routes | |
| # ------------------------- | |
| def home(): | |
| return render_template("index.html") | |
| def summarize_route(): | |
| t0 = time.time() | |
| data = request.get_json(force=True) or {} | |
| text = (data.get("text") or "")[:90000] | |
| user_model_pref = (data.get("model") or "auto").lower() # 'pegasus' | 'led' | 'auto' | |
| requested_length = (data.get("length") or "auto").lower() # short|medium|long|auto | |
| requested_tone = (data.get("tone") or "auto").lower() # neutral|formal|casual|bullet|auto | |
| if not text or len(text.split()) < 5: | |
| return jsonify({"error":"Input too short."}), 400 | |
| # 1) Decide settings (AI or explicit) | |
| if requested_length in ("auto","ai") or requested_tone in ("auto","ai"): | |
| cfg = generate_summarization_config(text) | |
| length_choice = cfg.get("length","medium") | |
| tone_choice = cfg.get("tone","neutral") | |
| preset_min = cfg.get("min_length") | |
| preset_max = cfg.get("max_length") | |
| else: | |
| length_choice = requested_length if requested_length in ("short","medium","long") else "medium" | |
| tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral" | |
| preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"] | |
| preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"] | |
| # 2) Model selection (user preference or auto) | |
| # auto rules: if user specifically asked 'led' or param-generator picked long / input is very long -> led | |
| words_len = len(text.split()) | |
| prefer_led = False | |
| if user_model_pref == "led": | |
| prefer_led = True | |
| elif user_model_pref == "pegasus": | |
| prefer_led = False | |
| else: # auto | |
| if length_choice == "long" or words_len > 3000: | |
| prefer_led = True | |
| else: | |
| prefer_led = False | |
| model_key = "led" if prefer_led else "pegasus" | |
| # get the pipeline (lazy load) | |
| try: | |
| summarizer_pipe = get_summarizer(model_key) | |
| except Exception as e: | |
| logger.exception("Failed to load summarizer '%s': %s", model_key, e) | |
| # fallback to pegasus if led fails | |
| summarizer_pipe = get_summarizer("pegasus") | |
| model_key = "pegasus" | |
| # 3) Prefilter very long inputs (if not using LED) | |
| if not prefer_led and words_len > 2500: | |
| text_for_chunks = extractive_prefilter(text, top_k=40) | |
| else: | |
| text_for_chunks = text | |
| # 4) Chunking: choose chunk size depending on model | |
| if model_key == "led": | |
| chunk_max_chars = 8000 # LED can handle larger chunks | |
| chunk_overlap = 400 | |
| else: | |
| chunk_max_chars = 1400 | |
| chunk_overlap = 200 | |
| chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max_chars, overlap=chunk_overlap) | |
| # 5) Summarize each chunk | |
| chunk_summaries = [] | |
| for chunk in chunks: | |
| chunk_target = 1 if length_choice == "short" else 2 | |
| chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral" | |
| prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target) | |
| try: | |
| # choose short_target True for tiny chunk summaries | |
| out = summarize_with_model(summarizer_pipe, prompt, short_target=(chunk_target==1)) | |
| except Exception as e: | |
| logger.exception("Chunk summarization failed, using extractive fallback: %s", e) | |
| out = extractive_prefilter(chunk, top_k=3) | |
| chunk_summaries.append(out) | |
| # 6) Combine + refine using the same model for consistency (or prefer Pegasus for elegant refinement) | |
| refine_model_key = model_key if model_key == "led" else "pegasus" | |
| refine_pipe = get_summarizer(refine_model_key) | |
| final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice,3) | |
| final = refine_and_combine(chunk_summaries, tone_choice, final_target_sentences, summarizer_pipe=refine_pipe) | |
| # 7) Post-process bullet tone | |
| if tone_choice == "bullet": | |
| parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final) | |
| bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()] | |
| final = "\n".join(bullets[:20]) | |
| elapsed = time.time() - t0 | |
| meta = { | |
| "length_choice": length_choice, | |
| "tone": tone_choice, | |
| "model_used": model_key, | |
| "refine_model": refine_model_key, | |
| "chunks": len(chunks), | |
| "input_words": words_len, | |
| "time_seconds": round(elapsed, 2), | |
| "device": ("gpu" if USE_GPU else "cpu") | |
| } | |
| return jsonify({"summary": final, "meta": meta}) | |
| # ------------------------- | |
| # Run | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| # debug=False for production; use Gunicorn in deployment | |
| app.run(host="0.0.0.0", port=7860, debug=False) | |