AIprojects / main.py
sayanAIAI's picture
Update main.py
bb6c458 verified
raw
history blame
14.9 kB
import os
os.environ['HF_HOME'] = '/tmp'
import time
import json
import re
import logging
from collections import Counter
from typing import Optional
from flask import Flask, request, jsonify, render_template
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# -------------------------
# App + logging
# -------------------------
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")
# -------------------------
# Device selection
# -------------------------
USE_GPU = torch.cuda.is_available()
DEVICE = 0 if USE_GPU else -1
logger.info("CUDA available: %s. Using device: %s", USE_GPU, DEVICE)
# -------------------------
# Model names (we'll load summarizers lazily)
# -------------------------
PEGASUS_MODEL = "google/pegasus-large"
LED_MODEL = "allenai/led-large-16384"
PARAM_MODEL = "google/flan-t5-small" # instruction model for parameter generation
# caches for lazy-loaded pipelines
_SUMMARIZER_CACHE = {}
# load the small param-generator right away (keeps it small)
logger.info("Loading parameter generator model: %s", PARAM_MODEL)
param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL)
param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL)
param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=DEVICE)
# -------------------------
# Presets & utils
# -------------------------
LENGTH_PRESETS = {
"short": {"min_length": 20, "max_length": 60},
"medium": {"min_length": 60, "max_length": 130},
"long": {"min_length": 130, "max_length": 300},
}
_STOPWORDS = {
"the","and","is","in","to","of","a","that","it","on","for","as","are","with","was","be","by","this","an","or","from","at","which","we","has","have"
}
def tokenize_sentences(text):
sents = re.split(r'(?<=[.!?])\s+', text.strip())
return [s.strip() for s in sents if s.strip()]
def extractive_prefilter(text, top_k=12):
sents = tokenize_sentences(text)
if len(sents) <= top_k:
return text
words = re.findall(r"\w+", text.lower())
freqs = Counter(w for w in words if w not in _STOPWORDS)
scored = []
for i, s in enumerate(sents):
ws = re.findall(r"\w+", s.lower())
score = sum(freqs.get(w, 0) for w in ws)
scored.append((score, i, s))
scored.sort(reverse=True)
chosen = [s for _, _, s in sorted(scored[:top_k], key=lambda t: t[1])]
return " ".join(chosen)
def chunk_text_by_chars(text, max_chars=1500, overlap=200):
if len(text) <= max_chars:
return [text]
parts = []
start = 0
while start < len(text):
end = min(len(text), start + max_chars)
chunk = text[start:end]
nl = chunk.rfind('\n')
if nl > max_chars * 0.6:
end = start + nl
chunk = text[start:end]
parts.append(chunk.strip())
start = max(end - overlap, end)
return parts
def _first_int_from_text(s, fallback=None):
m = re.search(r"\d{1,4}", s)
return int(m.group()) if m else fallback
# -------------------------
# Lazy summarizer loader
# -------------------------
def get_summarizer(model_key: str):
"""
Returns a pipeline summarizer for 'pegasus' or 'led', loading it lazily.
model_key: "pegasus" or "led"
"""
model_key = model_key.lower()
if model_key in _SUMMARIZER_CACHE:
return _SUMMARIZER_CACHE[model_key]
if model_key == "pegasus":
model_name = PEGASUS_MODEL
elif model_key == "led":
model_name = LED_MODEL
else:
raise ValueError("Unknown model_key: " + str(model_key))
logger.info("Loading summarizer model '%s' (%s) on device %s ...", model_key, model_name, DEVICE)
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline("summarization", model=model, tokenizer=tok, device=DEVICE)
_SUMMARIZER_CACHE[model_key] = pipe
logger.info("Loaded summarizer '%s' successfully.", model_key)
return pipe
# -------------------------
# Prompt and decision logic
# -------------------------
def apply_tone_instruction(text, tone, target_sentences=None):
tone = (tone or "neutral").lower()
if tone == "bullet":
instr = "Produce concise bullet points. Each bullet short (<=20 words). No extra commentary."
elif tone == "short":
ts = target_sentences or 1
instr = f"Summarize the content in {ts} sentence{'s' if ts>1 else ''}. Be highly abstractive and avoid copying sentences verbatim."
elif tone == "formal":
instr = "Summarize in a formal, professional tone in 2-4 sentences. Keep it precise and well-structured."
elif tone == "casual":
instr = "Summarize in a casual, conversational tone in 1-3 sentences. Use plain, friendly language."
elif tone == "long":
instr = "Provide a clear, structured summary in 4-8 sentences covering key points and context."
else:
instr = "Summarize the content in 2-3 sentences. Be clear and concise."
instr += " Do not repeat information; prefer rephrasing."
return f"{instr}\n\nText:\n{text}"
def generate_summarization_config(text):
"""
Ask small instruction model for settings; fallback to heuristic.
Returns dict with keys: length, min_length, max_length, tone
"""
prompt = (
"You are an assistant that recommends summarization settings.\n"
"Given the text, respond ONLY with single-line JSON EXACTLY like:\n"
'{"length":"short|medium|long","min_words":MIN,"max_words":MAX,"tone":"neutral|formal|casual|bullet"}\n\n'
"Text:\n'''"
+ text[:4000] + "'''"
)
try:
out = param_generator(
prompt,
max_new_tokens=64,
num_beams=1,
do_sample=False,
early_stopping=True
)[0].get("generated_text","").strip()
cfg = None
try:
cfg = json.loads(out)
except Exception:
j = re.search(r"\{.*\}", out, re.DOTALL)
if j:
raw = j.group().replace("'", '"')
cfg = json.loads(raw)
if not cfg:
raise ValueError("Unparseable param-generator output")
length = cfg.get("length","").lower()
tone = cfg.get("tone","").lower()
min_w = cfg.get("min_words")
max_w = cfg.get("max_words")
if length not in ("short","medium","long"):
words = len(text.split())
length = "short" if words < 150 else ("medium" if words < 800 else "long")
if tone not in ("neutral","formal","casual","bullet"):
tone = "neutral"
if not isinstance(min_w,int):
min_w = _first_int_from_text(out, fallback=None)
if not isinstance(max_w,int):
max_w = _first_int_from_text(out[::-1], fallback=None)
defaults = {"short":(15,50),"medium":(50,130),"long":(130,300)}
dmin,dmax = defaults.get(length,(50,130))
min_len = int(min_w) if isinstance(min_w,int) else dmin
max_len = int(max_w) if isinstance(max_w,int) else dmax
min_len = max(5, min(min_len, 2000))
max_len = max(min_len+5, min(max_len, 4000))
logger.info("Param-generator chose length=%s tone=%s min=%s max=%s", length, tone, min_len, max_len)
return {"length":length,"min_length":min_len,"max_length":max_len,"tone":tone}
except Exception as e:
logger.exception("Param-generator failed: %s", e)
words = len(text.split())
length = "short" if words < 150 else ("medium" if words < 800 else "long")
fallback = {"short":(15,50),"medium":(50,130),"long":(130,300)}
mn,mx = fallback[length]
return {"length":length,"min_length":mn,"max_length":mx,"tone":"neutral"}
# -------------------------
# Two-stage summarization (chunk -> chunk summaries -> refine)
# -------------------------
def refine_and_combine(summaries_list, tone, final_target_sentences=None, summarizer_pipe=None):
combined = "\n\n".join(summaries_list)
if len(combined.split()) > 2000:
combined = extractive_prefilter(combined, top_k=20)
prompt = apply_tone_instruction(combined, tone, target_sentences=final_target_sentences)
tgt_sent = final_target_sentences or 3
gen_kwargs = {
"min_length": max(20, int(tgt_sent * 8)),
"max_length": max(60, int(tgt_sent * 30)),
"num_beams": 6,
"early_stopping": True,
"no_repeat_ngram_size": 3,
"do_sample": False,
}
try:
if summarizer_pipe is None:
# fallback to pegasus by default (if pipe not provided)
summarizer_pipe = get_summarizer("pegasus")
out = summarizer_pipe(prompt, **gen_kwargs)[0]["summary_text"].strip()
return out
except Exception as e:
logger.exception("Refine failed: %s", e)
return " ".join(summaries_list[:3])
# -------------------------
# Model-specific generation helper
# -------------------------
def summarize_with_model(pipe, text_prompt, short_target=False):
"""
Use model pipeline with conservative and model-appropriate generation settings.
short_target: if True use shorter min/max suitable for concise outputs
"""
# heuristics: if pipe is LED (model name in tied tokenizer), allow larger max_length
model_name = getattr(pipe.model.config, "name_or_path", "") or ""
is_led = "led" in model_name or "longformer" in model_name or "allenai" in model_name and "led" in model_name
if short_target:
min_l = 12
max_l = 60
else:
min_l = 24
max_l = 140 if not is_led else 400 # LED can handle longer outputs
gen_kwargs = {
"min_length": min_l,
"max_length": max_l,
"num_beams": 5 if not is_led else 4,
"early_stopping": True,
"no_repeat_ngram_size": 3,
"do_sample": False,
}
return pipe(text_prompt, **gen_kwargs)[0]["summary_text"].strip()
# -------------------------
# Routes
# -------------------------
@app.route("/")
def home():
return render_template("index.html")
@app.route("/summarize", methods=["POST"])
def summarize_route():
t0 = time.time()
data = request.get_json(force=True) or {}
text = (data.get("text") or "")[:90000]
user_model_pref = (data.get("model") or "auto").lower() # 'pegasus' | 'led' | 'auto'
requested_length = (data.get("length") or "auto").lower() # short|medium|long|auto
requested_tone = (data.get("tone") or "auto").lower() # neutral|formal|casual|bullet|auto
if not text or len(text.split()) < 5:
return jsonify({"error":"Input too short."}), 400
# 1) Decide settings (AI or explicit)
if requested_length in ("auto","ai") or requested_tone in ("auto","ai"):
cfg = generate_summarization_config(text)
length_choice = cfg.get("length","medium")
tone_choice = cfg.get("tone","neutral")
preset_min = cfg.get("min_length")
preset_max = cfg.get("max_length")
else:
length_choice = requested_length if requested_length in ("short","medium","long") else "medium"
tone_choice = requested_tone if requested_tone in ("neutral","formal","casual","bullet") else "neutral"
preset_min = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["min_length"]
preset_max = LENGTH_PRESETS.get(length_choice, LENGTH_PRESETS["medium"])["max_length"]
# 2) Model selection (user preference or auto)
# auto rules: if user specifically asked 'led' or param-generator picked long / input is very long -> led
words_len = len(text.split())
prefer_led = False
if user_model_pref == "led":
prefer_led = True
elif user_model_pref == "pegasus":
prefer_led = False
else: # auto
if length_choice == "long" or words_len > 3000:
prefer_led = True
else:
prefer_led = False
model_key = "led" if prefer_led else "pegasus"
# get the pipeline (lazy load)
try:
summarizer_pipe = get_summarizer(model_key)
except Exception as e:
logger.exception("Failed to load summarizer '%s': %s", model_key, e)
# fallback to pegasus if led fails
summarizer_pipe = get_summarizer("pegasus")
model_key = "pegasus"
# 3) Prefilter very long inputs (if not using LED)
if not prefer_led and words_len > 2500:
text_for_chunks = extractive_prefilter(text, top_k=40)
else:
text_for_chunks = text
# 4) Chunking: choose chunk size depending on model
if model_key == "led":
chunk_max_chars = 8000 # LED can handle larger chunks
chunk_overlap = 400
else:
chunk_max_chars = 1400
chunk_overlap = 200
chunks = chunk_text_by_chars(text_for_chunks, max_chars=chunk_max_chars, overlap=chunk_overlap)
# 5) Summarize each chunk
chunk_summaries = []
for chunk in chunks:
chunk_target = 1 if length_choice == "short" else 2
chunk_tone = tone_choice if tone_choice in ("formal","casual","bullet") else "neutral"
prompt = apply_tone_instruction(chunk, chunk_tone, target_sentences=chunk_target)
try:
# choose short_target True for tiny chunk summaries
out = summarize_with_model(summarizer_pipe, prompt, short_target=(chunk_target==1))
except Exception as e:
logger.exception("Chunk summarization failed, using extractive fallback: %s", e)
out = extractive_prefilter(chunk, top_k=3)
chunk_summaries.append(out)
# 6) Combine + refine using the same model for consistency (or prefer Pegasus for elegant refinement)
refine_model_key = model_key if model_key == "led" else "pegasus"
refine_pipe = get_summarizer(refine_model_key)
final_target_sentences = {"short":1,"medium":3,"long":6}.get(length_choice,3)
final = refine_and_combine(chunk_summaries, tone_choice, final_target_sentences, summarizer_pipe=refine_pipe)
# 7) Post-process bullet tone
if tone_choice == "bullet":
parts = re.split(r'[\n\r]+|(?:\.\s+)|(?:;\s+)', final)
bullets = [f"- {p.strip().rstrip('.')}" for p in parts if p.strip()]
final = "\n".join(bullets[:20])
elapsed = time.time() - t0
meta = {
"length_choice": length_choice,
"tone": tone_choice,
"model_used": model_key,
"refine_model": refine_model_key,
"chunks": len(chunks),
"input_words": words_len,
"time_seconds": round(elapsed, 2),
"device": ("gpu" if USE_GPU else "cpu")
}
return jsonify({"summary": final, "meta": meta})
# -------------------------
# Run
# -------------------------
if __name__ == "__main__":
# debug=False for production; use Gunicorn in deployment
app.run(host="0.0.0.0", port=7860, debug=False)