AIprojects / main.py
sayanAIAI's picture
Update main.py
dd9cbcf verified
raw
history blame
17.9 kB
import os
os.environ['HF_HOME'] = '/tmp'
import time
import json
import re
import logging
from collections import Counter
from typing import Optional, Dict, Any, List
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError
from flask import Flask, request, jsonify, render_template
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# -------------------------
# App + logging
# -------------------------
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")
# -------------------------
# Device selection (CPU by default)
# -------------------------
USE_GPU = False
DEVICE = -1
logger.info("Startup: forcing CPU usage for models (DEVICE=%s)", DEVICE)
# -------------------------
# Model names and caches
# -------------------------
PEGASUS_MODEL = "google/pegasus-large"
LED_MODEL = "allenai/led-large-16384"
DISTILBART_MODEL = "sshleifer/distilbart-cnn-12-6"
PARAM_MODEL = "google/flan-t5-small"
_SUMMARIZER_CACHE: Dict[str, Any] = {}
_PARAM_GENERATOR = None
_PREFERRED_SUMMARIZER_KEY: Optional[str] = None
# -------------------------
# Utilities: chunking, extractive fallback
# -------------------------
_STOPWORDS = {
"the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
}
def tokenize_sentences(text: str) -> List[str]:
sents = re.split(r'(?<=[.!?])\s+', text.strip())
return [s.strip() for s in sents if s.strip()]
def extractive_prefilter(text: str, top_k: int = 6) -> str:
sents = tokenize_sentences(text)
if len(sents) <= top_k:
return text
words = re.findall(r"\w+", text.lower())
freqs = Counter(w for w in words if w not in _STOPWORDS)
scored = []
for i, s in enumerate(sents):
ws = re.findall(r"\w+", s.lower())
score = sum(freqs.get(w, 0) for w in ws)
scored.append((score, i, s))
scored.sort(reverse=True)
chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])]
return " ".join(chosen)
def chunk_text_by_chars(text: str, max_chars: int = 800, overlap: int = 120) -> List[str]:
n = len(text)
if n <= max_chars:
return [text]
parts = []
start = 0
prev_start = -1
while start < n and start != prev_start:
prev_start = start
end = min(n, start + max_chars)
chunk = text[start:end]
nl = chunk.rfind('\n')
if nl > int(max_chars * 0.6):
end = start + nl
chunk = text[start:end]
parts.append(chunk.strip())
start = end - overlap
if start <= prev_start:
start = end
return [p for p in parts if p]
# -------------------------
# safe loader (defined before any calls)
# -------------------------
def safe_load_pipeline(model_name: str):
"""
Try to load a summarization pipeline robustly:
- try fast tokenizer first
- if that fails, try use_fast=False
- return pipeline or None if both fail
"""
try:
logger.info("Loading tokenizer/model for %s (fast)...", model_name)
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
logger.info("Loaded %s (fast tokenizer)", model_name)
return pipe
except Exception as e_fast:
logger.warning("Fast tokenizer failed for %s: %s. Trying slow tokenizer...", model_name, e_fast)
try:
tok = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
logger.info("Loaded %s (slow tokenizer)", model_name)
return pipe
except Exception as e_slow:
logger.exception("Slow tokenizer failed for %s: %s", model_name, e_slow)
return None
# -------------------------
# get_summarizer: lazy load + cache + fallback
# -------------------------
def get_summarizer(key: str):
"""
key: 'pegasus'|'led'|'distilbart'|'auto'
returns a pipeline (cached), or raises RuntimeError if no pipeline can be loaded.
"""
key = (key or "auto").lower()
if key == "auto":
key = _PREFERRED_SUMMARIZER_KEY or "distilbart"
# direct mapping
model_name = {
"pegasus": PEGASUS_MODEL,
"led": LED_MODEL,
"distilbart": DISTILBART_MODEL
}.get(key, DISTILBART_MODEL)
if key in _SUMMARIZER_CACHE:
return _SUMMARIZER_CACHE[key]
# try to load
logger.info("Attempting to lazy-load summarizer '%s' -> %s", key, model_name)
pipe = safe_load_pipeline(model_name)
if pipe:
_SUMMARIZER_CACHE[key] = pipe
return pipe
# fallback attempts
logger.warning("Failed to load %s. Trying distilbart fallback.", key)
if "distilbart" in _SUMMARIZER_CACHE:
return _SUMMARIZER_CACHE["distilbart"]
fb = safe_load_pipeline(DISTILBART_MODEL)
if fb:
_SUMMARIZER_CACHE["distilbart"] = fb
return fb
# nothing works
raise RuntimeError("No summarizer available. Install required libraries and/or choose smaller model.")
# -------------------------
# Generation strategy + small helpers
# -------------------------
def summarize_with_model(pipe, text_prompt: str, short_target: bool = False) -> str:
model_name = getattr(pipe.model.config, "name_or_path", "") or ""
is_led = "led" in model_name.lower() or "longformer" in model_name.lower()
# Fast sampling pass
fast_cfg = {
"max_new_tokens": 64 if short_target else (120 if not is_led else 240),
"do_sample": True,
"top_p": 0.92,
"temperature": 0.85,
"num_beams": 1,
"early_stopping": True,
"no_repeat_ngram_size": 3,
}
try:
return pipe(text_prompt, **fast_cfg)[0].get("summary_text","").strip()
except Exception as e:
logger.warning("Fast pass failed: %s, trying quality pass...", e)
quality_cfg = {
"max_new_tokens": 140 if not is_led else 320,
"do_sample": False,
"num_beams": 3,
"early_stopping": True,
"no_repeat_ngram_size": 3,
}
try:
return pipe(text_prompt, **quality_cfg)[0].get("summary_text","").strip()
except Exception as e:
logger.exception("Quality pass failed: %s", e)
# fallback extractive
try:
return extractive_prefilter(text_prompt, top_k=3)
except Exception:
return "Summarization failed; try shorter input."
# -------------------------
# Param generator (AI decision) - lazy loader
# -------------------------
def get_param_generator():
global _PARAM_GENERATOR
if _PARAM_GENERATOR is not None:
return _PARAM_GENERATOR
# try to load text2text pipeline for PARAM_MODEL
try:
logger.info("Loading param-generator (text2text) lazily: %s", PARAM_MODEL)
tok = AutoTokenizer.from_pretrained(PARAM_MODEL)
mod = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
_PARAM_GENERATOR = pipeline("text2text-generation", model=mod, tokenizer=tok, device=DEVICE)
logger.info("Param-generator loaded.")
return _PARAM_GENERATOR
except Exception as e:
logger.exception("Param-generator lazy load failed: %s", e)
_PARAM_GENERATOR = None
return None
def generate_summarization_config(text: str) -> Dict[str, Any]:
defaults = {"short": (12, 50), "medium": (50, 130), "long": (130, 300)}
pg = get_param_generator()
if pg is None:
words = len(text.split())
length = "short" if words < 150 else ("medium" if words < 800 else "long")
mn, mx = defaults[length]
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
prompt = (
"Recommend summarization settings for this text. Answer ONLY with JSON like:\n"
'{"length":"short|medium|long","tone":"neutral|formal|casual|bullet","min_words":MIN,"max_words":MAX}\n\n'
"Text:\n'''"
+ text[:3000] + "'''"
)
try:
out_item = pg(prompt, max_new_tokens=64, do_sample=False, num_beams=1)[0]
out = out_item.get("generated_text") or out_item.get("summary_text") or ""
out = (out or "").strip()
if not out:
raise ValueError("Empty param-generator output")
# reject noisy echo outputs
input_words = set(w.lower() for w in re.findall(r"\w+", text)[:200])
out_words = set(w.lower() for w in re.findall(r"\w+", out)[:200])
if len(input_words) and (len(input_words & out_words) / max(1, len(input_words))) > 0.4:
logger.warning("Param-generator appears to echo input; using heuristic fallback.")
words = len(text.split())
length = "short" if words < 150 else ("medium" if words < 800 else "long")
mn, mx = defaults[length]
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
jmatch = re.search(r"\{.*\}", out, re.DOTALL)
if jmatch:
raw = jmatch.group().replace("'", '"')
cfg = json.loads(raw)
else:
cfg = None
if not cfg or not isinstance(cfg, dict):
raise ValueError("Param output not parseable")
length = cfg.get("length","medium").lower()
tone = cfg.get("tone","neutral").lower()
mn = int(cfg.get("min_words") or cfg.get("min_length") or defaults[length][0])
mx = int(cfg.get("max_words") or cfg.get("max_length") or defaults[length][1])
mn = max(5, min(mn, 2000))
mx = max(mn + 5, min(mx, 4000))
return {"length": length, "min_length": mn, "max_length": mx, "tone": tone}
except Exception as e:
logger.exception("Param-generator parse failed: %s", e)
words = len(text.split())
length = "short" if words < 150 else ("medium" if words < 800 else "long")
mn, mx = defaults[length]
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
# -------------------------
# Threaded chunk summarization with per-chunk timeout (to prevent hang)
# -------------------------
executor = ThreadPoolExecutor(max_workers=min(8, max(2, (os.cpu_count() or 2))))
CHUNK_TIMEOUT_SECONDS = 28
REFINE_TIMEOUT_SECONDS = 60
def summarize_chunks_parallel(pipe, chunks: List[str], chunk_target: int) -> List[str]:
futures = {}
results = [None] * len(chunks)
for idx, chunk in enumerate(chunks):
prompt = apply_tone_instruction(chunk, "neutral", target_sentences=chunk_target)
fut = executor.submit(summarize_with_model, pipe, prompt, short_target=(chunk_target==1))
futures[fut] = idx
start = time.time()
for fut in as_completed(futures):
idx = futures[fut]
try:
remaining = max(0.1, CHUNK_TIMEOUT_SECONDS - (time.time() - start))
results[idx] = fut.result(timeout=remaining)
except TimeoutError:
logger.warning("Chunk %d timed out; using extractive fallback.", idx)
results[idx] = extractive_prefilter(chunks[idx], top_k=3)
except Exception as e:
logger.exception("Chunk %d failed: %s; falling back", idx, e)
results[idx] = extractive_prefilter(chunks[idx], top_k=3)
for i, r in enumerate(results):
if not r:
results[i] = extractive_prefilter(chunks[i], top_k=3)
return results
# -------------------------
# Prompt helpers and refine
# -------------------------
def apply_tone_instruction(text: str, tone: str, target_sentences: Optional[int] = None) -> str:
tone = (tone or "neutral").lower()
if tone == "bullet":
instr = "Produce concise bullet points. Each bullet <= 20 words. No extra commentary."
elif tone == "short":
ts = target_sentences or 1
instr = f"Summarize in {ts} sentence{'s' if ts>1 else ''}. Be abstractive."
elif tone == "formal":
instr = "Summarize in a formal, professional tone (2-4 sentences)."
elif tone == "casual":
instr = "Summarize in a casual, conversational tone (1-3 sentences)."
elif tone == "long":
instr = "Provide a structured summary (4-8 sentences)."
else:
instr = "Summarize in 2-3 clear sentences."
instr += " Do not repeat information. Prefer rephrasing."
return f"{instr}\n\nText:\n{text}"
def refine_combined(pipe, summaries_list: List[str], tone: str, final_target_sentences: int) -> str:
combined = "\n\n".join(summaries_list)
if len(combined.split()) > 1200:
combined = extractive_prefilter(combined, top_k=20)
prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
fut = executor.submit(summarize_with_model, pipe, prompt, short_target=False)
try:
return fut.result(timeout=REFINE_TIMEOUT_SECONDS)
except TimeoutError:
logger.warning("Refine timed out; returning concatenated chunk summaries.")
return " ".join(summaries_list[:6])
except Exception as e:
logger.exception("Refine failed: %s", e)
return " ".join(summaries_list[:6])
# -------------------------
# Routes
# -------------------------
@app.route("/", methods=["GET"])
def home():
try:
return render_template("index.html")
except Exception:
return "Summarizer (lazy-load) — POST /summarize with JSON {text:'...'}", 200
@app.route("/preload", methods=["POST"])
def preload_models():
"""
Explicit endpoint to attempt preloading heavy models.
Call this only when you want the process to attempt loading Pegasus/LED (may be slow).
"""
results = {}
for key, model_name in [("pegasus", PEGASUS_MODEL), ("led", LED_MODEL), ("distilbart", DISTILBART_MODEL)]:
if key in _SUMMARIZER_CACHE:
results[key] = "already_loaded"
continue
try:
p = safe_load_pipeline(model_name)
if p:
_SUMMARIZER_CACHE[key] = p
results[key] = "loaded"
else:
results[key] = "failed"
except Exception as e:
results[key] = f"error: {e}"
return jsonify(results)
@app.route("/summarize", methods=["POST"])
def summarize_route():
t0 = time.time()
data = request.get_json(force=True) or {}
text = (data.get("text") or "").strip()[:90000]
user_model_pref = (data.get("model") or "auto").lower()
requested_length = (data.get("length") or "auto").lower()
requested_tone = (data.get("tone") or "auto").lower()
if not text or len(text.split()) < 5:
return jsonify({"error": "Input too short."}), 400
# decide settings
if requested_length in ("auto","ai") or requested_tone in ("auto","ai"):
cfg = generate_summarization_config(text)
length_choice = cfg.get("length","medium")
tone_choice = cfg.get("tone","neutral")
else:
length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
# model selection logic
words_len = len(text.split())
prefer_led = False
if user_model_pref == "led":
prefer_led = True
elif user_model_pref == "pegasus":
prefer_led = False
else:
if length_choice == "long" or words_len > 3000:
prefer_led = True
model_key = "led" if prefer_led else (_PREFERRED_SUMMARIZER_KEY or "distilbart")
try:
summarizer_pipe = get_summarizer(model_key)
except Exception as e:
logger.exception("get_summarizer failed (%s). Falling back to distilbart.", e)
summarizer_pipe = get_summarizer("distilbart")
model_key = "distilbart"
# prefilter very long inputs for non-LED
if model_key != "led" and words_len > 2500:
text_for_chunks = extractive_prefilter(text, top_k=40)
else:
text_for_chunks = text
# chunk sizing
if model_key == "led":
chunk_max = 6000
overlap = 400
else:
chunk_max = 800
overlap = 120
chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max, overlap=overlap)
chunk_target = 1 if length_choice == "short" else 2
# summarize chunks in parallel
try:
chunk_summaries = summarize_chunks_parallel(summarizer_pipe, chunks, chunk_target)
except Exception as e:
logger.exception("Chunk summarization orchestration failed: %s", e)
chunk_summaries = [extractive_prefilter(c, top_k=3) for c in chunks]
# refine step — prefer Pegasus if loaded, otherwise use current pipe
refine_pipe = _SUMMARIZER_CACHE.get("pegasus") or summarizer_pipe
final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice, 3)
final = refine_combined(refine_pipe, chunk_summaries, tone_choice, final_target_sentences)
# bullet postprocess
if tone_choice == "bullet":
parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
final = "\n".join(bullets[:20])
elapsed = time.time() - t0
meta = {
"length_choice": length_choice,
"tone": tone_choice,
"model_used": model_key,
"chunks": len(chunks),
"input_words": words_len,
"time_seconds": round(elapsed, 2),
"device": "cpu"
}
return jsonify({"summary": final, "meta": meta})
# -------------------------
# Local run (safe)
# -------------------------
if __name__ == "__main__":
# For local testing you may call preload_models_at_startup manually or use /preload.
app.run(host="0.0.0.0", port=7860, debug=False)