madankn79's picture
google
7245c1f
raw
history blame
3.3 kB
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from spaces import GPU # Required for ZeroGPU Spaces
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
import nltk
# Download NLTK stopwords if not already available
nltk.download("stopwords")
nltk.download('punkt')
stop_words = set(stopwords.words("english"))
# Model list
model_choices = {
"DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
"T5 Small (t5-small)": "t5-small",
"T5 Base (t5-base)": "t5-base",
"Pegasus XSum (google/pegasus-xsum)": "google/pegasus-xsum",
"BART CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
}
model_cache = {}
# Clean text: remove special characters, stop words, SKU codes, and short words
def clean_text(input_text):
# Step 1: Remove any non-English characters (like special symbols, non-latin characters)
cleaned_text = re.sub(r"[^A-Za-z0-9\s]", " ", input_text)
cleaned_text = re.sub(r"\s+", " ", cleaned_text) # Replace multiple spaces with a single space
# Step 2: Tokenize the text and remove stopwords and words that are too short to be meaningful
words = cleaned_text.split()
filtered_words = [word for word in words if word.lower() not in stop_words and len(word) > 2]
# Step 3: Rebuild the text from the remaining words
filtered_text = " ".join(filtered_words)
# Step 4: Remove any product codes or sequences (e.g., ST1642, AB1234)
filtered_text = re.sub(r"\b[A-Za-z]{2,}[0-9]{3,}\b", "", filtered_text) # SKU/product code pattern
# Step 5: Strip leading/trailing spaces
filtered_text = filtered_text.strip()
return filtered_text
# Extractive Summarization: Select sentences directly from the input text
def extractive_summary(input_text, num_sentences=2):
sentences = sent_tokenize(input_text) # Tokenize into sentences
filtered_sentences = [sentence for sentence in sentences if len(sentence.split()) > 2] # Filter out very short sentences
return " ".join(filtered_sentences[:num_sentences]) # Return first `num_sentences` sentences
# Main function triggered by Gradio
@GPU # 👈 Required for ZeroGPU to trigger GPU spin-up
def summarize_text(input_text, model_label, char_limit):
if not input_text.strip():
return "Please enter some text."
input_text = clean_text(input_text)
# For extractive summarization, we don't use the models that generate new tokens.
summary = extractive_summary(input_text)
# Truncate summary based on the character limit
return summary[:char_limit].strip()
# Gradio UI
iface = gr.Interface(
fn=summarize_text,
inputs=[
gr.Textbox(lines=6, label="Enter text to summarize"),
gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
],
outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
title="🔥 Fast Summarizer (Extractive Only)",
description="Summarizes input by selecting key sentences from the input text, without generating new tokens."
)
iface.launch()