Spaces:
Runtime error
Runtime error
| import torch | |
| import pytorch3d | |
| import torch.nn.functional as F | |
| from pytorch3d.ops import interpolate_face_attributes | |
| from pytorch3d.renderer import ( | |
| look_at_view_transform, | |
| FoVPerspectiveCameras, | |
| AmbientLights, | |
| PointLights, | |
| DirectionalLights, | |
| Materials, | |
| RasterizationSettings, | |
| MeshRenderer, | |
| MeshRasterizer, | |
| SoftPhongShader, | |
| SoftSilhouetteShader, | |
| HardPhongShader, | |
| TexturesVertex, | |
| TexturesUV, | |
| Materials, | |
| ) | |
| from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend | |
| from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties | |
| from pytorch3d.renderer.mesh.shader import ShaderBase | |
| def get_cos_angle(points, normals, camera_position): | |
| """ | |
| calculate cosine similarity between view->surface and surface normal. | |
| """ | |
| if points.shape != normals.shape: | |
| msg = "Expected points and normals to have the same shape: got %r, %r" | |
| raise ValueError(msg % (points.shape, normals.shape)) | |
| # Ensure all inputs have same batch dimension as points | |
| matched_tensors = convert_to_tensors_and_broadcast( | |
| points, camera_position, device=points.device | |
| ) | |
| _, camera_position = matched_tensors | |
| # Reshape direction and color so they have all the arbitrary intermediate | |
| # dimensions as points. Assume first dim = batch dim and last dim = 3. | |
| points_dims = points.shape[1:-1] | |
| expand_dims = (-1,) + (1,) * len(points_dims) | |
| if camera_position.shape != normals.shape: | |
| camera_position = camera_position.view(expand_dims + (3,)) | |
| normals = F.normalize(normals, p=2, dim=-1, eps=1e-6) | |
| # Calculate the cosine value. | |
| view_direction = camera_position - points | |
| view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6) | |
| cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True) | |
| cos_angle = cos_angle.clamp(0, 1) | |
| # Cosine of the angle between the reflected light ray and the viewer | |
| return cos_angle | |
| def _geometry_shading_with_pixels( | |
| meshes, fragments, lights, cameras, materials, texels | |
| ): | |
| """ | |
| Render pixel space vertex position, normal(world), depth, and cos angle | |
| Args: | |
| meshes: Batch of meshes | |
| fragments: Fragments named tuple with the outputs of rasterization | |
| lights: Lights class containing a batch of lights | |
| cameras: Cameras class containing a batch of cameras | |
| materials: Materials class containing a batch of material properties | |
| texels: texture per pixel of shape (N, H, W, K, 3) | |
| Returns: | |
| colors: (N, H, W, K, 3) | |
| pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection. | |
| """ | |
| verts = meshes.verts_packed() # (V, 3) | |
| faces = meshes.faces_packed() # (F, 3) | |
| vertex_normals = meshes.verts_normals_packed() # (V, 3) | |
| faces_verts = verts[faces] | |
| faces_normals = vertex_normals[faces] | |
| pixel_coords_in_camera = interpolate_face_attributes( | |
| fragments.pix_to_face, fragments.bary_coords, faces_verts | |
| ) | |
| pixel_normals = interpolate_face_attributes( | |
| fragments.pix_to_face, fragments.bary_coords, faces_normals | |
| ) | |
| cos_angles = get_cos_angle( | |
| pixel_coords_in_camera, pixel_normals, cameras.get_camera_center() | |
| ) | |
| return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles | |
| class HardGeometryShader(ShaderBase): | |
| """ | |
| renders common geometric informations. | |
| """ | |
| def forward(self, fragments, meshes, **kwargs): | |
| cameras = super()._get_cameras(**kwargs) | |
| texels = self.texel_from_uv(fragments, meshes) | |
| lights = kwargs.get("lights", self.lights) | |
| materials = kwargs.get("materials", self.materials) | |
| blend_params = kwargs.get("blend_params", self.blend_params) | |
| verts, normals, depths, cos_angles = _geometry_shading_with_pixels( | |
| meshes=meshes, | |
| fragments=fragments, | |
| texels=texels, | |
| lights=lights, | |
| cameras=cameras, | |
| materials=materials, | |
| ) | |
| texels = meshes.sample_textures(fragments) | |
| verts = hard_rgb_blend(verts, fragments, blend_params) | |
| normals = hard_rgb_blend(normals, fragments, blend_params) | |
| depths = hard_rgb_blend(depths, fragments, blend_params) | |
| cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params) | |
| from IPython import embed | |
| embed() | |
| texels = hard_rgb_blend(texels, fragments, blend_params) | |
| return verts, normals, depths, cos_angles, texels, fragments | |
| def texel_from_uv(self, fragments, meshes): | |
| texture_tmp = meshes.textures | |
| maps_tmp = texture_tmp.maps_padded() | |
| uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]] | |
| uv_color = ( | |
| torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype) | |
| ) | |
| uv_texture = TexturesUV( | |
| [uv_color.clone() for t in maps_tmp], | |
| texture_tmp.faces_uvs_padded(), | |
| texture_tmp.verts_uvs_padded(), | |
| sampling_mode="bilinear", | |
| ) | |
| meshes.textures = uv_texture | |
| texels = meshes.sample_textures(fragments) | |
| meshes.textures = texture_tmp | |
| texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1) | |
| return texels | |