File size: 3,299 Bytes
d83d604
98ac441
d1155a6
5813702
020246b
7245c1f
 
ff5002a
d1155a6
7245c1f
020246b
7245c1f
020246b
73fdc8f
020246b
5813702
861709c
daec533
861709c
020246b
 
5813702
73fdc8f
5813702
0c65a3b
7245c1f
98ac441
7245c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d83d604
020246b
 
6ed741d
5813702
 
 
98ac441
7245c1f
 
 
 
 
d1155a6
7245c1f
5813702
 
 
 
 
53d5734
7b5b68f
fd8e8ce
 
7245c1f
 
fd8e8ce
 
020246b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from spaces import GPU  # Required for ZeroGPU Spaces
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
import nltk

# Download NLTK stopwords if not already available
nltk.download("stopwords")
nltk.download('punkt')
stop_words = set(stopwords.words("english"))

# Model list
model_choices = {
    "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
    "T5 Small (t5-small)": "t5-small",
    "T5 Base (t5-base)": "t5-base",
    "Pegasus XSum (google/pegasus-xsum)": "google/pegasus-xsum",
    "BART CNN (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
}

model_cache = {}

# Clean text: remove special characters, stop words, SKU codes, and short words
def clean_text(input_text):
    # Step 1: Remove any non-English characters (like special symbols, non-latin characters)
    cleaned_text = re.sub(r"[^A-Za-z0-9\s]", " ", input_text)
    cleaned_text = re.sub(r"\s+", " ", cleaned_text)  # Replace multiple spaces with a single space

    # Step 2: Tokenize the text and remove stopwords and words that are too short to be meaningful
    words = cleaned_text.split()
    filtered_words = [word for word in words if word.lower() not in stop_words and len(word) > 2]

    # Step 3: Rebuild the text from the remaining words
    filtered_text = " ".join(filtered_words)

    # Step 4: Remove any product codes or sequences (e.g., ST1642, AB1234)
    filtered_text = re.sub(r"\b[A-Za-z]{2,}[0-9]{3,}\b", "", filtered_text)  # SKU/product code pattern

    # Step 5: Strip leading/trailing spaces
    filtered_text = filtered_text.strip()

    return filtered_text

# Extractive Summarization: Select sentences directly from the input text
def extractive_summary(input_text, num_sentences=2):
    sentences = sent_tokenize(input_text)  # Tokenize into sentences
    filtered_sentences = [sentence for sentence in sentences if len(sentence.split()) > 2]  # Filter out very short sentences
    return " ".join(filtered_sentences[:num_sentences])  # Return first `num_sentences` sentences

# Main function triggered by Gradio
@GPU  # 👈 Required for ZeroGPU to trigger GPU spin-up
def summarize_text(input_text, model_label, char_limit):
    if not input_text.strip():
        return "Please enter some text."

    input_text = clean_text(input_text)

    # For extractive summarization, we don't use the models that generate new tokens.
    summary = extractive_summary(input_text)

    # Truncate summary based on the character limit
    return summary[:char_limit].strip()

# Gradio UI
iface = gr.Interface(
    fn=summarize_text,
    inputs=[
        gr.Textbox(lines=6, label="Enter text to summarize"),
        gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
        gr.Slider(minimum=30, maximum=200, value=65, step=1, label="Max Character Limit")
    ],
    outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
    title="🔥 Fast Summarizer (Extractive Only)",
    description="Summarizes input by selecting key sentences from the input text, without generating new tokens."
)

iface.launch()