Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
# 이미지 분할 모델에 해당하는 모든 클래스의 tie_weights를 빈 함수로 오버라이드
|
| 6 |
-
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
| 7 |
-
for model_class in MODEL_FOR_IMAGE_SEGMENTATION_MAPPING.values():
|
| 8 |
-
model_class.tie_weights = lambda self: None
|
| 9 |
-
# --- 패치 종료 ---
|
| 10 |
-
|
| 11 |
-
from transformers import AutoModelForImageSegmentation
|
| 12 |
-
from transformers import PreTrainedModel # (참고용)
|
| 13 |
import os
|
| 14 |
import cv2
|
| 15 |
import numpy as np
|
|
@@ -23,18 +15,50 @@ from typing import Tuple, Optional
|
|
| 23 |
from PIL import Image
|
| 24 |
from gradio_imageslider import ImageSlider
|
| 25 |
from torchvision import transforms
|
| 26 |
-
|
| 27 |
import requests
|
| 28 |
from io import BytesIO
|
| 29 |
import zipfile
|
| 30 |
import random
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
### 이미지 후처리 함수들 ###
|
| 38 |
def refine_foreground(image, mask, r=90):
|
| 39 |
if mask.size != image.size:
|
| 40 |
mask = mask.resize(image.size)
|
|
@@ -61,6 +85,7 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
|
| 61 |
F = np.clip(F, 0, 1)
|
| 62 |
return F, blurred_B
|
| 63 |
|
|
|
|
| 64 |
class ImagePreprocessor():
|
| 65 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
| 66 |
self.transform_image = transforms.Compose([
|
|
@@ -72,6 +97,11 @@ class ImagePreprocessor():
|
|
| 72 |
image = self.transform_image(image)
|
| 73 |
return image
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
usage_to_weights_file = {
|
| 76 |
'General': 'BiRefNet',
|
| 77 |
'General-HR': 'BiRefNet_HR',
|
|
@@ -86,105 +116,113 @@ usage_to_weights_file = {
|
|
| 86 |
'General-legacy': 'BiRefNet-legacy'
|
| 87 |
}
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
@spaces.GPU
|
| 98 |
def predict(images, resolution, weights_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
assert images is not None, 'Images cannot be None.'
|
| 100 |
-
global birefnet
|
| 101 |
-
# 선택된 가중치로 모델 재로딩
|
| 102 |
-
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
| 103 |
-
print('Using weights: {}.'.format(_weights_file))
|
| 104 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
| 105 |
-
birefnet.to(device)
|
| 106 |
-
birefnet.eval(); birefnet.half()
|
| 107 |
|
|
|
|
| 108 |
try:
|
| 109 |
-
|
|
|
|
|
|
|
| 110 |
except:
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
elif weights_file == 'General-Lite-2K':
|
| 114 |
-
resolution_list = [2560, 1440]
|
| 115 |
-
else:
|
| 116 |
-
resolution_list = [1024, 1024]
|
| 117 |
-
print('Invalid resolution input. Automatically changed to default.')
|
| 118 |
|
| 119 |
-
# 이미지가
|
| 120 |
if isinstance(images, list):
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
| 122 |
else:
|
| 123 |
images = [images]
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
save_paths = []
|
| 127 |
-
save_dir = 'preds-BiRefNet'
|
| 128 |
-
if tab_is_batch and not os.path.exists(save_dir):
|
| 129 |
-
os.makedirs(save_dir)
|
| 130 |
-
|
| 131 |
-
outputs = []
|
| 132 |
for idx, image_src in enumerate(images):
|
|
|
|
| 133 |
if isinstance(image_src, str):
|
| 134 |
if os.path.isfile(image_src):
|
| 135 |
image_ori = Image.open(image_src)
|
| 136 |
else:
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
else:
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
else:
|
| 144 |
-
image_ori = image_src.convert('RGB')
|
| 145 |
image = image_ori.convert('RGB')
|
| 146 |
-
|
| 147 |
-
image_proc =
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
image_masked = refine_foreground(image, pred_pil)
|
| 153 |
image_masked.putalpha(pred_pil.resize(image.size))
|
| 154 |
-
|
| 155 |
-
if
|
| 156 |
-
|
| 157 |
-
os.path.splitext(os.path.basename(image_src))[0]
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
| 161 |
outputs.append(image_masked)
|
| 162 |
else:
|
| 163 |
outputs = [image_masked, image_ori]
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
| 171 |
else:
|
| 172 |
return outputs
|
| 173 |
|
| 174 |
-
# 예제 데이터 (이미지, URL, 배치)
|
| 175 |
-
examples_image = [[path, "1024x1024", "General"] for path in glob('examples/*')]
|
| 176 |
-
examples_text = [[url, "1024x1024", "General"] for url in ["https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"]]
|
| 177 |
-
examples_batch = [[file, "1024x1024", "General"] for file in glob('examples/*')]
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
"`2048x2048` is suggested for BiRefNet_HR.\n"
|
| 183 |
-
"Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n"
|
| 184 |
-
"We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
|
| 185 |
-
)
|
| 186 |
|
| 187 |
-
#
|
| 188 |
css = """
|
| 189 |
body {
|
| 190 |
background: linear-gradient(135deg, #667eea, #764ba2);
|
|
@@ -239,16 +277,17 @@ button:hover, .btn:hover {
|
|
| 239 |
}
|
| 240 |
"""
|
| 241 |
|
| 242 |
-
|
| 243 |
-
<h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo
|
| 244 |
<p align="center" style="font-size:1.1em; color:#555;">
|
| 245 |
-
|
| 246 |
</p>
|
| 247 |
"""
|
| 248 |
|
| 249 |
with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
| 250 |
-
gr.Markdown(
|
| 251 |
with gr.Tabs():
|
|
|
|
| 252 |
with gr.Tab("Image"):
|
| 253 |
with gr.Row():
|
| 254 |
with gr.Column(scale=1):
|
|
@@ -257,8 +296,14 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
|
| 257 |
weights_radio = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
| 258 |
predict_btn = gr.Button("Predict")
|
| 259 |
with gr.Column(scale=2):
|
| 260 |
-
output_slider = ImageSlider(label="
|
| 261 |
-
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
with gr.Tab("Text"):
|
| 263 |
with gr.Row():
|
| 264 |
with gr.Column(scale=1):
|
|
@@ -267,23 +312,37 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
|
| 267 |
weights_radio_text = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
| 268 |
predict_btn_text = gr.Button("Predict")
|
| 269 |
with gr.Column(scale=2):
|
| 270 |
-
output_slider_text = ImageSlider(label="
|
| 271 |
-
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
with gr.Tab("Batch"):
|
| 273 |
with gr.Row():
|
| 274 |
with gr.Column(scale=1):
|
| 275 |
-
file_input = gr.File(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
|
| 277 |
weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
| 278 |
predict_btn_batch = gr.Button("Predict")
|
| 279 |
with gr.Column(scale=2):
|
| 280 |
-
output_gallery = gr.Gallery(label="
|
| 281 |
-
zip_output = gr.File(label="Download
|
| 282 |
-
gr.Examples(
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
-
#
|
| 287 |
predict_btn.click(
|
| 288 |
fn=predict,
|
| 289 |
inputs=[image_input, resolution_input, weights_radio],
|
|
|
|
| 1 |
+
##########################################################
|
| 2 |
+
# 0. 환경 설정 및 라이브러리 임포트
|
| 3 |
+
##########################################################
|
| 4 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import os
|
| 6 |
import cv2
|
| 7 |
import numpy as np
|
|
|
|
| 15 |
from PIL import Image
|
| 16 |
from gradio_imageslider import ImageSlider
|
| 17 |
from torchvision import transforms
|
|
|
|
| 18 |
import requests
|
| 19 |
from io import BytesIO
|
| 20 |
import zipfile
|
| 21 |
import random
|
| 22 |
|
| 23 |
+
# Transformers
|
| 24 |
+
from transformers import (
|
| 25 |
+
AutoConfig,
|
| 26 |
+
AutoModelForImageSegmentation,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# 1) Config를 먼저 로드하여 tie_weights 충돌을 방지
|
| 30 |
+
config = AutoConfig.from_pretrained(
|
| 31 |
+
"zhengpeng7/BiRefNet", # 👉 원하는 Hugging Face 모델 Repo
|
| 32 |
+
trust_remote_code=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# 2) config.get_text_config 에 더미 메서드 부여 (tie_word_embeddings=False)
|
| 36 |
+
def dummy_get_text_config(decoder=True):
|
| 37 |
+
return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
|
| 38 |
|
| 39 |
+
config.get_text_config = dummy_get_text_config
|
| 40 |
+
|
| 41 |
+
# 3) 모델 구조만 만들기 (from_config) -> tie_weights 자동 호출 안 됨
|
| 42 |
+
birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
|
| 43 |
+
birefnet.eval()
|
| 44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 45 |
+
birefnet.to(device)
|
| 46 |
+
birefnet.half()
|
| 47 |
+
|
| 48 |
+
# 4) state_dict 로드 (가중치) - 로컬 파일 사용 예시
|
| 49 |
+
# 실제로는 hf_hub_download / snapshot_download 등으로 "model.safetensors"를 미리 받은 뒤 사용
|
| 50 |
+
print("Loading BiRefNet weights from local file: model.safetensors")
|
| 51 |
+
state_dict = torch.load("model.safetensors", map_location="cpu") # 예시
|
| 52 |
+
missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
|
| 53 |
+
print("[Info] Missing keys:", missing)
|
| 54 |
+
print("[Info] Unexpected keys:", unexpected)
|
| 55 |
+
torch.cuda.empty_cache()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
##########################################################
|
| 59 |
+
# 1. 이미지 후처리 함수들
|
| 60 |
+
##########################################################
|
| 61 |
|
|
|
|
| 62 |
def refine_foreground(image, mask, r=90):
|
| 63 |
if mask.size != image.size:
|
| 64 |
mask = mask.resize(image.size)
|
|
|
|
| 85 |
F = np.clip(F, 0, 1)
|
| 86 |
return F, blurred_B
|
| 87 |
|
| 88 |
+
|
| 89 |
class ImagePreprocessor():
|
| 90 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
| 91 |
self.transform_image = transforms.Compose([
|
|
|
|
| 97 |
image = self.transform_image(image)
|
| 98 |
return image
|
| 99 |
|
| 100 |
+
|
| 101 |
+
##########################################################
|
| 102 |
+
# 2. 예제 설정 및 유틸
|
| 103 |
+
##########################################################
|
| 104 |
+
|
| 105 |
usage_to_weights_file = {
|
| 106 |
'General': 'BiRefNet',
|
| 107 |
'General-HR': 'BiRefNet_HR',
|
|
|
|
| 116 |
'General-legacy': 'BiRefNet-legacy'
|
| 117 |
}
|
| 118 |
|
| 119 |
+
examples_image = [[path, "1024x1024", "General"] for path in glob('examples/*')]
|
| 120 |
+
examples_text = [[url, "1024x1024", "General"] for url in [
|
| 121 |
+
"https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
|
| 122 |
+
]]
|
| 123 |
+
examples_batch = [[file, "1024x1024", "General"] for file in glob('examples/*')]
|
| 124 |
+
|
| 125 |
+
descriptions = (
|
| 126 |
+
"Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n"
|
| 127 |
+
"The resolution used in our training was `1024x1024`, which is suggested for good results! "
|
| 128 |
+
"`2048x2048` is suggested for BiRefNet_HR.\n"
|
| 129 |
+
"Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n"
|
| 130 |
+
"We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
|
| 131 |
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
##########################################################
|
| 135 |
+
# 3. 추론 함수 (이미 로드된 birefnet 모델 사용)
|
| 136 |
+
##########################################################
|
| 137 |
|
| 138 |
@spaces.GPU
|
| 139 |
def predict(images, resolution, weights_file):
|
| 140 |
+
"""
|
| 141 |
+
여기서는, 단일 birefnet 모델만 유지하고 있으며,
|
| 142 |
+
weight_file을 바꾸더라도 실제로는 이미 로드된 'birefnet' 모델만 사용.
|
| 143 |
+
(만약 다�� 가중치를 로드하고 싶다면, 아래처럼 로컬 state_dict 교체 방식 추가 가능.)
|
| 144 |
+
"""
|
| 145 |
assert images is not None, 'Images cannot be None.'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
+
# Resolution parse
|
| 148 |
try:
|
| 149 |
+
w, h = resolution.strip().split('x')
|
| 150 |
+
w, h = int(int(w)//32*32), int(int(h)//32*32)
|
| 151 |
+
resolution_list = (w, h)
|
| 152 |
except:
|
| 153 |
+
print('[WARN] Invalid resolution input. Fallback to 1024x1024.')
|
| 154 |
+
resolution_list = (1024, 1024)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
# 이미지가 여러 장일 수 있으므로 리스트로 처리
|
| 157 |
if isinstance(images, list):
|
| 158 |
+
is_batch = True
|
| 159 |
+
outputs, save_paths = [], []
|
| 160 |
+
save_dir = 'preds-BiRefNet'
|
| 161 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 162 |
else:
|
| 163 |
images = [images]
|
| 164 |
+
is_batch = False
|
| 165 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
for idx, image_src in enumerate(images):
|
| 167 |
+
# str이면 파일 경로 혹은 URL
|
| 168 |
if isinstance(image_src, str):
|
| 169 |
if os.path.isfile(image_src):
|
| 170 |
image_ori = Image.open(image_src)
|
| 171 |
else:
|
| 172 |
+
resp = requests.get(image_src)
|
| 173 |
+
image_ori = Image.open(BytesIO(resp.content))
|
| 174 |
+
# numpy 배열이면 Pillow 변환
|
| 175 |
+
elif isinstance(image_src, np.ndarray):
|
| 176 |
+
image_ori = Image.fromarray(image_src)
|
| 177 |
else:
|
| 178 |
+
image_ori = image_src.convert('RGB')
|
| 179 |
+
|
|
|
|
|
|
|
| 180 |
image = image_ori.convert('RGB')
|
| 181 |
+
preproc = ImagePreprocessor(resolution_list)
|
| 182 |
+
image_proc = preproc.proc(image).unsqueeze(0).to(device).half()
|
| 183 |
+
|
| 184 |
+
# 실제 추론
|
| 185 |
+
with torch.inference_mode():
|
| 186 |
+
# 결과 맨 마지막 레이어 preds
|
| 187 |
+
preds = birefnet(image_proc)[-1].sigmoid().cpu()
|
| 188 |
+
pred_mask = preds[0].squeeze()
|
| 189 |
+
|
| 190 |
+
# 후처리
|
| 191 |
+
pred_pil = transforms.ToPILImage()(pred_mask)
|
| 192 |
image_masked = refine_foreground(image, pred_pil)
|
| 193 |
image_masked.putalpha(pred_pil.resize(image.size))
|
| 194 |
+
|
| 195 |
+
if is_batch:
|
| 196 |
+
file_name = (
|
| 197 |
+
os.path.splitext(os.path.basename(image_src))[0]
|
| 198 |
+
if isinstance(image_src, str)
|
| 199 |
+
else f"img_{idx}"
|
| 200 |
+
)
|
| 201 |
+
out_path = os.path.join(save_dir, f"{file_name}.png")
|
| 202 |
+
image_masked.save(out_path)
|
| 203 |
+
save_paths.append(out_path)
|
| 204 |
outputs.append(image_masked)
|
| 205 |
else:
|
| 206 |
outputs = [image_masked, image_ori]
|
| 207 |
+
|
| 208 |
+
torch.cuda.empty_cache()
|
| 209 |
+
|
| 210 |
+
# 배치라면 갤러리 + ZIP 반환
|
| 211 |
+
if is_batch:
|
| 212 |
+
zip_path = os.path.join(save_dir, f"{save_dir}.zip")
|
| 213 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
| 214 |
+
for fpath in save_paths:
|
| 215 |
+
zipf.write(fpath, os.path.basename(fpath))
|
| 216 |
+
return (save_paths, zip_path)
|
| 217 |
else:
|
| 218 |
return outputs
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
##########################################################
|
| 222 |
+
# 4. Gradio UI
|
| 223 |
+
##########################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
# 커스텀 CSS
|
| 226 |
css = """
|
| 227 |
body {
|
| 228 |
background: linear-gradient(135deg, #667eea, #764ba2);
|
|
|
|
| 277 |
}
|
| 278 |
"""
|
| 279 |
|
| 280 |
+
title_html = """
|
| 281 |
+
<h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
|
| 282 |
<p align="center" style="font-size:1.1em; color:#555;">
|
| 283 |
+
Using <code>from_config()</code> + local <code>state_dict</code> to bypass tie_weights issues
|
| 284 |
</p>
|
| 285 |
"""
|
| 286 |
|
| 287 |
with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
|
| 288 |
+
gr.Markdown(title_html)
|
| 289 |
with gr.Tabs():
|
| 290 |
+
# 탭 1: Image
|
| 291 |
with gr.Tab("Image"):
|
| 292 |
with gr.Row():
|
| 293 |
with gr.Column(scale=1):
|
|
|
|
| 296 |
weights_radio = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
| 297 |
predict_btn = gr.Button("Predict")
|
| 298 |
with gr.Column(scale=2):
|
| 299 |
+
output_slider = ImageSlider(label="Result", type="pil")
|
| 300 |
+
gr.Examples(
|
| 301 |
+
examples=examples_image,
|
| 302 |
+
inputs=[image_input, resolution_input, weights_radio],
|
| 303 |
+
label="Examples"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# 탭 2: Text(URL)
|
| 307 |
with gr.Tab("Text"):
|
| 308 |
with gr.Row():
|
| 309 |
with gr.Column(scale=1):
|
|
|
|
| 312 |
weights_radio_text = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
| 313 |
predict_btn_text = gr.Button("Predict")
|
| 314 |
with gr.Column(scale=2):
|
| 315 |
+
output_slider_text = ImageSlider(label="Result", type="pil")
|
| 316 |
+
gr.Examples(
|
| 317 |
+
examples=examples_text,
|
| 318 |
+
inputs=[image_url, resolution_input_text, weights_radio_text],
|
| 319 |
+
label="Examples"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# 탭 3: Batch
|
| 323 |
with gr.Tab("Batch"):
|
| 324 |
with gr.Row():
|
| 325 |
with gr.Column(scale=1):
|
| 326 |
+
file_input = gr.File(
|
| 327 |
+
label="Upload Multiple Images",
|
| 328 |
+
type="filepath",
|
| 329 |
+
file_count="multiple"
|
| 330 |
+
)
|
| 331 |
resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
|
| 332 |
weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
|
| 333 |
predict_btn_batch = gr.Button("Predict")
|
| 334 |
with gr.Column(scale=2):
|
| 335 |
+
output_gallery = gr.Gallery(label="Results", scale=1)
|
| 336 |
+
zip_output = gr.File(label="Zip Download")
|
| 337 |
+
gr.Examples(
|
| 338 |
+
examples=examples_batch,
|
| 339 |
+
inputs=[file_input, resolution_input_batch, weights_radio_batch],
|
| 340 |
+
label="Examples"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
|
| 344 |
|
| 345 |
+
# 버튼 이벤트 연결
|
| 346 |
predict_btn.click(
|
| 347 |
fn=predict,
|
| 348 |
inputs=[image_input, resolution_input, weights_radio],
|