Spaces:
Sleeping
Sleeping
Commit
·
7318bea
1
Parent(s):
c8a655a
NameError
Browse files- app.py +23 -21
- dino_feature_extractor.py +3 -2
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import io
|
| 4 |
import cv2
|
|
@@ -11,6 +10,7 @@ from functools import lru_cache
|
|
| 11 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 12 |
from torchvision.transforms.functional import normalize
|
| 13 |
import glob
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
from restormerRFR_arch import RestormerRFR
|
|
@@ -83,39 +83,41 @@ def get_model_and_device():
|
|
| 83 |
return model, device
|
| 84 |
|
| 85 |
|
| 86 |
-
@spaces.GPU(duration=
|
| 87 |
def restore_image(pil_img: Image.Image) -> Image.Image:
|
| 88 |
"""
|
| 89 |
输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
|
| 90 |
"""
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
|
| 95 |
-
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
|
| 96 |
-
img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB
|
| 97 |
-
img = img.unsqueeze(0).to(device) # (1,3,H,W)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 101 |
-
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
| 102 |
-
normalize(img, mean, std, inplace=True)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
output = normalize(output, -1 * mean / std, 1 / std)
|
| 110 |
-
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W)
|
| 111 |
-
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # (H,W,RGB)
|
| 112 |
-
output = (output * 255.0).round().astype(np.uint8)
|
| 113 |
-
out_pil = Image.fromarray(output, mode="RGB")
|
| 114 |
-
return out_pil
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
DESCRIPTION = """
|
| 118 |
-
# RAM
|
| 119 |
"""
|
| 120 |
|
| 121 |
with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import io
|
| 3 |
import cv2
|
|
|
|
| 10 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 11 |
from torchvision.transforms.functional import normalize
|
| 12 |
import glob
|
| 13 |
+
import traceback
|
| 14 |
|
| 15 |
|
| 16 |
from restormerRFR_arch import RestormerRFR
|
|
|
|
| 83 |
return model, device
|
| 84 |
|
| 85 |
|
| 86 |
+
@spaces.GPU(duration=240)
|
| 87 |
def restore_image(pil_img: Image.Image) -> Image.Image:
|
| 88 |
"""
|
| 89 |
输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
|
| 90 |
"""
|
| 91 |
+
try:
|
| 92 |
+
model, device = get_model_and_device()
|
| 93 |
+
dino_extractor = get_dino_extractor(device)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
|
| 97 |
+
img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB
|
| 98 |
+
img = img.unsqueeze(0).to(device) # (1,3,H,W)
|
| 99 |
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 102 |
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
| 103 |
+
normalize(img, mean, std, inplace=True)
|
| 104 |
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
dino_features = dino_extractor(img)
|
| 107 |
+
output = model(img, dino_features)
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
output = normalize(output, -1 * mean / std, 1 / std)
|
| 111 |
+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W)
|
| 112 |
+
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # (H,W,RGB)
|
| 113 |
+
output = (output * 255.0).round().astype(np.uint8)
|
| 114 |
+
out_pil = Image.fromarray(output, mode="RGB")
|
| 115 |
+
return out_pil
|
| 116 |
+
except Exception as e:
|
| 117 |
+
raise gr.Error(f"{e}\n{traceback.format_exc()}")
|
| 118 |
|
| 119 |
DESCRIPTION = """
|
| 120 |
+
# RAM++: Robust Representation Learning via Adaptive Mask for All-in-One Image Restoration
|
| 121 |
"""
|
| 122 |
|
| 123 |
with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
|
dino_feature_extractor.py
CHANGED
|
@@ -10,8 +10,9 @@ class DinoFeatureModule(nn.Module):
|
|
| 10 |
def __init__(self, model_id: str = "facebook/dinov2-giant"):
|
| 11 |
super(DinoFeatureModule, self).__init__()
|
| 12 |
dtype = torch.float32
|
|
|
|
| 13 |
self.dino = AutoModel.from_pretrained(
|
| 14 |
-
model_id,
|
| 15 |
torch_dtype=dtype
|
| 16 |
)
|
| 17 |
|
|
@@ -110,7 +111,7 @@ class DinoFeatureModule(nn.Module):
|
|
| 110 |
|
| 111 |
shortest_edge = min(target_h, target_w)
|
| 112 |
processor = AutoImageProcessor.from_pretrained(
|
| 113 |
-
model_id,
|
| 114 |
local_files_only=False,
|
| 115 |
do_rescale=False,
|
| 116 |
do_center_crop=False,
|
|
|
|
| 10 |
def __init__(self, model_id: str = "facebook/dinov2-giant"):
|
| 11 |
super(DinoFeatureModule, self).__init__()
|
| 12 |
dtype = torch.float32
|
| 13 |
+
self.model_id = model_id
|
| 14 |
self.dino = AutoModel.from_pretrained(
|
| 15 |
+
self.model_id,
|
| 16 |
torch_dtype=dtype
|
| 17 |
)
|
| 18 |
|
|
|
|
| 111 |
|
| 112 |
shortest_edge = min(target_h, target_w)
|
| 113 |
processor = AutoImageProcessor.from_pretrained(
|
| 114 |
+
self.model_id,
|
| 115 |
local_files_only=False,
|
| 116 |
do_rescale=False,
|
| 117 |
do_center_crop=False,
|