Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,30 +25,49 @@ from transformers import (
|
|
| 25 |
AutoConfig,
|
| 26 |
AutoModelForImageSegmentation,
|
| 27 |
)
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
config = AutoConfig.from_pretrained(
|
| 31 |
-
"zhengpeng7/BiRefNet",
|
| 32 |
trust_remote_code=True
|
| 33 |
)
|
| 34 |
|
| 35 |
-
# 2) config.get_text_config
|
| 36 |
def dummy_get_text_config(decoder=True):
|
| 37 |
return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
|
| 38 |
|
| 39 |
config.get_text_config = dummy_get_text_config
|
| 40 |
|
| 41 |
-
# 3) ๋ชจ๋ธ ๊ตฌ์กฐ๋ง ๋ง๋ค๊ธฐ
|
| 42 |
birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
|
| 43 |
birefnet.eval()
|
| 44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
birefnet.to(device)
|
| 46 |
birefnet.half()
|
| 47 |
|
| 48 |
-
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
|
| 53 |
print("[Info] Missing keys:", missing)
|
| 54 |
print("[Info] Unexpected keys:", unexpected)
|
|
@@ -56,7 +75,7 @@ torch.cuda.empty_cache()
|
|
| 56 |
|
| 57 |
|
| 58 |
##########################################################
|
| 59 |
-
#
|
| 60 |
##########################################################
|
| 61 |
|
| 62 |
def refine_foreground(image, mask, r=90):
|
|
@@ -85,7 +104,6 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
|
| 85 |
F = np.clip(F, 0, 1)
|
| 86 |
return F, blurred_B
|
| 87 |
|
| 88 |
-
|
| 89 |
class ImagePreprocessor():
|
| 90 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
| 91 |
self.transform_image = transforms.Compose([
|
|
@@ -99,7 +117,7 @@ class ImagePreprocessor():
|
|
| 99 |
|
| 100 |
|
| 101 |
##########################################################
|
| 102 |
-
#
|
| 103 |
##########################################################
|
| 104 |
|
| 105 |
usage_to_weights_file = {
|
|
@@ -130,30 +148,24 @@ descriptions = (
|
|
| 130 |
"We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
|
| 131 |
)
|
| 132 |
|
| 133 |
-
|
| 134 |
##########################################################
|
| 135 |
-
#
|
| 136 |
##########################################################
|
| 137 |
|
| 138 |
@spaces.GPU
|
| 139 |
def predict(images, resolution, weights_file):
|
| 140 |
-
|
| 141 |
-
์ฌ๊ธฐ์๋, ๋จ์ผ birefnet ๋ชจ๋ธ๋ง ์ ์งํ๊ณ ์์ผ๋ฉฐ,
|
| 142 |
-
weight_file์ ๋ฐ๊พธ๋๋ผ๋ ์ค์ ๋ก๋ ์ด๋ฏธ ๋ก๋๋ 'birefnet' ๋ชจ๋ธ๋ง ์ฌ์ฉ.
|
| 143 |
-
(๋ง์ฝ ๋ค๋ฅธ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๊ณ ์ถ๋ค๋ฉด, ์๋์ฒ๋ผ ๋ก์ปฌ state_dict ๊ต์ฒด ๋ฐฉ์ ์ถ๊ฐ ๊ฐ๋ฅ.)
|
| 144 |
-
"""
|
| 145 |
assert images is not None, 'Images cannot be None.'
|
| 146 |
|
| 147 |
-
#
|
| 148 |
try:
|
| 149 |
-
w, h = resolution.strip().split('x')
|
| 150 |
-
w, h = int(
|
| 151 |
-
resolution_list = (w, h)
|
| 152 |
except:
|
| 153 |
-
|
| 154 |
-
|
| 155 |
|
| 156 |
-
#
|
| 157 |
if isinstance(images, list):
|
| 158 |
is_batch = True
|
| 159 |
outputs, save_paths = [], []
|
|
@@ -164,65 +176,57 @@ def predict(images, resolution, weights_file):
|
|
| 164 |
is_batch = False
|
| 165 |
|
| 166 |
for idx, image_src in enumerate(images):
|
| 167 |
-
#
|
| 168 |
if isinstance(image_src, str):
|
| 169 |
if os.path.isfile(image_src):
|
| 170 |
image_ori = Image.open(image_src)
|
| 171 |
else:
|
| 172 |
resp = requests.get(image_src)
|
| 173 |
image_ori = Image.open(BytesIO(resp.content))
|
| 174 |
-
# numpy
|
| 175 |
elif isinstance(image_src, np.ndarray):
|
| 176 |
image_ori = Image.fromarray(image_src)
|
| 177 |
else:
|
| 178 |
image_ori = image_src.convert('RGB')
|
| 179 |
|
| 180 |
-
|
| 181 |
-
preproc = ImagePreprocessor(
|
| 182 |
-
image_proc = preproc.proc(
|
| 183 |
|
| 184 |
-
#
|
| 185 |
with torch.inference_mode():
|
| 186 |
-
# ๊ฒฐ๊ณผ ๋งจ ๋ง์ง๋ง ๋ ์ด์ด preds
|
| 187 |
preds = birefnet(image_proc)[-1].sigmoid().cpu()
|
| 188 |
pred_mask = preds[0].squeeze()
|
| 189 |
|
| 190 |
# ํ์ฒ๋ฆฌ
|
| 191 |
pred_pil = transforms.ToPILImage()(pred_mask)
|
| 192 |
-
image_masked = refine_foreground(
|
| 193 |
-
image_masked.putalpha(pred_pil.resize(
|
| 194 |
|
| 195 |
if is_batch:
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
)
|
| 201 |
-
out_path = os.path.join(save_dir, f"{file_name}.png")
|
| 202 |
-
image_masked.save(out_path)
|
| 203 |
-
save_paths.append(out_path)
|
| 204 |
outputs.append(image_masked)
|
| 205 |
else:
|
| 206 |
outputs = [image_masked, image_ori]
|
| 207 |
|
| 208 |
torch.cuda.empty_cache()
|
| 209 |
|
| 210 |
-
# ๋ฐฐ์น๋ผ๋ฉด ๊ฐค๋ฌ๋ฆฌ + ZIP ๋ฐํ
|
| 211 |
if is_batch:
|
| 212 |
-
|
| 213 |
-
with zipfile.ZipFile(
|
| 214 |
for fpath in save_paths:
|
| 215 |
zipf.write(fpath, os.path.basename(fpath))
|
| 216 |
-
return
|
| 217 |
else:
|
| 218 |
return outputs
|
| 219 |
|
| 220 |
-
|
| 221 |
##########################################################
|
| 222 |
-
#
|
| 223 |
##########################################################
|
| 224 |
|
| 225 |
-
# ์ปค์คํ
CSS
|
| 226 |
css = """
|
| 227 |
body {
|
| 228 |
background: linear-gradient(135deg, #667eea, #764ba2);
|
|
@@ -280,14 +284,13 @@ button:hover, .btn:hover {
|
|
| 280 |
title_html = """
|
| 281 |
<h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
|
| 282 |
<p align="center" style="font-size:1.1em; color:#555;">
|
| 283 |
-
Using <code>from_config()</code> + local <code>state_dict</code> to bypass tie_weights issues
|
| 284 |
</p>
|
| 285 |
"""
|
| 286 |
|
| 287 |
with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
| 288 |
gr.Markdown(title_html)
|
| 289 |
with gr.Tabs():
|
| 290 |
-
# ํญ 1: Image
|
| 291 |
with gr.Tab("Image"):
|
| 292 |
with gr.Row():
|
| 293 |
with gr.Column(scale=1):
|
|
@@ -297,13 +300,8 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
|
| 297 |
predict_btn = gr.Button("Predict")
|
| 298 |
with gr.Column(scale=2):
|
| 299 |
output_slider = ImageSlider(label="Result", type="pil")
|
| 300 |
-
gr.Examples(
|
| 301 |
-
examples=examples_image,
|
| 302 |
-
inputs=[image_input, resolution_input, weights_radio],
|
| 303 |
-
label="Examples"
|
| 304 |
-
)
|
| 305 |
|
| 306 |
-
# ํญ 2: Text(URL)
|
| 307 |
with gr.Tab("Text"):
|
| 308 |
with gr.Row():
|
| 309 |
with gr.Column(scale=1):
|
|
@@ -313,36 +311,23 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
|
| 313 |
predict_btn_text = gr.Button("Predict")
|
| 314 |
with gr.Column(scale=2):
|
| 315 |
output_slider_text = ImageSlider(label="Result", type="pil")
|
| 316 |
-
gr.Examples(
|
| 317 |
-
examples=examples_text,
|
| 318 |
-
inputs=[image_url, resolution_input_text, weights_radio_text],
|
| 319 |
-
label="Examples"
|
| 320 |
-
)
|
| 321 |
|
| 322 |
-
# ํญ 3: Batch
|
| 323 |
with gr.Tab("Batch"):
|
| 324 |
with gr.Row():
|
| 325 |
with gr.Column(scale=1):
|
| 326 |
-
file_input = gr.File(
|
| 327 |
-
label="Upload Multiple Images",
|
| 328 |
-
type="filepath",
|
| 329 |
-
file_count="multiple"
|
| 330 |
-
)
|
| 331 |
resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
|
| 332 |
weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
| 333 |
predict_btn_batch = gr.Button("Predict")
|
| 334 |
with gr.Column(scale=2):
|
| 335 |
output_gallery = gr.Gallery(label="Results", scale=1)
|
| 336 |
zip_output = gr.File(label="Zip Download")
|
| 337 |
-
gr.Examples(
|
| 338 |
-
examples=examples_batch,
|
| 339 |
-
inputs=[file_input, resolution_input_batch, weights_radio_batch],
|
| 340 |
-
label="Examples"
|
| 341 |
-
)
|
| 342 |
|
| 343 |
gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
|
| 344 |
|
| 345 |
-
#
|
| 346 |
predict_btn.click(
|
| 347 |
fn=predict,
|
| 348 |
inputs=[image_input, resolution_input, weights_radio],
|
|
|
|
| 25 |
AutoConfig,
|
| 26 |
AutoModelForImageSegmentation,
|
| 27 |
)
|
| 28 |
+
# Hugging Face Hub
|
| 29 |
+
from huggingface_hub import hf_hub_download
|
| 30 |
|
| 31 |
+
|
| 32 |
+
##########################################################
|
| 33 |
+
# 1. Config ๋ฐ from_config() ์ด๊ธฐํ
|
| 34 |
+
##########################################################
|
| 35 |
+
|
| 36 |
+
# 1) Config๋ง ๋จผ์ ๋ก๋
|
| 37 |
config = AutoConfig.from_pretrained(
|
| 38 |
+
"zhengpeng7/BiRefNet", # ์์
|
| 39 |
trust_remote_code=True
|
| 40 |
)
|
| 41 |
|
| 42 |
+
# 2) config.get_text_config์ ๋๋ฏธ ๋ฉ์๋ ๋ถ์ฌ (tie_word_embeddings=False)
|
| 43 |
def dummy_get_text_config(decoder=True):
|
| 44 |
return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
|
| 45 |
|
| 46 |
config.get_text_config = dummy_get_text_config
|
| 47 |
|
| 48 |
+
# 3) ๋ชจ๋ธ ๊ตฌ์กฐ๋ง ๋ง๋ค๊ธฐ
|
| 49 |
birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
|
| 50 |
birefnet.eval()
|
| 51 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 52 |
birefnet.to(device)
|
| 53 |
birefnet.half()
|
| 54 |
|
| 55 |
+
##########################################################
|
| 56 |
+
# 2. ๋ชจ๋ธ ๊ฐ์ค์น ๋ค์ด๋ก๋ & ๋ก๋
|
| 57 |
+
##########################################################
|
| 58 |
+
|
| 59 |
+
# huggingface_hub์์ safetensors ๋๋ bin ํ์ผ ๋ค์ด๋ก๋
|
| 60 |
+
# (repo_id, filename ๋ฑ์ ์ค์ ์ฌ์ฉ ํ๊ฒฝ์ ๋ง๊ฒ ๋ณ๊ฒฝ)
|
| 61 |
+
weights_path = hf_hub_download(
|
| 62 |
+
repo_id="zhengpeng7/BiRefNet", # ์์
|
| 63 |
+
filename="model.safetensors", # ๋๋ "pytorch_model.bin"
|
| 64 |
+
trust_remote_code=True
|
| 65 |
+
)
|
| 66 |
+
print("Downloaded weights to:", weights_path)
|
| 67 |
+
|
| 68 |
+
# state_dict ๋ก๋
|
| 69 |
+
print("Loading BiRefNet weights from HF Hub file:", weights_path)
|
| 70 |
+
state_dict = torch.load(weights_path, map_location="cpu")
|
| 71 |
missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
|
| 72 |
print("[Info] Missing keys:", missing)
|
| 73 |
print("[Info] Unexpected keys:", unexpected)
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
##########################################################
|
| 78 |
+
# 3. ์ด๋ฏธ์ง ํ์ฒ๋ฆฌ ํจ์๋ค
|
| 79 |
##########################################################
|
| 80 |
|
| 81 |
def refine_foreground(image, mask, r=90):
|
|
|
|
| 104 |
F = np.clip(F, 0, 1)
|
| 105 |
return F, blurred_B
|
| 106 |
|
|
|
|
| 107 |
class ImagePreprocessor():
|
| 108 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
| 109 |
self.transform_image = transforms.Compose([
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
##########################################################
|
| 120 |
+
# 4. ์์ ์ค์ ๋ฐ ๊ธฐํ
|
| 121 |
##########################################################
|
| 122 |
|
| 123 |
usage_to_weights_file = {
|
|
|
|
| 148 |
"We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
|
| 149 |
)
|
| 150 |
|
|
|
|
| 151 |
##########################################################
|
| 152 |
+
# 5. ์ถ๋ก ํจ์ (์ด๋ฏธ ๋ก๋๋ birefnet ๋ชจ๋ธ ์ฌ์ฉ)
|
| 153 |
##########################################################
|
| 154 |
|
| 155 |
@spaces.GPU
|
| 156 |
def predict(images, resolution, weights_file):
|
| 157 |
+
# weights_file์ ์ฌ๊ธฐ์๋ ๋ฌด์ํ๊ณ , ์ด๋ฏธ ๋ก๋๋ birefnet ์ฌ์ฉ
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
assert images is not None, 'Images cannot be None.'
|
| 159 |
|
| 160 |
+
# Parse resolution
|
| 161 |
try:
|
| 162 |
+
w, h = map(int, resolution.strip().split('x'))
|
| 163 |
+
w, h = int(w//32*32), int(h//32*32)
|
|
|
|
| 164 |
except:
|
| 165 |
+
w, h = 1024, 1024
|
| 166 |
+
resolution_tuple = (w, h)
|
| 167 |
|
| 168 |
+
# ๋ฆฌ์คํธ์ธ์ง ํ์ธ
|
| 169 |
if isinstance(images, list):
|
| 170 |
is_batch = True
|
| 171 |
outputs, save_paths = [], []
|
|
|
|
| 176 |
is_batch = False
|
| 177 |
|
| 178 |
for idx, image_src in enumerate(images):
|
| 179 |
+
# ํ์ผ ๊ฒฝ๋ก ํน์ URL
|
| 180 |
if isinstance(image_src, str):
|
| 181 |
if os.path.isfile(image_src):
|
| 182 |
image_ori = Image.open(image_src)
|
| 183 |
else:
|
| 184 |
resp = requests.get(image_src)
|
| 185 |
image_ori = Image.open(BytesIO(resp.content))
|
| 186 |
+
# numpy array โ PIL
|
| 187 |
elif isinstance(image_src, np.ndarray):
|
| 188 |
image_ori = Image.fromarray(image_src)
|
| 189 |
else:
|
| 190 |
image_ori = image_src.convert('RGB')
|
| 191 |
|
| 192 |
+
# ์ ์ฒ๋ฆฌ
|
| 193 |
+
preproc = ImagePreprocessor(resolution_tuple)
|
| 194 |
+
image_proc = preproc.proc(image_ori.convert('RGB')).unsqueeze(0).to(device).half()
|
| 195 |
|
| 196 |
+
# ์ถ๋ก
|
| 197 |
with torch.inference_mode():
|
|
|
|
| 198 |
preds = birefnet(image_proc)[-1].sigmoid().cpu()
|
| 199 |
pred_mask = preds[0].squeeze()
|
| 200 |
|
| 201 |
# ํ์ฒ๋ฆฌ
|
| 202 |
pred_pil = transforms.ToPILImage()(pred_mask)
|
| 203 |
+
image_masked = refine_foreground(image_ori, pred_pil)
|
| 204 |
+
image_masked.putalpha(pred_pil.resize(image_ori.size))
|
| 205 |
|
| 206 |
if is_batch:
|
| 207 |
+
fbase = (os.path.splitext(os.path.basename(image_src))[0] if isinstance(image_src, str) else f"img_{idx}")
|
| 208 |
+
outpath = os.path.join(save_dir, f"{fbase}.png")
|
| 209 |
+
image_masked.save(outpath)
|
| 210 |
+
save_paths.append(outpath)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
outputs.append(image_masked)
|
| 212 |
else:
|
| 213 |
outputs = [image_masked, image_ori]
|
| 214 |
|
| 215 |
torch.cuda.empty_cache()
|
| 216 |
|
|
|
|
| 217 |
if is_batch:
|
| 218 |
+
zippath = os.path.join(save_dir, f"{save_dir}.zip")
|
| 219 |
+
with zipfile.ZipFile(zippath, 'w') as zipf:
|
| 220 |
for fpath in save_paths:
|
| 221 |
zipf.write(fpath, os.path.basename(fpath))
|
| 222 |
+
return outputs, zippath
|
| 223 |
else:
|
| 224 |
return outputs
|
| 225 |
|
|
|
|
| 226 |
##########################################################
|
| 227 |
+
# 6. Gradio UI
|
| 228 |
##########################################################
|
| 229 |
|
|
|
|
| 230 |
css = """
|
| 231 |
body {
|
| 232 |
background: linear-gradient(135deg, #667eea, #764ba2);
|
|
|
|
| 284 |
title_html = """
|
| 285 |
<h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
|
| 286 |
<p align="center" style="font-size:1.1em; color:#555;">
|
| 287 |
+
Using <code>from_config()</code> + local <code>state_dict</code> or <code>hf_hub_download</code> to bypass tie_weights issues
|
| 288 |
</p>
|
| 289 |
"""
|
| 290 |
|
| 291 |
with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
| 292 |
gr.Markdown(title_html)
|
| 293 |
with gr.Tabs():
|
|
|
|
| 294 |
with gr.Tab("Image"):
|
| 295 |
with gr.Row():
|
| 296 |
with gr.Column(scale=1):
|
|
|
|
| 300 |
predict_btn = gr.Button("Predict")
|
| 301 |
with gr.Column(scale=2):
|
| 302 |
output_slider = ImageSlider(label="Result", type="pil")
|
| 303 |
+
gr.Examples(examples=examples_image, inputs=[image_input, resolution_input, weights_radio], label="Examples")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
|
|
|
| 305 |
with gr.Tab("Text"):
|
| 306 |
with gr.Row():
|
| 307 |
with gr.Column(scale=1):
|
|
|
|
| 311 |
predict_btn_text = gr.Button("Predict")
|
| 312 |
with gr.Column(scale=2):
|
| 313 |
output_slider_text = ImageSlider(label="Result", type="pil")
|
| 314 |
+
gr.Examples(examples=examples_text, inputs=[image_url, resolution_input_text, weights_radio_text], label="Examples")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
|
|
|
| 316 |
with gr.Tab("Batch"):
|
| 317 |
with gr.Row():
|
| 318 |
with gr.Column(scale=1):
|
| 319 |
+
file_input = gr.File(label="Upload Multiple Images", type="filepath", file_count="multiple")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
|
| 321 |
weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
| 322 |
predict_btn_batch = gr.Button("Predict")
|
| 323 |
with gr.Column(scale=2):
|
| 324 |
output_gallery = gr.Gallery(label="Results", scale=1)
|
| 325 |
zip_output = gr.File(label="Zip Download")
|
| 326 |
+
gr.Examples(examples=examples_batch, inputs=[file_input, resolution_input_batch, weights_radio_batch], label="Examples")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
|
| 329 |
|
| 330 |
+
# ์ด๋ฒคํธ ์ฐ๊ฒฐ
|
| 331 |
predict_btn.click(
|
| 332 |
fn=predict,
|
| 333 |
inputs=[image_input, resolution_input, weights_radio],
|