mahmoudsaber0 commited on
Commit
d23c0fb
·
verified ·
1 Parent(s): 4edb764

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -192
app.py CHANGED
@@ -200,204 +200,264 @@ class ModelManager:
200
  # استخدام hf_hub_download بدلاً من torch.hub للـ HF repos
201
  logger.info(f"🌐 Downloading weights from HF repo...")
202
  repo_id = "mihalykiss/modernbert_2"
203
- filename = model_url.split('/')[-1] # Extract filename like "Model_groups_3class_seed12"
204
- pt_file = hf_hub_download(
205
  repo_id=repo_id,
206
  filename=filename,
207
- cache_dir=CACHE_DIR,
208
- local_dir_use_symlinks=False
209
  )
210
- state_dict = torch.load(pt_file, map_location=device, weights_only=True)
211
-
212
- # تحميل الأوزان فقط إذا لم نكن في وضع fallback (لأن ModernBERT weights قد لا تتوافق مع BERT القياسي)
213
- if not self.using_fallback:
214
- base_model.load_state_dict(state_dict, strict=False)
215
- logger.info("✅ Weights loaded successfully")
216
- else:
217
- logger.warning("⚠️ Skipping weight load in fallback mode (incompatible architecture)")
218
- else:
219
- logger.info("📊 Using model with random initialization")
220
- except Exception as weight_error:
221
- logger.warning(f"⚠️ Could not load weights: {weight_error}")
222
- logger.info("📊 Continuing with base model (random or pre-trained init)")
223
-
224
- # نقل الموديل للجهاز المناسب
225
- model = base_model.to(device)
226
- model.eval()
227
-
228
- # تنظيف الذاكرة
229
- if 'state_dict' in locals():
230
- del state_dict
231
- gc.collect()
232
- if torch.cuda.is_available():
233
- torch.cuda.empty_cache()
234
-
235
- logger.info(f"✅ {model_name} loaded successfully (fallback: {self.using_fallback})")
236
- return model
237
 
238
- def load_models(self, max_models=2):
239
- """تحميل الموديلات بحد أقصى للذاكرة"""
 
 
 
 
 
 
 
 
 
 
240
  if self.models_loaded:
241
- logger.info("✨ Models already loaded")
242
  return True
243
 
244
- # تحميل الـ Tokenizer أولاً
245
- if not self.load_tokenizer():
246
- logger.error("❌ Tokenizer load failed - cannot proceed")
247
- return False
248
-
249
- # تحميل الموديلات
250
- logger.info(f"🚀 Loading up to {max_models} models...")
251
-
252
- # محاولة تحميل الملف المحلي أولاً
253
- local_model_path = "modernbert.bin"
254
- if os.path.exists(local_model_path):
255
- model = self.load_single_model(
256
- model_path=local_model_path,
257
- model_name="Model 1 (Local)"
258
- )
259
- if model is not None:
260
- self.models.append(model)
261
-
262
- # تحميل الموديلات من URLs (استخراج filenames)
263
- for i, full_url in enumerate(self.model_urls[:max_models - len(self.models)]):
264
- if len(self.models) >= max_models:
265
- break
266
-
267
- # استخدام full_url كما هو، لكن في load_single_model نستخرج filename
268
- model = self.load_single_model(
269
- model_url=full_url,
270
- model_name=f"Model {len(self.models) + 1}"
271
- )
272
- if model is not None:
273
  self.models.append(model)
274
 
275
- # التحقق من الذاكرة المتاحة
276
- if torch.cuda.is_available():
277
- mem_allocated = torch.cuda.memory_allocated() / 1024**3
278
- mem_reserved = torch.cuda.memory_reserved() / 1024**3
279
- logger.info(f"💾 GPU Memory: {mem_allocated:.2f}GB allocated, {mem_reserved:.2f}GB reserved")
280
-
281
- # إيقاف التحميل إذا كانت الذاكرة ممتلئة
282
- if mem_allocated > 6: # حد أقصى 6GB
283
- logger.warning("⚠️ Memory limit reached, stopping model loading")
284
- break
285
-
286
- # التحقق من نجاح التحميل
287
- if len(self.models) > 0:
288
  self.models_loaded = True
289
- logger.info(f"✅ Successfully loaded {len(self.models)} models (using fallback: {self.using_fallback})")
290
  return True
291
- else:
292
- logger.error("❌ No models could be loaded")
 
293
  return False
294
 
295
- def classify_text(self, text: str) -> Dict:
296
- """تحليل النص باستخدام الموديلات المحملة"""
297
- if not self.models_loaded or len(self.models) == 0:
298
- raise ValueError("No models loaded")
299
 
300
- # تنظيف النص
301
- cleaned_text = clean_text(text)
302
- if not cleaned_text.strip():
303
- raise ValueError("Empty text after cleaning")
304
-
305
- # Tokenization (max_length adjusted for fallback BERT if needed)
306
- max_len = 512 if not self.using_fallback else 512 # BERT max is 512
307
  try:
 
308
  inputs = self.tokenizer(
309
- cleaned_text,
310
  return_tensors="pt",
311
  truncation=True,
312
- max_length=max_len,
313
  padding=True
314
  ).to(device)
315
- except Exception as e:
316
- logger.error(f"Tokenization error: {e}")
317
- raise ValueError(f"Failed to tokenize text: {e}")
318
-
319
- # الحصول على التنبؤات
320
- all_probabilities = []
321
-
322
- with torch.no_grad():
323
- for i, model in enumerate(self.models):
324
- try:
325
- logits = model(**inputs).logits
326
- probs = torch.softmax(logits, dim=1)
327
- all_probabilities.append(probs)
328
- except Exception as e:
329
- logger.warning(f"Model {i+1} prediction failed: {e}")
330
- continue
331
 
332
- if not all_probabilities:
333
- raise ValueError("All models failed to make predictions")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
- # حساب المتوسط (Soft Voting)
336
- averaged_probs = torch.mean(torch.stack(all_probabilities), dim=0)
337
- probabilities = averaged_probs[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
- # حساب نسب Human vs AI
340
- human_prob = probabilities[24].item()
341
- ai_probs = probabilities.clone()
342
- ai_probs[24] = 0 # إزالة احتمالية Human
343
- ai_total_prob = ai_probs.sum().item()
344
 
345
- # التطبيع
346
- total = human_prob + ai_total_prob
347
- if total > 0:
348
- human_percentage = (human_prob / total) * 100
349
- ai_percentage = (ai_total_prob / total) * 100
350
- else:
351
- human_percentage = 50
352
- ai_percentage = 50
353
 
354
- # تحديد الموديل الأكثر احتمالاً
355
- ai_model_idx = torch.argmax(ai_probs).item()
356
- predicted_model = label_mapping.get(ai_model_idx, "Unknown")
357
 
358
- # أعلى 5 تنبؤات
359
- top_5_probs, top_5_indices = torch.topk(probabilities, 5)
360
- top_5_results = []
361
- for prob, idx in zip(top_5_probs, top_5_indices):
362
- top_5_results.append({
363
- "model": label_mapping.get(idx.item(), "Unknown"),
364
- "probability": round(prob.item() * 100, 2)
365
- })
366
 
367
  return {
368
- "human_percentage": round(human_percentage, 2),
369
- "ai_percentage": round(ai_percentage, 2),
370
- "predicted_model": predicted_model,
371
- "top_5_predictions": top_5_results,
372
- "is_human": human_percentage > ai_percentage,
373
- "models_used": len(all_probabilities),
374
- "using_fallback": self.using_fallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  }
376
 
377
  # =====================================================
378
- # 🧹 دوال التنظيف والمعالجة
379
  # =====================================================
380
- def clean_text(text: str) -> str:
381
- """تنظيف النص من المسافات الزائدة"""
382
- text = re.sub(r'\s{2,}', ' ', text)
383
- text = re.sub(r'\s+([,.;:?!])', r'\1', text)
384
- return text.strip()
 
 
 
 
 
 
 
385
 
386
- def split_into_paragraphs(text: str) -> List[str]:
 
 
 
387
  """تقسيم النص إلى فقرات"""
388
- paragraphs = re.split(r'\n\s*\n', text.strip())
389
- return [p.strip() for p in paragraphs if p.strip()]
390
 
391
  # =====================================================
392
  # 🌐 FastAPI Application
393
  # =====================================================
394
  app = FastAPI(
395
- title="ModernBERT AI Text Detector",
396
- description="كشف النصوص المكتوبة بواسطة الذكاء الاصطناعي",
397
- version="2.2.0" # Updated version with UID fix
398
  )
399
 
400
- # إضافة CORS للسماح بالاستخدام من المتصفح
401
  app.add_middleware(
402
  CORSMiddleware,
403
  allow_origins=["*"],
@@ -406,44 +466,18 @@ app.add_middleware(
406
  allow_headers=["*"],
407
  )
408
 
409
- # إنشاء مدير الموديلات
410
  model_manager = ModelManager()
411
 
412
  # =====================================================
413
- # 📝 نماذج البيانات (Pydantic Models)
414
- # =====================================================
415
- class TextInput(BaseModel):
416
- text: str
417
- analyze_paragraphs: Optional[bool] = False
418
-
419
- class SimpleTextInput(BaseModel):
420
- text: str
421
-
422
- class DetectionResult(BaseModel):
423
- success: bool
424
- code: int
425
- message: str
426
- data: Dict
427
-
428
- # =====================================================
429
- # 🎯 API Endpoints
430
  # =====================================================
431
  @app.on_event("startup")
432
  async def startup_event():
433
- """تحميل الموديلات عند بداية التشغيل"""
434
- logger.info("=" * 50)
435
- logger.info("🚀 Starting ModernBERT AI Detector...")
436
- logger.info(f"🐍 Python version: {sys.version}")
437
- logger.info(f"🔥 PyTorch version: {torch.__version__}")
438
- import transformers
439
- logger.info(f"🔧 Transformers version: {transformers.__version__}")
440
- logger.info("🛡️ UID Monkey Patch Applied (for Docker/Container)")
441
- logger.info("=" * 50)
442
-
443
- # محاولة تحميل الموديلات
444
- max_models = int(os.environ.get("MAX_MODELS", "2"))
445
- success = model_manager.load_models(max_models=max_models)
446
-
447
  if success:
448
  logger.info("✅ Application ready! (Fallback mode: %s)", model_manager.using_fallback)
449
  else:
@@ -555,13 +589,16 @@ async def analyze_text(data: TextInput):
555
  human_percentage = round(100 - ai_percentage, 2)
556
  ai_words = int(recalc_ai_words)
557
 
 
 
 
558
  # إنشاء رسالة التغذية الراجعة
559
  if ai_percentage > 50:
560
  feedback = "Most of Your Text is AI/GPT Generated"
561
  else:
562
  feedback = "Most of Your Text Appears Human-Written"
563
 
564
- # إرجاع النتائج بنفس تنسيق الكود الأصلي
565
  return DetectionResult(
566
  success=True,
567
  code=200,
@@ -578,7 +615,9 @@ async def analyze_text(data: TextInput):
578
  "detected_language": "en",
579
  "top_5_predictions": result.get("top_5_predictions", []),
580
  "models_used": result.get("models_used", 1),
581
- "using_fallback": result.get("using_fallback", False)
 
 
582
  }
583
  )
584
 
@@ -645,4 +684,4 @@ if __name__ == "__main__":
645
  port=port,
646
  workers=workers,
647
  reload=False # Set to True for dev
648
- )
 
200
  # استخدام hf_hub_download بدلاً من torch.hub للـ HF repos
201
  logger.info(f"🌐 Downloading weights from HF repo...")
202
  repo_id = "mihalykiss/modernbert_2"
203
+ filename = model_url.split("/")[-1]
204
+ local_path = hf_hub_download(
205
  repo_id=repo_id,
206
  filename=filename,
207
+ cache_dir=CACHE_DIR
 
208
  )
209
+ logger.info(f"✅ Downloaded to {local_path}")
210
+ state_dict = torch.load(local_path, map_location=device, weights_only=True)
211
+ base_model.load_state_dict(state_dict, strict=False)
212
+
213
+ logger.info(f"✅ {model_name} weights loaded successfully")
214
+ except Exception as e:
215
+ logger.warning(f"⚠️ Could not load custom weights for {model_name}: {e}")
216
+ logger.info("📌 Using base model without fine-tuned weights")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ # نقل للجهاز وضبط الوضع
219
+ try:
220
+ base_model = base_model.to(device)
221
+ base_model.eval()
222
+ logger.info(f"✅ {model_name} moved to {device} and set to eval mode")
223
+ return base_model
224
+ except Exception as e:
225
+ logger.error(f"❌ Failed to prepare {model_name}: {e}")
226
+ return None
227
+
228
+ def load_models(self):
229
+ """تحميل جميع الموديلات"""
230
  if self.models_loaded:
 
231
  return True
232
 
233
+ try:
234
+ # تحميل tokenizer
235
+ if not self.load_tokenizer():
236
+ return False
237
+
238
+ # تحميل كل موديل
239
+ for i, model_url in enumerate(self.model_urls):
240
+ model = self.load_single_model(
241
+ model_url=model_url,
242
+ model_name=f"Model {i+1}"
243
+ )
244
+ if model is None:
245
+ logger.warning(f"⚠️ Failed to load model {i+1}")
246
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  self.models.append(model)
248
 
249
+ if len(self.models) == 0:
250
+ logger.error("❌ No models loaded successfully")
251
+ return False
252
+
 
 
 
 
 
 
 
 
 
253
  self.models_loaded = True
254
+ logger.info(f"✅ Successfully loaded {len(self.models)} model(s)")
255
  return True
256
+
257
+ except Exception as e:
258
+ logger.error(f"❌ Model loading error: {e}", exc_info=True)
259
  return False
260
 
261
+ def classify_text(self, text: str, max_length: int = 512) -> Dict:
262
+ """تصنيف النص"""
263
+ if not self.models_loaded or not self.tokenizer:
264
+ raise RuntimeError("Models or tokenizer not loaded")
265
 
 
 
 
 
 
 
 
266
  try:
267
+ # Tokenization
268
  inputs = self.tokenizer(
269
+ text,
270
  return_tensors="pt",
271
  truncation=True,
272
+ max_length=max_length,
273
  padding=True
274
  ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ # التنبؤ باستخدام جميع الموديلات
277
+ all_logits = []
278
+ with torch.no_grad():
279
+ for model in self.models:
280
+ outputs = model(**inputs)
281
+ all_logits.append(outputs.logits)
282
+
283
+ # حساب المتوسط
284
+ avg_logits = torch.mean(torch.stack(all_logits), dim=0)
285
+ probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
286
+
287
+ # الحصول على أعلى التنبؤات
288
+ top_probs, top_indices = torch.topk(probabilities[0], k=5)
289
+
290
+ # حساب احتمالات AI vs Human
291
+ ai_prob = 1.0 - probabilities[0][24].item() # 24 = human
292
+ human_prob = probabilities[0][24].item()
293
+
294
+ # الموديل المتوقع
295
+ predicted_idx = top_indices[0].item()
296
+ predicted_model = label_mapping.get(predicted_idx, "unknown")
297
+
298
+ # Top 5 predictions
299
+ top_5 = [
300
+ {
301
+ "model": label_mapping.get(idx.item(), "unknown"),
302
+ "probability": prob.item()
303
+ }
304
+ for prob, idx in zip(top_probs, top_indices)
305
+ ]
306
+
307
+ return {
308
+ "ai_percentage": round(ai_prob * 100, 2),
309
+ "human_percentage": round(human_prob * 100, 2),
310
+ "predicted_model": predicted_model,
311
+ "top_5_predictions": top_5,
312
+ "models_used": len(self.models),
313
+ "using_fallback": self.using_fallback
314
+ }
315
 
316
+ except Exception as e:
317
+ logger.error(f"Classification error: {e}", exc_info=True)
318
+ raise
319
+
320
+ # =====================================================
321
+ # 🆕 NEW HELPER FUNCTIONS - Content Cleaning & Splitting
322
+ # =====================================================
323
+ def clean_content_for_analysis(text: str, min_line_length: int = 30) -> str:
324
+ """
325
+ Clean content by removing short lines (headlines, etc.)
326
+
327
+ Args:
328
+ text: Original text
329
+ min_line_length: Minimum character length for a line to be kept (default: 30)
330
+
331
+ Returns:
332
+ Cleaned text with only substantial content lines
333
+ """
334
+ lines = text.split('\n')
335
+ cleaned_lines = []
336
+
337
+ for line in lines:
338
+ stripped = line.strip()
339
+ # Keep lines that are longer than min_line_length
340
+ if len(stripped) >= min_line_length:
341
+ cleaned_lines.append(stripped)
342
+
343
+ return ' '.join(cleaned_lines)
344
+
345
+
346
+ def split_content_in_half(text: str) -> tuple:
347
+ """
348
+ Split cleaned content into two halves
349
+
350
+ Args:
351
+ text: Cleaned text
352
+
353
+ Returns:
354
+ Tuple of (first_half, second_half)
355
+ """
356
+ words = text.split()
357
+ mid_point = len(words) // 2
358
+
359
+ first_half = ' '.join(words[:mid_point])
360
+ second_half = ' '.join(words[mid_point:])
361
+
362
+ return first_half, second_half
363
+
364
+
365
+ def analyze_content_halves(model_manager, text: str) -> Dict:
366
+ """
367
+ Analyze text by splitting it into two halves after cleaning
368
+
369
+ Args:
370
+ model_manager: The ModelManager instance
371
+ text: Original text to analyze
372
+
373
+ Returns:
374
+ Dictionary with analysis of both halves
375
+ """
376
+ try:
377
+ # Clean the content first
378
+ cleaned_text = clean_content_for_analysis(text)
379
 
380
+ if not cleaned_text or len(cleaned_text.split()) < 10:
381
+ return {
382
+ "halves_analysis_available": False,
383
+ "reason": "Content too short after cleaning"
384
+ }
385
 
386
+ # Split into halves
387
+ first_half, second_half = split_content_in_half(cleaned_text)
 
 
 
 
 
 
388
 
389
+ # Analyze first half
390
+ first_half_result = model_manager.classify_text(first_half)
391
+ first_half_words = len(first_half.split())
392
 
393
+ # Analyze second half
394
+ second_half_result = model_manager.classify_text(second_half)
395
+ second_half_words = len(second_half.split())
 
 
 
 
 
396
 
397
  return {
398
+ "halves_analysis_available": True,
399
+ "cleaned_content": {
400
+ "total_words": len(cleaned_text.split()),
401
+ "first_half_words": first_half_words,
402
+ "second_half_words": second_half_words
403
+ },
404
+ "first_half": {
405
+ "ai_percentage": first_half_result["ai_percentage"],
406
+ "human_percentage": first_half_result["human_percentage"],
407
+ "predicted_model": first_half_result["predicted_model"],
408
+ "word_count": first_half_words,
409
+ "preview": first_half[:200] + "..." if len(first_half) > 200 else first_half
410
+ },
411
+ "second_half": {
412
+ "ai_percentage": second_half_result["ai_percentage"],
413
+ "human_percentage": second_half_result["human_percentage"],
414
+ "predicted_model": second_half_result["predicted_model"],
415
+ "word_count": second_half_words,
416
+ "preview": second_half[:200] + "..." if len(second_half) > 200 else second_half
417
+ }
418
+ }
419
+
420
+ except Exception as e:
421
+ logger.error(f"Error in halves analysis: {e}", exc_info=True)
422
+ return {
423
+ "halves_analysis_available": False,
424
+ "error": str(e)
425
  }
426
 
427
  # =====================================================
428
+ # 📝 Pydantic Models
429
  # =====================================================
430
+ class TextInput(BaseModel):
431
+ text: str
432
+ analyze_paragraphs: bool = False
433
+
434
+ class SimpleTextInput(BaseModel):
435
+ text: str
436
+
437
+ class DetectionResult(BaseModel):
438
+ success: bool
439
+ code: int
440
+ message: str
441
+ data: Dict
442
 
443
+ # =====================================================
444
+ # 🔧 مساعدات
445
+ # =====================================================
446
+ def split_into_paragraphs(text: str, min_length: int = 100) -> List[str]:
447
  """تقسيم النص إلى فقرات"""
448
+ paragraphs = re.split(r'\n\s*\n', text)
449
+ return [p.strip() for p in paragraphs if len(p.strip()) >= min_length]
450
 
451
  # =====================================================
452
  # 🌐 FastAPI Application
453
  # =====================================================
454
  app = FastAPI(
455
+ title="ModernBERT AI Text Detector API",
456
+ description="API for detecting AI-generated text using ModernBERT",
457
+ version="2.0.0"
458
  )
459
 
460
+ # CORS
461
  app.add_middleware(
462
  CORSMiddleware,
463
  allow_origins=["*"],
 
466
  allow_headers=["*"],
467
  )
468
 
469
+ # Model Manager Instance
470
  model_manager = ModelManager()
471
 
472
  # =====================================================
473
+ # 🚀 Startup Event
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  # =====================================================
475
  @app.on_event("startup")
476
  async def startup_event():
477
+ """تحميل الموديلات عند بدء التطبيق"""
478
+ logger.info("🚀 Starting application...")
479
+ logger.info("📦 Loading models...")
480
+ success = model_manager.load_models()
 
 
 
 
 
 
 
 
 
 
481
  if success:
482
  logger.info("✅ Application ready! (Fallback mode: %s)", model_manager.using_fallback)
483
  else:
 
589
  human_percentage = round(100 - ai_percentage, 2)
590
  ai_words = int(recalc_ai_words)
591
 
592
+ # 🆕 NEW FEATURE: Analyze content by halves
593
+ halves_analysis = analyze_content_halves(model_manager, text)
594
+
595
  # إنشاء رسالة التغذية الراجعة
596
  if ai_percentage > 50:
597
  feedback = "Most of Your Text is AI/GPT Generated"
598
  else:
599
  feedback = "Most of Your Text Appears Human-Written"
600
 
601
+ # إرجاع النتائج بنفس تنسيق الكود الأصلي + إضافة تحليل النصفين
602
  return DetectionResult(
603
  success=True,
604
  code=200,
 
615
  "detected_language": "en",
616
  "top_5_predictions": result.get("top_5_predictions", []),
617
  "models_used": result.get("models_used", 1),
618
+ "using_fallback": result.get("using_fallback", False),
619
+ # 🆕 NEW: Halves analysis appended to response
620
+ "halves_analysis": halves_analysis
621
  }
622
  )
623
 
 
684
  port=port,
685
  workers=workers,
686
  reload=False # Set to True for dev
687
+ )