RAM_plus_plus / app.py
Zilong-Zhang003
readme_normalize
b1e0bf7
raw
history blame
4 kB
import os
import io
import cv2
import gradio as gr
import numpy as np
import torch
import spaces
from PIL import Image
from functools import lru_cache
from huggingface_hub import hf_hub_download, snapshot_download
from torchvision.transforms.functional import normalize
import glob
import traceback
from restormerRFR_arch import RestormerRFR
from dino_feature_extractor import DinoFeatureModule
WEIGHT_REPO_ID = "233zzl/RAM_plus_plus"
WEIGHT_FILENAME = "7task/RestormerRFR.pth"
MODEL_NAME = "RestormerRFR"
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def warmup():
hf_hub_download(
repo_id=WEIGHT_REPO_ID,
filename=WEIGHT_FILENAME,
repo_type="model",
revision="main"
)
snapshot_download(
repo_id="facebook/dinov2-giant",
repo_type="model",
revision="main"
)
def build_model():
model = RestormerRFR(
inp_channels=3,
out_channels=3,
dim=48,
num_blocks=[4, 6, 6, 8],
num_refinement_blocks=4,
heads=[1, 2, 4, 8],
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type="WithBias",
finetune_type=None,
img_size=128,
)
return model
@lru_cache(maxsize=1)
def get_dino_extractor(device):
extractor = DinoFeatureModule().to(device).eval()
return extractor
@lru_cache(maxsize=1)
def get_model_and_device():
device = get_device()
model = build_model()
weight_path = hf_hub_download(
repo_id=WEIGHT_REPO_ID,
filename=WEIGHT_FILENAME,
)
ckpt = torch.load(weight_path, map_location="cpu")
keyname = "params" if "params" in ckpt else None
if keyname is not None:
model.load_state_dict(ckpt[keyname], strict=False)
else:
model.load_state_dict(ckpt, strict=False)
model.eval().to(device)
return model, device
@spaces.GPU(duration=120)
def restore_image(pil_img: Image.Image) -> Image.Image:
try:
model, device = get_model_and_device()
dino_extractor = get_dino_extractor(device)
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB
img = img.unsqueeze(0).to(device) # (1,3,H,W)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
normalize(img, mean, std, inplace=True)
with torch.no_grad():
dino_features = dino_extractor(img)
output = model(img, dino_features)
output = normalize(output, -1 * mean / std, 1 / std)
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W)
output = (output * 255.0).round().astype(np.uint8)
out_pil = Image.fromarray(output, mode="RGB")
return out_pil
except Exception as e:
raise gr.Error(f"{e}\n{traceback.format_exc()}")
DESCRIPTION = """
# RAM++
"""
with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="load picture(JPEG/PNG)")
btn = gr.Button("Run (ZeroGPU)")
with gr.Column():
out = gr.Image(type="pil", label="output")
ex_files = []
for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp"):
ex_files.extend(glob.glob(os.path.join("examples", ext)))
ex_files = sorted(ex_files)
if ex_files:
gr.Examples(examples=ex_files, inputs=inp, label="exampls)")
btn.click(restore_image, inputs=inp, outputs=out, api_name="run")
gr.Markdown("""
**Tips**
- If the queue is long or you hit the quota, please try again later, or upgrade to Pro for a higher ZeroGPU quota and priority.
""")
demo.load(fn=warmup, inputs=None, outputs=None)
if __name__ == "__main__":
demo.launch()