mahmoudsaber0 commited on
Commit
54cad9d
·
verified ·
1 Parent(s): dec54ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -118
app.py CHANGED
@@ -1,153 +1,192 @@
 
 
1
  import torch
2
  import re
3
- import io
4
- import base64
5
  import matplotlib
6
  matplotlib.use("Agg")
7
  import matplotlib.pyplot as plt
8
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
- import gradio as gr
10
-
11
- # ===============================
12
- # Safe model loading
13
- # ===============================
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- def safe_load_model(name):
17
- try:
18
- return AutoModelForSequenceClassification.from_pretrained(name).to(device)
19
- except Exception as e:
20
- print(f"[WARN] Failed to load {name}: {e}")
21
- return None
 
 
 
 
 
 
 
 
 
22
 
23
  print("Loading models...")
24
-
25
- model_1 = safe_load_model("roberta-base-openai-detector")
26
- model_2 = safe_load_model("roberta-large-openai-detector")
27
- model_3 = safe_load_model("Hello-SimpleAI/chatgpt-detector-roberta")
28
- tokenizer = AutoTokenizer.from_pretrained("roberta-base-openai-detector")
29
-
30
- label_mapping = {i: f"model_{i}" for i in range(25)}
31
- label_mapping[24] = "Human"
32
-
33
- # ===============================
34
- # Helper functions
35
- # ===============================
36
- def clean_text(text):
37
- return re.sub(r'\s+', ' ', text).strip()
38
-
39
- def plot_to_base64(fig):
40
- buf = io.BytesIO()
41
- fig.savefig(buf, format="png", bbox_inches="tight")
42
- buf.seek(0)
43
- img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
44
- plt.close(fig)
45
- return f"<img src='data:image/png;base64,{img_base64}' style='max-width:100%; border-radius:8px;'>"
46
-
47
- # ===============================
48
- # Main classification logic
49
- # ===============================
 
 
 
 
 
 
 
 
 
 
 
 
50
  def classify_text(text):
51
  cleaned_text = clean_text(text)
52
- if not cleaned_text:
53
- return "<b style='color:red'>Please enter some text.</b>"
54
-
55
- paragraphs = re.split(r'\n{2,}', cleaned_text)
56
- if len(paragraphs) == 1 and len(cleaned_text.split()) > 300:
57
- words = cleaned_text.split()
58
- paragraphs = [' '.join(words[i:i + 300]) for i in range(0, len(words), 300)]
59
 
60
- paragraph_scores = []
 
61
  all_probabilities = []
62
 
63
- for i, para in enumerate(paragraphs):
64
- inputs = tokenizer(para, return_tensors="pt", truncation=True, padding=True).to(device)
65
- softmax_outputs = []
66
 
67
- for m in [model_1, model_2, model_3]:
68
- if m is None:
69
- continue
70
- with torch.no_grad():
71
- logits = m(**inputs).logits
72
- softmax_outputs.append(torch.softmax(logits, dim=1))
73
 
74
- if not softmax_outputs:
75
- return "<b style='color:red'>Error: No models loaded successfully.</b>"
76
-
77
- avg_probs = sum(softmax_outputs) / len(softmax_outputs)
78
- probabilities = avg_probs[0]
79
- all_probabilities.append(probabilities.cpu())
80
 
81
  human_prob = probabilities[24].item()
82
- ai_probs = probabilities.clone()
83
- ai_probs[24] = 0
84
- ai_prob = ai_probs.sum().item()
85
 
86
- total = human_prob + ai_prob
87
  human_pct = (human_prob / total) * 100
88
- ai_pct = (ai_prob / total) * 100
89
- ai_model = label_mapping[torch.argmax(ai_probs).item()]
90
 
91
- preview = para[:180].strip() + ("..." if len(para) > 180 else "")
92
- paragraph_scores.append({
93
- "id": i + 1,
94
  "human": human_pct,
95
  "ai": ai_pct,
96
- "model": ai_model,
97
- "preview": preview
98
  })
99
 
100
- avg_human = sum(p["human"] for p in paragraph_scores) / len(paragraph_scores)
101
- avg_ai = sum(p["ai"] for p in paragraph_scores) / len(paragraph_scores)
102
-
103
  if avg_human > avg_ai:
104
- overall = f"<b>Overall Result:</b> <span class='highlight-human'>{avg_human:.2f}% Human-written</span>"
105
  else:
106
- top_model = max(paragraph_scores, key=lambda p: p["ai"])["model"]
107
- overall = f"<b>Overall Result:</b> <span class='highlight-ai'>{avg_ai:.2f}% AI-generated (likely {top_model})</span>"
 
 
 
 
 
 
 
 
 
 
108
 
109
- # --- Top 5 chart ---
110
  mean_probs = torch.mean(torch.stack(all_probabilities), dim=0)
111
  top_5_probs, top_5_indices = torch.topk(mean_probs, 5)
112
- labels = [label_mapping[i.item()] for i in top_5_indices]
113
- values = top_5_probs.cpu().numpy()
114
 
115
- fig, ax = plt.subplots(figsize=(8, 4))
116
- ax.barh(labels, values, color='#4CAF50')
117
- ax.set_xlabel("Probability")
118
- ax.set_title("Top 5 Model Predictions")
119
  ax.invert_yaxis()
120
- chart_html = plot_to_base64(fig)
121
-
122
- # --- Paragraph breakdown ---
123
- html = f"<div style='font-family:Arial, sans-serif; line-height:1.6'>{overall}<br><br>"
124
- html += "<h3>Paragraph Analysis:</h3>"
125
-
126
- for p in paragraph_scores:
127
- color = "#28a745" if p["human"] > p["ai"] else "#FF5733"
128
- html += f"""
129
- <div style='margin-bottom:10px; border-left:5px solid {color}; padding-left:10px; background:#f9f9f9; border-radius:6px;'>
130
- <b>Paragraph {p["id"]}</b>: {p["human"]:.2f}% Human | {p["ai"]:.2f}% AI → <i>{p["model"]}</i><br>
131
- <small>{p["preview"]}</small>
132
- </div>
133
- """
134
 
135
- html += "<br><h3>Top 5 Models:</h3>" + chart_html + "</div>"
136
- return html
137
 
138
- # ===============================
139
- # Gradio UI
140
- # ===============================
141
- css = """
142
- .highlight-ai { color: #FF5733; font-weight: bold; }
143
- .highlight-human { color: #28a745; font-weight: bold; }
 
144
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- with gr.Blocks(css=css, theme="soft") as demo:
147
- gr.Markdown("# 🧠 AI vs Human Text Detector")
148
- txt = gr.Textbox(label="Paste your article", lines=12, placeholder="Enter your full text here...")
149
- btn = gr.Button("Analyze", variant="primary")
150
- out = gr.HTML(label="Results", elem_id="result-box")
151
- btn.click(classify_text, inputs=txt, outputs=out)
152
 
153
- demo.launch()
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import re
 
 
5
  import matplotlib
6
  matplotlib.use("Agg")
7
  import matplotlib.pyplot as plt
8
+ from tokenizers.normalizers import Sequence, Replace, Strip
9
+ from tokenizers import Regex
10
+
11
+ # -------------------------
12
+ # Device setup
13
+ # -------------------------
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # -------------------------
17
+ # Model and Tokenizer Setup
18
+ # -------------------------
19
+ model1_path = "modernbert.bin"
20
+ model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
21
+ model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
24
+
25
+ def safe_load_model(base_name, weights_path):
26
+ model = AutoModelForSequenceClassification.from_pretrained(base_name, num_labels=41)
27
+ state_dict = torch.hub.load_state_dict_from_url(weights_path, map_location=device) if weights_path.startswith("http") else torch.load(weights_path, map_location=device)
28
+ model.load_state_dict(state_dict)
29
+ model.to(device).eval()
30
+ return model
31
 
32
  print("Loading models...")
33
+ model_1 = safe_load_model("answerdotai/ModernBERT-base", model1_path)
34
+ model_2 = safe_load_model("answerdotai/ModernBERT-base", model2_path)
35
+ model_3 = safe_load_model("answerdotai/ModernBERT-base", model3_path)
36
+
37
+ # -------------------------
38
+ # Label Mapping
39
+ # -------------------------
40
+ label_mapping = {
41
+ 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
42
+ 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
43
+ 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
44
+ 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
45
+ 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
46
+ 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
47
+ 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
48
+ 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
49
+ 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
50
+ 39: 'text-davinci-002', 40: 'text-davinci-003'
51
+ }
52
+
53
+ # -------------------------
54
+ # Text Cleaning
55
+ # -------------------------
56
+ def clean_text(text: str) -> str:
57
+ text = re.sub(r'\s{2,}', ' ', text)
58
+ text = re.sub(r'\s+([,.;:?!])', r'\1', text)
59
+ return text
60
+
61
+ newline_to_space = Replace(Regex(r'\s*\n\s*'), " ")
62
+ tokenizer.backend_tokenizer.normalizer = Sequence([
63
+ tokenizer.backend_tokenizer.normalizer,
64
+ newline_to_space,
65
+ Strip()
66
+ ])
67
+
68
+ # -------------------------
69
+ # Classification Function
70
+ # -------------------------
71
  def classify_text(text):
72
  cleaned_text = clean_text(text)
73
+ if not cleaned_text.strip():
74
+ return "<b style='color:red;'>Please enter some text to analyze.</b>", None
 
 
 
 
 
75
 
76
+ paragraphs = [p.strip() for p in re.split(r'\n{2,}', cleaned_text) if p.strip()]
77
+ chunk_scores = []
78
  all_probabilities = []
79
 
80
+ for i, paragraph in enumerate(paragraphs):
81
+ inputs = tokenizer(paragraph, return_tensors="pt", truncation=True, padding=True).to(device)
 
82
 
83
+ with torch.no_grad():
84
+ logits_1 = model_1(**inputs).logits
85
+ logits_2 = model_2(**inputs).logits
86
+ logits_3 = model_3(**inputs).logits
 
 
87
 
88
+ softmax_1 = torch.softmax(logits_1, dim=1)
89
+ softmax_2 = torch.softmax(logits_2, dim=1)
90
+ softmax_3 = torch.softmax(logits_3, dim=1)
91
+ averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
92
+ probabilities = averaged_probabilities[0]
93
+ all_probabilities.append(probabilities.cpu())
94
 
95
  human_prob = probabilities[24].item()
96
+ ai_probs_clone = probabilities.clone()
97
+ ai_probs_clone[24] = 0
98
+ ai_total_prob = ai_probs_clone.sum().item()
99
 
100
+ total = human_prob + ai_total_prob
101
  human_pct = (human_prob / total) * 100
102
+ ai_pct = (ai_total_prob / total) * 100
103
+ ai_model = label_mapping[torch.argmax(ai_probs_clone).item()]
104
 
105
+ chunk_scores.append({
106
+ "paragraph": paragraph[:150] + ("..." if len(paragraph) > 150 else ""),
 
107
  "human": human_pct,
108
  "ai": ai_pct,
109
+ "model": ai_model
 
110
  })
111
 
112
+ # --- Overall ---
113
+ avg_human = sum(c["human"] for c in chunk_scores) / len(chunk_scores)
114
+ avg_ai = sum(c["ai"] for c in chunk_scores) / len(chunk_scores)
115
  if avg_human > avg_ai:
116
+ result_message = f"**Overall Result:** <span class='highlight-human'>{avg_human:.2f}% Human-written</span>"
117
  else:
118
+ top_model = max(chunk_scores, key=lambda c: c['ai'])['model']
119
+ result_message = f"**Overall Result:** <span class='highlight-ai'>{avg_ai:.2f}% AI-generated (likely {top_model})</span>"
120
+
121
+ # --- Paragraph Breakdown ---
122
+ paragraph_html = "<h3>Paragraph Analysis:</h3>"
123
+ for idx, c in enumerate(chunk_scores, 1):
124
+ color = "#4CAF50" if c['human'] > c['ai'] else "#FF5733"
125
+ paragraph_html += f"""
126
+ <div style='margin-bottom:10px; border-left:4px solid {color}; padding-left:10px;'>
127
+ <b>Paragraph {idx}</b>: {c['human']:.2f}% Human | {c['ai']:.2f}% AI → <i>{c['model']}</i><br>
128
+ <small>{c['paragraph']}</small></div>
129
+ """
130
 
131
+ # --- Plot ---
132
  mean_probs = torch.mean(torch.stack(all_probabilities), dim=0)
133
  top_5_probs, top_5_indices = torch.topk(mean_probs, 5)
134
+ top_5_probs = top_5_probs.cpu().numpy()
135
+ top_5_labels = [label_mapping[i.item()] for i in top_5_indices]
136
 
137
+ fig, ax = plt.subplots(figsize=(10, 5))
138
+ bars = ax.barh(top_5_labels, top_5_probs, color='#4CAF50')
139
+ ax.set_xlabel('Probability')
140
+ ax.set_title('Top 5 Model Predictions')
141
  ax.invert_yaxis()
142
+ for bar in bars:
143
+ width = bar.get_width()
144
+ ax.text(width + 0.005, bar.get_y() + bar.get_height() / 2, f'{width:.2%}', va='center')
145
+ plt.tight_layout()
146
+
147
+ return result_message + "<br><br>" + paragraph_html, fig
 
 
 
 
 
 
 
 
148
 
 
 
149
 
150
+ # -------------------------
151
+ # UI Setup
152
+ # -------------------------
153
+ title = "AI Text Detector"
154
+ description = """
155
+ This tool uses <b>ModernBERT</b> to detect AI-generated text.<br>
156
+ Each paragraph is analyzed separately to show which parts are likely AI-generated.
157
  """
158
+ bottom_text = "**Developed by SzegedAI – Extended by Saber**"
159
+
160
+ AI_texts = [
161
+ "Artificial intelligence (AI) is reshaping industries by automating tasks, enhancing decision-making, and driving innovation. From predictive analytics in finance to autonomous vehicles in transportation, AI technologies are becoming integral to daily operations."
162
+ ]
163
+
164
+ Human_texts = [
165
+ "Mathematics has always been a cornerstone of scientific discovery. It provides a precise language for describing natural phenomena, from the orbit of planets to the behavior of subatomic particles."
166
+ ]
167
+
168
+ iface = gr.Blocks(css="""
169
+ @import url('https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400;700&display=swap');
170
+ body { font-family: 'Roboto Mono', sans-serif !important; }
171
+ .highlight-human { color: #4CAF50; font-weight: bold; }
172
+ .highlight-ai { color: #FF5733; font-weight: bold; }
173
+ """)
174
+
175
+ with iface:
176
+ gr.Markdown(f"# {title}")
177
+ gr.Markdown(description)
178
+ text_input = gr.Textbox(label="", placeholder="Paste your article here...", lines=10)
179
+ analyze_btn = gr.Button("🔍 Analyze", variant="primary")
180
+ result_output = gr.HTML(label="Result")
181
+ plot_output = gr.Plot(label="Model Probability Distribution")
182
+
183
+ analyze_btn.click(classify_text, inputs=text_input, outputs=[result_output, plot_output])
184
+
185
+ with gr.Tab("AI Examples"):
186
+ gr.Examples(AI_texts, inputs=text_input)
187
+ with gr.Tab("Human Examples"):
188
+ gr.Examples(Human_texts, inputs=text_input)
189
 
190
+ gr.Markdown(bottom_text)
 
 
 
 
 
191
 
192
+ iface.launch(share=True)