Spaces:
Running
on
A10G
Running
on
A10G
| import os | |
| import json | |
| import tqdm | |
| import cv2 | |
| import numpy as np | |
| import torch, lrm | |
| import torch.nn.functional as F | |
| from lrm.utils.config import load_config | |
| from datetime import datetime | |
| import gradio as gr | |
| from pygltflib import GLTF2 | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from refine import refine | |
| device = "cuda" | |
| import trimesh | |
| import pymeshlab | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| repo_id = "zjpshadow/CharacterGen" | |
| all_files = list_repo_files(repo_id, revision="main") | |
| for file in all_files: | |
| if os.path.exists("../" + file): | |
| continue | |
| if file.startswith("3D_Stage"): | |
| hf_hub_download(repo_id, file, local_dir="../") | |
| def traverse(path, back_proj): | |
| mesh = trimesh.load(f"{path}/model-00.obj") | |
| mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(90.0), [-1, 0, 0])) | |
| mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(180.0), [0, 1, 0])) | |
| cmesh = pymeshlab.Mesh(mesh.vertices, mesh.faces) | |
| ms = pymeshlab.MeshSet() | |
| ms.add_mesh(cmesh) | |
| ms.apply_coord_laplacian_smoothing(stepsmoothnum=4) | |
| mesh.vertices = ms.current_mesh().vertex_matrix() | |
| mesh.export(f'{path}/output.glb', file_type='glb') | |
| image = Image.open(f"{path}/{'refined_texture_kd.jpg' if back_proj else 'texture_kd.jpg'}") | |
| texture = np.array(image) | |
| vertex_colors = np.zeros((mesh.vertices.shape[0], 4), dtype=np.uint8) | |
| for vertex_index in range(len(mesh.visual.uv)): | |
| uv = mesh.visual.uv[vertex_index] | |
| x = int(uv[0] * (texture.shape[1] - 1)) | |
| y = int((1 - uv[1]) * (texture.shape[0] - 1)) | |
| color = texture[y, x, :3] | |
| vertex_colors[vertex_index] = [color[0], color[1], color[2], 255] | |
| return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, vertex_colors=vertex_colors) | |
| class Inference_API: | |
| def __init__(self): | |
| # Load config | |
| self.cfg = load_config("configs/infer.yaml", makedirs=False) | |
| # Load system | |
| print("Loading system") | |
| self.system = lrm.find(self.cfg.system_cls)(self.cfg.system).to(device) | |
| self.system.eval() | |
| def process_images(self, img_input0, img_input1, img_input2, img_input3, back_proj): | |
| meta = json.load(open("material/meta.json")) | |
| c2w_cond = [np.array(loc["transform_matrix"]) for loc in meta["locations"]] | |
| c2w_cond = torch.from_numpy(np.stack(c2w_cond, axis=0)).float()[None].to(device) | |
| # Prepare input data | |
| rgb_cond = [] | |
| files = [img_input0, img_input1, img_input2, img_input3] | |
| new_image = [] | |
| for file in files: | |
| image = np.array(file) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| new_image.append(Image.fromarray(image.astype(np.uint8)).convert("RGB")) | |
| rgb = cv2.resize(image, (self.cfg.data.cond_width, | |
| self.cfg.data.cond_height)).astype(np.float32) / 255.0 | |
| rgb_cond.append(rgb) | |
| assert len(rgb_cond) == 4, "Please provide 4 images" | |
| rgb_cond = torch.from_numpy(np.stack(rgb_cond, axis=0)).float()[None].to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| scene_codes = self.system({"rgb_cond": rgb_cond, "c2w_cond": c2w_cond}) | |
| exporter_output = self.system.exporter([f"{i:02d}" for i in range(rgb_cond.shape[0])], scene_codes) | |
| # Save output | |
| save_dir = os.path.join("./outputs", datetime.now().strftime("@%Y%m%d-%H%M%S")) | |
| os.makedirs(save_dir, exist_ok=True) | |
| self.system.set_save_dir(save_dir) | |
| for out in exporter_output: | |
| save_func_name = f"save_{out.save_type}" | |
| save_func = getattr(self.system, save_func_name) | |
| save_func(f"{out.save_name}", **out.params) | |
| if back_proj: | |
| refine(save_dir, new_image[1], new_image[0], new_image[3], new_image[2]) | |
| new_obj = traverse(save_dir, back_proj) | |
| new_obj.export(f'{save_dir}/output.obj', file_type='obj') | |
| gltf = GLTF2().load(f'{save_dir}/output.glb') | |
| for material in gltf.materials: | |
| if material.pbrMetallicRoughness: | |
| material.pbrMetallicRoughness.baseColorFactor = [1.0, 1.0, 1.0, 100.0] | |
| material.pbrMetallicRoughness.metallicFactor = 0.0 | |
| material.pbrMetallicRoughness.roughnessFactor = 1.0 | |
| gltf.save(f'{save_dir}/output.glb') | |
| return save_dir, f"{save_dir}/output.obj", f"{save_dir}/output.glb" | |
| inferapi = Inference_API() | |
| # Define the interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration") | |
| gr.Markdown("# 3D Stage: Four View Images to 3D Mesh") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| img_input0 = gr.Image(type="pil", label="Back Image", image_mode="RGBA", width=256, height=384) | |
| img_input1 = gr.Image(type="pil", label="Front Image", image_mode="RGBA", width=256, height=384) | |
| with gr.Row(): | |
| img_input2 = gr.Image(type="pil", label="Right Image", image_mode="RGBA", width=256, height=384) | |
| img_input3 = gr.Image(type="pil", label="Left Image", image_mode="RGBA", width=256, height=384) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples= | |
| [["material/examples/1/1.png", | |
| "material/examples/1/2.png", | |
| "material/examples/1/3.png", | |
| "material/examples/1/4.png"]], | |
| label="Example Images", | |
| inputs=[img_input0, img_input1, img_input2, img_input3] | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| back_proj = gr.Checkbox(label="Back Projection") | |
| submit_button = gr.Button("Process") | |
| output_dir = gr.Textbox(label="Output Directory") | |
| with gr.Column(): | |
| with gr.Tab("GLB"): | |
| output_model_glb = gr.Model3D( label="Output Model (GLB Format)", height = 768) | |
| gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.") | |
| with gr.Tab("OBJ"): | |
| output_model_obj = gr.Model3D( label="Output Model (OBJ Format)", height = 768) | |
| gr.Markdown("Note: The model shown here is flipped. Download to get correct results.") | |
| submit_button.click(inferapi.process_images, inputs=[img_input0, img_input1, img_input2, img_input3, back_proj], | |
| outputs=[output_dir, output_model_obj, output_model_glb]) | |
| # Run the interface | |
| demo.launch() |