nixaut-codelabs commited on
Commit
0d556f2
·
verified ·
1 Parent(s): aa976b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -180
app.py CHANGED
@@ -9,9 +9,9 @@ import requests
9
  import numpy as np
10
  import asyncio
11
  from typing import List, Dict, Any, Optional, Union
12
- from fastapi import FastAPI, HTTPException, Depends, Request, File, UploadFile
13
  from fastapi.concurrency import run_in_threadpool
14
- from fastapi.responses import HTMLResponse, JSONResponse
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.templating import Jinja2Templates
17
  from pydantic import BaseModel, Field
@@ -345,14 +345,23 @@ def classify_image(image_data):
345
  "nsfw_score": 0.0
346
  }
347
 
348
- def process_content_item(item: Union[str, Dict], text_model: str = "gemma") -> Dict:
 
349
  if isinstance(item, str):
350
- item = {"type": "text", "text": item}
 
 
 
 
 
 
 
 
351
 
352
- content_type = item.get("type")
353
 
354
  if content_type == "text":
355
- text = item.get("text", "")
356
  if text_model == "gemma":
357
  gemma_result = classify_text_with_gemma(text)
358
  flagged = gemma_result["classification"] == "u"
@@ -365,12 +374,7 @@ def process_content_item(item: Union[str, Dict], text_model: str = "gemma") -> D
365
  "violence": 0.9 if flagged else 0.1, "violence/graphic": 0.9 if flagged else 0.1,
366
  "nsfw": 0.1,
367
  }
368
- return {
369
- "flagged": flagged,
370
- "categories": {k: (v > 0.5) for k, v in scores.items()},
371
- "category_scores": scores,
372
- "text": text,
373
- }
374
  elif text_model == "detoxify":
375
  d = classify_text_with_detoxify(text)
376
  scores = {
@@ -381,12 +385,7 @@ def process_content_item(item: Union[str, Dict], text_model: str = "gemma") -> D
381
  "violence": d["category_scores"].get("threat", 0.1), "violence/graphic": d["category_scores"].get("threat", 0.1),
382
  "nsfw": d["category_scores"].get("sexual_explicit", 0.1),
383
  }
384
- return {
385
- "flagged": d["flagged"],
386
- "categories": {k: (v > 0.5) for k, v in scores.items()},
387
- "category_scores": scores,
388
- "text": text,
389
- }
390
  elif text_model == "both":
391
  gemma_result = classify_text_with_gemma(text)
392
  detoxify_result = classify_text_with_detoxify(text)
@@ -405,17 +404,12 @@ def process_content_item(item: Union[str, Dict], text_model: str = "gemma") -> D
405
  "violence/graphic": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("threat", 0.1)),
406
  "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
407
  }
408
- return {
409
- "flagged": flagged,
410
- "categories": {k: (v > 0.5) for k, v in scores.items()},
411
- "category_scores": scores,
412
- "text": text,
413
- }
414
 
415
  elif content_type == "image":
416
  image_data = None
417
- image_url = item.get("url")
418
- image_base64 = item.get("base64")
419
 
420
  if image_url:
421
  try:
@@ -449,7 +443,7 @@ def process_content_item(item: Union[str, Dict], text_model: str = "gemma") -> D
449
  "categories": {k: (v > 0.5) for k, v in scores.items()},
450
  "category_scores": scores,
451
  "image_url": image_url,
452
- "image_base64": image_base64[:50] + "..." if isinstance(image_base64, str) and len(image_base64) > 50 else None,
453
  }
454
 
455
  default_scores = {
@@ -462,7 +456,7 @@ def process_content_item(item: Union[str, Dict], text_model: str = "gemma") -> D
462
  "flagged": False,
463
  "categories": {k: False for k in default_scores},
464
  "category_scores": default_scores,
465
- "error": f"Invalid or unprocessable item: {item}"
466
  }
467
 
468
  def get_api_key(request: Request):
@@ -508,8 +502,8 @@ async def moderate_content(
508
  for item in items:
509
  if isinstance(item, str):
510
  total_tokens += count_tokens(item)
511
- elif isinstance(item, dict) and item.get("type") == "text":
512
- total_tokens += count_tokens(item.get("text", ""))
513
  else:
514
  raise HTTPException(status_code=400, detail="Invalid input format")
515
 
@@ -1032,21 +1026,13 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1032
  }
1033
  });
1034
 
1035
- analyzeTextBtn.addEventListener('click', async () => {
1036
- const text = textInput.value.trim();
1037
- if (!text) {
1038
- showNotification('Please enter text to analyze', 'error');
1039
- return;
1040
- }
1041
-
1042
  const apiKey = apiKeyInput.value.trim();
1043
  if (!apiKey) {
1044
  showNotification('Please enter your API key', 'error');
1045
  return;
1046
  }
1047
 
1048
- const textModel = textModelSelect.value;
1049
-
1050
  showLoading(true);
1051
  try {
1052
  const response = await fetch('/v1/moderations', {
@@ -1055,17 +1041,14 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1055
  'Content-Type': 'application/json',
1056
  'Authorization': `Bearer ${apiKey}`
1057
  },
1058
- body: JSON.stringify({
1059
- input: text,
1060
- model: textModel
1061
- })
1062
  });
1063
-
1064
  if (!response.ok) {
1065
  const errorData = await response.json();
1066
- throw new Error(errorData.detail || 'An error occurred');
1067
  }
1068
-
1069
  const data = await response.json();
1070
  displayResults(data.results);
1071
  updateMetrics();
@@ -1074,158 +1057,90 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1074
  } finally {
1075
  showLoading(false);
1076
  }
 
 
 
 
 
 
 
 
 
 
 
 
1077
  });
1078
 
1079
  analyzeImageBtn.addEventListener('click', async () => {
1080
  const url = imageUrl.value.trim();
1081
- const fileInput = document.querySelector('#imageUpload');
1082
- const file = fileInput.files[0];
1083
-
1084
  if (!url && !file) {
1085
  showNotification('Please provide an image URL or upload an image', 'error');
1086
  return;
1087
  }
1088
-
1089
- const apiKey = apiKeyInput.value.trim();
1090
- if (!apiKey) {
1091
- showNotification('Please enter your API key', 'error');
1092
- return;
1093
- }
1094
-
1095
  let imageInput;
1096
-
1097
  if (url) {
1098
- imageInput = {
1099
- type: "image",
1100
- url: url
1101
- };
1102
  } else {
1103
- const reader = new FileReader();
1104
- const base64Promise = new Promise((resolve) => {
1105
  reader.onload = (event) => resolve(event.target.result);
 
1106
  });
1107
- reader.readAsDataURL(file);
1108
- const base64 = await base64Promise;
1109
-
1110
- imageInput = {
1111
- type: "image",
1112
- base64: base64
1113
- };
1114
- }
1115
-
1116
- showLoading(true);
1117
- try {
1118
- const response = await fetch('/v1/moderations', {
1119
- method: 'POST',
1120
- headers: {
1121
- 'Content-Type': 'application/json',
1122
- 'Authorization': `Bearer ${apiKey}`
1123
- },
1124
- body: JSON.stringify({
1125
- input: [imageInput],
1126
- model: textModelSelect.value
1127
- })
1128
- });
1129
-
1130
- if (!response.ok) {
1131
- const errorData = await response.json();
1132
- throw new Error(errorData.detail || 'An error occurred');
1133
- }
1134
-
1135
- const data = await response.json();
1136
- displayResults(data.results);
1137
- updateMetrics();
1138
- } catch (error) {
1139
- showNotification(`Error: ${error.message}`, 'error');
1140
- } finally {
1141
- showLoading(false);
1142
  }
 
 
 
 
 
1143
  });
1144
-
1145
  analyzeMixedBtn.addEventListener('click', async () => {
1146
  const items = Array.from(mixedItemsContainer.querySelectorAll('.mixed-item'));
1147
  if (items.length === 0) {
1148
  showNotification('Please add at least one item to analyze', 'error');
1149
  return;
1150
  }
1151
-
1152
- const apiKey = apiKeyInput.value.trim();
1153
- if (!apiKey) {
1154
- showNotification('Please enter your API key', 'error');
1155
- return;
1156
- }
1157
-
1158
- const inputItems = [];
1159
-
1160
- for (const item of items) {
1161
  const type = item.querySelector('.item-type').value;
1162
  const contentDiv = item.querySelector('.item-content');
1163
-
1164
  if (type === 'text') {
1165
- const textarea = contentDiv.querySelector('textarea');
1166
- const text = textarea.value.trim();
1167
- if (text) {
1168
- inputItems.push({
1169
- type: 'text',
1170
- text: text
1171
- });
1172
- }
1173
  } else {
1174
- const urlInput = contentDiv.querySelector('input[type="text"]');
1175
- const fileInput = contentDiv.querySelector('input[type="file"]');
1176
- const preview = contentDiv.querySelector('.image-preview');
1177
- const previewImg = contentDiv.querySelector('.image-preview img');
1178
-
1179
- const url = urlInput.value.trim();
1180
- const file = fileInput.files[0];
1181
 
1182
  if (url) {
1183
- inputItems.push({
1184
- type: 'image',
1185
- url: url
1186
- });
1187
- } else if (file || !preview.classList.contains('hidden')) {
1188
- const imgSrc = previewImg.src;
1189
- inputItems.push({
1190
- type: 'image',
1191
- base64: imgSrc
1192
  });
 
1193
  }
 
1194
  }
1195
- }
1196
-
 
 
1197
  if (inputItems.length === 0) {
1198
  showNotification('Please add content to at least one item', 'error');
1199
  return;
1200
  }
1201
 
1202
- showLoading(true);
1203
- try {
1204
- const response = await fetch('/v1/moderations', {
1205
- method: 'POST',
1206
- headers: {
1207
- 'Content-Type': 'application/json',
1208
- 'Authorization': `Bearer ${apiKey}`
1209
- },
1210
- body: JSON.stringify({
1211
- input: inputItems,
1212
- model: textModelSelect.value
1213
- })
1214
- });
1215
-
1216
- if (!response.ok) {
1217
- const errorData = await response.json();
1218
- throw new Error(errorData.detail || 'An error occurred');
1219
- }
1220
-
1221
- const data = await response.json();
1222
- displayResults(data.results);
1223
- updateMetrics();
1224
- } catch (error) {
1225
- showNotification(`Error: ${error.message}`, 'error');
1226
- } finally {
1227
- showLoading(false);
1228
- }
1229
  });
1230
 
1231
  function displayResults(results) {
@@ -1240,10 +1155,10 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1240
  'Content may contain inappropriate or harmful material.' :
1241
  'Content appears to be safe and appropriate.';
1242
 
1243
- const categories = Object.entries(result.categories)
1244
  .filter(([_, value]) => value)
1245
  .map(([key, _]) => key.replace('/', ' '))
1246
- .join(', ');
1247
 
1248
  let contentPreview = '';
1249
 
@@ -1256,7 +1171,7 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1256
  <img src="${result.image_url}" class="max-h-48 mx-auto rounded" />
1257
  </div>`;
1258
  } else if (result.image_base64) {
1259
- contentPreview = `<div class="bg-black/20 p-4 rounded-lg mb-4 text-center">
1260
  <img src="${result.image_base64}" class="max-h-48 mx-auto rounded" />
1261
  </div>`;
1262
  }
@@ -1282,7 +1197,7 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1282
  </div>
1283
  ` : ''}
1284
  <div class="grid grid-cols-2 md:grid-cols-4 gap-2 text-xs">
1285
- ${Object.entries(result.category_scores).map(([category, score]) => `
1286
  <div class="bg-black/20 p-2 rounded">
1287
  <div class="font-medium">${category.replace('/', ' ')}</div>
1288
  <div class="w-full bg-gray-700 rounded-full h-1.5 mt-1">
@@ -1290,7 +1205,7 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1290
  </div>
1291
  <div class="text-right mt-1">${(score * 100).toFixed(0)}%</div>
1292
  </div>
1293
- `).join('')}
1294
  </div>
1295
  </div>
1296
  </div>
@@ -1304,7 +1219,8 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1304
  }
1305
 
1306
  async function updateMetrics() {
1307
- const apiKey = apiKeyInput.value.trim() || 'temp-key-for-metrics';
 
1308
  try {
1309
  const response = await fetch('/v1/metrics', {
1310
  headers: { 'Authorization': 'Bearer ' + apiKey }
@@ -1323,11 +1239,7 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1323
  }
1324
 
1325
  function showLoading(show) {
1326
- if (show) {
1327
- loadingModal.classList.remove('hidden');
1328
- } else {
1329
- loadingModal.classList.add('hidden');
1330
- }
1331
  }
1332
 
1333
  function showNotification(message, type = 'info') {
@@ -1335,12 +1247,7 @@ with open("templates/index.html", "w", encoding='utf-8') as f:
1335
  notification.className = `fixed top-4 right-4 p-4 rounded-lg shadow-lg z-50 ${
1336
  type === 'error' ? 'bg-red-500' : 'bg-indigo-500'
1337
  } text-white`;
1338
- notification.innerHTML = `
1339
- <div class="flex items-center">
1340
- <i class="fas ${type === 'error' ? 'fa-exclamation-circle' : 'fa-info-circle'} mr-2"></i>
1341
- <span>${message}</span>
1342
- </div>
1343
- `;
1344
 
1345
  document.body.appendChild(notification);
1346
 
 
9
  import numpy as np
10
  import asyncio
11
  from typing import List, Dict, Any, Optional, Union
12
+ from fastapi import FastAPI, HTTPException, Depends, Request
13
  from fastapi.concurrency import run_in_threadpool
14
+ from fastapi.responses import HTMLResponse
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.templating import Jinja2Templates
17
  from pydantic import BaseModel, Field
 
345
  "nsfw_score": 0.0
346
  }
347
 
348
+ def process_content_item(item: Union[str, TextContent, ImageContent], text_model: str = "gemma") -> Dict:
349
+ work_item = {}
350
  if isinstance(item, str):
351
+ work_item = {"type": "text", "text": item}
352
+ elif isinstance(item, (TextContent, ImageContent)):
353
+ work_item = item.model_dump()
354
+ else:
355
+ # This case should ideally not be hit with proper Pydantic validation
356
+ return {
357
+ "flagged": False,
358
+ "error": f"Unsupported item type: {type(item).__name__}"
359
+ }
360
 
361
+ content_type = work_item.get("type")
362
 
363
  if content_type == "text":
364
+ text = work_item.get("text", "")
365
  if text_model == "gemma":
366
  gemma_result = classify_text_with_gemma(text)
367
  flagged = gemma_result["classification"] == "u"
 
374
  "violence": 0.9 if flagged else 0.1, "violence/graphic": 0.9 if flagged else 0.1,
375
  "nsfw": 0.1,
376
  }
377
+ return {"flagged": flagged, "categories": {k: (v > 0.5) for k, v in scores.items()}, "category_scores": scores, "text": text}
 
 
 
 
 
378
  elif text_model == "detoxify":
379
  d = classify_text_with_detoxify(text)
380
  scores = {
 
385
  "violence": d["category_scores"].get("threat", 0.1), "violence/graphic": d["category_scores"].get("threat", 0.1),
386
  "nsfw": d["category_scores"].get("sexual_explicit", 0.1),
387
  }
388
+ return {"flagged": d["flagged"], "categories": {k: (v > 0.5) for k, v in scores.items()}, "category_scores": scores, "text": text}
 
 
 
 
 
389
  elif text_model == "both":
390
  gemma_result = classify_text_with_gemma(text)
391
  detoxify_result = classify_text_with_detoxify(text)
 
404
  "violence/graphic": max(0.9 if gemma_result["classification"] == "u" else 0.1, detoxify_result["category_scores"].get("threat", 0.1)),
405
  "nsfw": detoxify_result["category_scores"].get("sexual_explicit", 0.1),
406
  }
407
+ return {"flagged": flagged, "categories": {k: (v > 0.5) for k, v in scores.items()}, "category_scores": scores, "text": text}
 
 
 
 
 
408
 
409
  elif content_type == "image":
410
  image_data = None
411
+ image_url = work_item.get("url")
412
+ image_base64 = work_item.get("base64")
413
 
414
  if image_url:
415
  try:
 
443
  "categories": {k: (v > 0.5) for k, v in scores.items()},
444
  "category_scores": scores,
445
  "image_url": image_url,
446
+ "image_base64": work_item.get("base64"),
447
  }
448
 
449
  default_scores = {
 
456
  "flagged": False,
457
  "categories": {k: False for k in default_scores},
458
  "category_scores": default_scores,
459
+ "error": f"Invalid or unprocessable item: {work_item}"
460
  }
461
 
462
  def get_api_key(request: Request):
 
502
  for item in items:
503
  if isinstance(item, str):
504
  total_tokens += count_tokens(item)
505
+ elif isinstance(item, TextContent):
506
+ total_tokens += count_tokens(item.text)
507
  else:
508
  raise HTTPException(status_code=400, detail="Invalid input format")
509
 
 
1026
  }
1027
  });
1028
 
1029
+ async function analyze(payload) {
 
 
 
 
 
 
1030
  const apiKey = apiKeyInput.value.trim();
1031
  if (!apiKey) {
1032
  showNotification('Please enter your API key', 'error');
1033
  return;
1034
  }
1035
 
 
 
1036
  showLoading(true);
1037
  try {
1038
  const response = await fetch('/v1/moderations', {
 
1041
  'Content-Type': 'application/json',
1042
  'Authorization': `Bearer ${apiKey}`
1043
  },
1044
+ body: JSON.stringify(payload)
 
 
 
1045
  });
1046
+
1047
  if (!response.ok) {
1048
  const errorData = await response.json();
1049
+ throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
1050
  }
1051
+
1052
  const data = await response.json();
1053
  displayResults(data.results);
1054
  updateMetrics();
 
1057
  } finally {
1058
  showLoading(false);
1059
  }
1060
+ }
1061
+
1062
+ analyzeTextBtn.addEventListener('click', async () => {
1063
+ const text = textInput.value.trim();
1064
+ if (!text) {
1065
+ showNotification('Please enter text to analyze', 'error');
1066
+ return;
1067
+ }
1068
+ await analyze({
1069
+ input: text,
1070
+ model: textModelSelect.value
1071
+ });
1072
  });
1073
 
1074
  analyzeImageBtn.addEventListener('click', async () => {
1075
  const url = imageUrl.value.trim();
1076
+ const file = imageUpload.files[0];
1077
+
 
1078
  if (!url && !file) {
1079
  showNotification('Please provide an image URL or upload an image', 'error');
1080
  return;
1081
  }
1082
+
 
 
 
 
 
 
1083
  let imageInput;
 
1084
  if (url) {
1085
+ imageInput = { type: "image", url: url };
 
 
 
1086
  } else {
1087
+ const base64 = await new Promise((resolve) => {
1088
+ const reader = new FileReader();
1089
  reader.onload = (event) => resolve(event.target.result);
1090
+ reader.readAsDataURL(file);
1091
  });
1092
+ imageInput = { type: "image", base64: base64 };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1093
  }
1094
+
1095
+ await analyze({
1096
+ input: [imageInput],
1097
+ model: textModelSelect.value
1098
+ });
1099
  });
1100
+
1101
  analyzeMixedBtn.addEventListener('click', async () => {
1102
  const items = Array.from(mixedItemsContainer.querySelectorAll('.mixed-item'));
1103
  if (items.length === 0) {
1104
  showNotification('Please add at least one item to analyze', 'error');
1105
  return;
1106
  }
1107
+
1108
+ const inputPromises = items.map(async (item) => {
 
 
 
 
 
 
 
 
1109
  const type = item.querySelector('.item-type').value;
1110
  const contentDiv = item.querySelector('.item-content');
1111
+
1112
  if (type === 'text') {
1113
+ const text = contentDiv.querySelector('textarea').value.trim();
1114
+ return text ? { type: 'text', text: text } : null;
 
 
 
 
 
 
1115
  } else {
1116
+ const url = contentDiv.querySelector('input[type="text"]').value.trim();
1117
+ const file = contentDiv.querySelector('input[type="file"]').files[0];
 
 
 
 
 
1118
 
1119
  if (url) {
1120
+ return { type: 'image', url: url };
1121
+ } else if (file) {
1122
+ const base64 = await new Promise((resolve) => {
1123
+ const reader = new FileReader();
1124
+ reader.onload = (event) => resolve(event.target.result);
1125
+ reader.readAsDataURL(file);
 
 
 
1126
  });
1127
+ return { type: 'image', base64: base64 };
1128
  }
1129
+ return null;
1130
  }
1131
+ });
1132
+
1133
+ const inputItems = (await Promise.all(inputPromises)).filter(Boolean);
1134
+
1135
  if (inputItems.length === 0) {
1136
  showNotification('Please add content to at least one item', 'error');
1137
  return;
1138
  }
1139
 
1140
+ await analyze({
1141
+ input: inputItems,
1142
+ model: textModelSelect.value
1143
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1144
  });
1145
 
1146
  function displayResults(results) {
 
1155
  'Content may contain inappropriate or harmful material.' :
1156
  'Content appears to be safe and appropriate.';
1157
 
1158
+ const categories = result.categories ? Object.entries(result.categories)
1159
  .filter(([_, value]) => value)
1160
  .map(([key, _]) => key.replace('/', ' '))
1161
+ .join(', ') : 'N/A';
1162
 
1163
  let contentPreview = '';
1164
 
 
1171
  <img src="${result.image_url}" class="max-h-48 mx-auto rounded" />
1172
  </div>`;
1173
  } else if (result.image_base64) {
1174
+ contentPreview = `<div class="bg-black/20 p-4 rounded-lg mb-4 text-center">
1175
  <img src="${result.image_base64}" class="max-h-48 mx-auto rounded" />
1176
  </div>`;
1177
  }
 
1197
  </div>
1198
  ` : ''}
1199
  <div class="grid grid-cols-2 md:grid-cols-4 gap-2 text-xs">
1200
+ ${result.category_scores ? Object.entries(result.category_scores).map(([category, score]) => `
1201
  <div class="bg-black/20 p-2 rounded">
1202
  <div class="font-medium">${category.replace('/', ' ')}</div>
1203
  <div class="w-full bg-gray-700 rounded-full h-1.5 mt-1">
 
1205
  </div>
1206
  <div class="text-right mt-1">${(score * 100).toFixed(0)}%</div>
1207
  </div>
1208
+ `).join('') : ''}
1209
  </div>
1210
  </div>
1211
  </div>
 
1219
  }
1220
 
1221
  async function updateMetrics() {
1222
+ const apiKey = apiKeyInput.value.trim();
1223
+ if (!apiKey) return;
1224
  try {
1225
  const response = await fetch('/v1/metrics', {
1226
  headers: { 'Authorization': 'Bearer ' + apiKey }
 
1239
  }
1240
 
1241
  function showLoading(show) {
1242
+ loadingModal.style.display = show ? 'flex' : 'none';
 
 
 
 
1243
  }
1244
 
1245
  function showNotification(message, type = 'info') {
 
1247
  notification.className = `fixed top-4 right-4 p-4 rounded-lg shadow-lg z-50 ${
1248
  type === 'error' ? 'bg-red-500' : 'bg-indigo-500'
1249
  } text-white`;
1250
+ notification.innerHTML = `<div class="flex items-center"><i class="fas ${type === 'error' ? 'fa-exclamation-circle' : 'fa-info-circle'} mr-2"></i><span>${message}</span></div>`;
 
 
 
 
 
1251
 
1252
  document.body.appendChild(notification);
1253