Spaces:
Running
on
Zero
Running
on
Zero
| # This file uses nvdiffrast library, which is licensed under the NVIDIA Source Code License (1-Way Commercial). | |
| # nvdiffrast is available for non-commercial use (research or evaluation purposes only). | |
| # For commercial use, please contact NVIDIA for licensing: https://www.nvidia.com/en-us/research/inquiries/ | |
| # | |
| # nvdiffrast copyright: Copyright (c) 2020, NVIDIA Corporation. All rights reserved. | |
| # Full license: https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt | |
| from typing import Tuple, Union | |
| import nvdiffrast.torch as dr | |
| import torch | |
| from jaxtyping import Float, Integer | |
| from torch import Tensor | |
| class NVDiffRasterizerContext: | |
| def __init__(self, context_type: str, device) -> None: | |
| self.device = device | |
| self.ctx = self.initialize_context(context_type, device) | |
| def initialize_context( | |
| self, context_type: str, device | |
| ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: | |
| if context_type == "gl": | |
| return dr.RasterizeGLContext(device=device) | |
| elif context_type == "cuda": | |
| return dr.RasterizeCudaContext(device=device) | |
| else: | |
| raise ValueError(f"Unknown rasterizer context type: {context_type}") | |
| def vertex_transform( | |
| self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] | |
| ) -> Float[Tensor, "B Nv 4"]: | |
| with torch.amp.autocast("cuda", enabled=False): | |
| verts_homo = torch.cat( | |
| [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 | |
| ) | |
| verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) | |
| return verts_clip | |
| def rasterize( | |
| self, | |
| pos: Float[Tensor, "B Nv 4"], | |
| tri: Integer[Tensor, "Nf 3"], | |
| resolution: Union[int, Tuple[int, int]], | |
| ): | |
| # rasterize in instance mode (single topology) | |
| return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) | |
| def rasterize_one( | |
| self, | |
| pos: Float[Tensor, "Nv 4"], | |
| tri: Integer[Tensor, "Nf 3"], | |
| resolution: Union[int, Tuple[int, int]], | |
| ): | |
| # rasterize one single mesh under a single viewpoint | |
| rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) | |
| return rast[0], rast_db[0] | |
| def antialias( | |
| self, | |
| color: Float[Tensor, "B H W C"], | |
| rast: Float[Tensor, "B H W 4"], | |
| pos: Float[Tensor, "B Nv 4"], | |
| tri: Integer[Tensor, "Nf 3"], | |
| ) -> Float[Tensor, "B H W C"]: | |
| return dr.antialias(color.float(), rast, pos.float(), tri.int()) | |
| def interpolate( | |
| self, | |
| attr: Float[Tensor, "B Nv C"], | |
| rast: Float[Tensor, "B H W 4"], | |
| tri: Integer[Tensor, "Nf 3"], | |
| rast_db=None, | |
| diff_attrs=None, | |
| ) -> Float[Tensor, "B H W C"]: | |
| return dr.interpolate( | |
| attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs | |
| ) | |
| def interpolate_one( | |
| self, | |
| attr: Float[Tensor, "Nv C"], | |
| rast: Float[Tensor, "B H W 4"], | |
| tri: Integer[Tensor, "Nf 3"], | |
| rast_db=None, | |
| diff_attrs=None, | |
| ) -> Float[Tensor, "B H W C"]: | |
| return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) | |
| def texture_map_to_rgb(tex_map, uv_coordinates): | |
| return dr.texture(tex_map.float(), uv_coordinates) | |
| def render_rgb_from_texture_mesh_with_mask( | |
| ctx, | |
| mesh, | |
| tex_map: Float[Tensor, "1 H W C"], | |
| mvp_matrix: Float[Tensor, "batch 4 4"], | |
| image_height: int, | |
| image_width: int, | |
| background_color: Tensor = torch.tensor([0.0, 0.0, 0.0]), | |
| ): | |
| batch_size = mvp_matrix.shape[0] | |
| tex_map = tex_map.contiguous() | |
| if tex_map.dim() == 3: | |
| tex_map = tex_map.unsqueeze(0) # Add batch dimension if missing | |
| vertex_positions_clip = ctx.vertex_transform(mesh.v_pos, mvp_matrix) | |
| rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx, (image_height, image_width)) | |
| mask = rasterized_output[..., 3:] > 0 | |
| mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx) | |
| interpolated_texture_coords, _ = ctx.interpolate_one(mesh._v_tex, rasterized_output, mesh._t_tex_idx) | |
| rgb_foreground = texture_map_to_rgb(tex_map.float(), interpolated_texture_coords) | |
| rgb_foreground_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground) | |
| rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground) | |
| rgb_background_batched += background_color.view(1, 1, 1, 3).to(rgb_foreground) | |
| selector = mask[..., 0] | |
| rgb_foreground_batched[selector] = rgb_foreground[selector] | |
| # Use the anti-aliased mask for blending | |
| final_rgb = torch.lerp(rgb_background_batched, rgb_foreground_batched, mask_antialiased) | |
| final_rgb_aa = ctx.antialias(final_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx) | |
| return final_rgb_aa, selector | |
| def render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width): | |
| device = mvp_matrix.device | |
| vertex_positions_clip = ctx.vertex_transform(mesh.v_pos.to(device), mvp_matrix) | |
| rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx.to(device), (image_height, image_width)) | |
| interpolated_positions, _ = ctx.interpolate_one(mesh.v_pos.to(device), rasterized_output, mesh.t_pos_idx.to(device)) | |
| interpolated_normals, _ = ctx.interpolate_one(mesh.v_normal.to(device).contiguous(), rasterized_output, mesh.t_pos_idx.to(device)) | |
| mask = rasterized_output[..., 3:] > 0 | |
| mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device)) | |
| batch_size = mvp_matrix.shape[0] | |
| rgb_foreground_pos_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions) | |
| rgb_foreground_norm_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions) | |
| rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions) | |
| selector = mask[..., 0] | |
| rgb_foreground_pos_batched[selector] = interpolated_positions[selector] | |
| rgb_foreground_norm_batched[selector] = interpolated_normals[selector] | |
| final_pos_rgb = torch.lerp(rgb_background_batched, rgb_foreground_pos_batched, mask_antialiased) | |
| final_norm_rgb = torch.lerp(rgb_background_batched, rgb_foreground_norm_batched, mask_antialiased) | |
| final_pos_rgb_aa = ctx.antialias(final_pos_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device)) | |
| final_norm_rgb_aa = ctx.antialias(final_norm_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device)) | |
| return final_pos_rgb_aa, final_norm_rgb_aa, mask_antialiased | |
| def rasterize_position_and_normal_maps(ctx, mesh, rasterize_height, rasterize_width): | |
| device = ctx.device | |
| # Convert mesh data to torch tensors | |
| mesh_v = mesh.v_pos.to(device) | |
| mesh_f = mesh.t_pos_idx.to(device) | |
| uvs_tensor = mesh._v_tex.to(device) | |
| indices_tensor = mesh._t_tex_idx.to(device) | |
| normal_v = mesh.v_normal.to(device).contiguous() | |
| # Interpolate mesh data | |
| uv_clip = uvs_tensor[None, ...] * 2.0 - 1.0 | |
| uv_clip_padded = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., :1]), torch.ones_like(uv_clip[..., :1])), dim=-1) | |
| rasterized_output, _ = ctx.rasterize(uv_clip_padded, indices_tensor.int(), (rasterize_height, rasterize_width)) | |
| # Interpolate positions. | |
| position_map, _ = ctx.interpolate_one(mesh_v, rasterized_output, mesh_f.int()) | |
| normal_map, _ = ctx.interpolate_one(normal_v, rasterized_output, mesh_f.int()) | |
| rasterization_mask = rasterized_output[..., 3:4] > 0 | |
| return position_map, normal_map, rasterization_mask |