Spaces:
Build error
Build error
add safety checker
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ import gradio as gr
|
|
| 6 |
import torch
|
| 7 |
from einops import rearrange
|
| 8 |
from PIL import Image
|
|
|
|
| 9 |
|
| 10 |
from flux.cli import SamplingOptions
|
| 11 |
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
|
@@ -13,6 +14,7 @@ from flux.util import load_ae, load_clip, load_flow_model, load_t5
|
|
| 13 |
from pulid.pipeline_flux import PuLIDPipeline
|
| 14 |
from pulid.utils import resize_numpy_image_long
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
def get_models(name: str, device: torch.device, offload: bool):
|
| 18 |
t5 = load_t5(device, max_length=128)
|
|
@@ -20,7 +22,8 @@ def get_models(name: str, device: torch.device, offload: bool):
|
|
| 20 |
model = load_flow_model(name, device="cpu" if offload else device)
|
| 21 |
model.eval()
|
| 22 |
ae = load_ae(name, device="cpu" if offload else device)
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class FluxGenerator:
|
|
@@ -28,7 +31,7 @@ class FluxGenerator:
|
|
| 28 |
self.device = torch.device('cuda')
|
| 29 |
self.offload = False
|
| 30 |
self.model_name = 'flux-dev'
|
| 31 |
-
self.model, self.ae, self.t5, self.clip = get_models(
|
| 32 |
self.model_name,
|
| 33 |
device=self.device,
|
| 34 |
offload=self.offload,
|
|
@@ -147,7 +150,12 @@ def generate_image(
|
|
| 147 |
x = rearrange(x[0], "c h w -> h w c")
|
| 148 |
|
| 149 |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
_HEADER_ = '''
|
| 153 |
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
|
|
|
| 6 |
import torch
|
| 7 |
from einops import rearrange
|
| 8 |
from PIL import Image
|
| 9 |
+
from transformers import pipeline
|
| 10 |
|
| 11 |
from flux.cli import SamplingOptions
|
| 12 |
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
|
|
|
| 14 |
from pulid.pipeline_flux import PuLIDPipeline
|
| 15 |
from pulid.utils import resize_numpy_image_long
|
| 16 |
|
| 17 |
+
NSFW_THRESHOLD = 0.85
|
| 18 |
|
| 19 |
def get_models(name: str, device: torch.device, offload: bool):
|
| 20 |
t5 = load_t5(device, max_length=128)
|
|
|
|
| 22 |
model = load_flow_model(name, device="cpu" if offload else device)
|
| 23 |
model.eval()
|
| 24 |
ae = load_ae(name, device="cpu" if offload else device)
|
| 25 |
+
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
|
| 26 |
+
return model, ae, t5, clip, nsfw_classifier
|
| 27 |
|
| 28 |
|
| 29 |
class FluxGenerator:
|
|
|
|
| 31 |
self.device = torch.device('cuda')
|
| 32 |
self.offload = False
|
| 33 |
self.model_name = 'flux-dev'
|
| 34 |
+
self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
|
| 35 |
self.model_name,
|
| 36 |
device=self.device,
|
| 37 |
offload=self.offload,
|
|
|
|
| 150 |
x = rearrange(x[0], "c h w -> h w c")
|
| 151 |
|
| 152 |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
| 153 |
+
nsfw_score = [x["score"] for x in flux_generator.nsfw_classifier(img) if x["label"] == "nsfw"][0]
|
| 154 |
+
if nsfw_score < NSFW_THRESHOLD:
|
| 155 |
+
return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
|
| 156 |
+
else:
|
| 157 |
+
return (None, f"Your generated image may contain NSFW (with nsfw_score: {nsfw_score}) content",
|
| 158 |
+
flux_generator.pulid_model.debug_img_list)
|
| 159 |
|
| 160 |
_HEADER_ = '''
|
| 161 |
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|