Spaces:
Runtime error
Runtime error
| from typing import Any, Dict, Union | |
| import blobfile as bf | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from shap_e.models.generation.latent_diffusion import SplitVectorDiffusion | |
| from shap_e.models.generation.perceiver import PointDiffusionPerceiver | |
| from shap_e.models.generation.pooled_mlp import PooledMLP | |
| from shap_e.models.generation.transformer import ( | |
| CLIPImageGridPointDiffusionTransformer, | |
| CLIPImageGridUpsamplePointDiffusionTransformer, | |
| CLIPImagePointDiffusionTransformer, | |
| PointDiffusionTransformer, | |
| UpsamplePointDiffusionTransformer, | |
| ) | |
| from shap_e.models.nerf.model import MLPNeRFModel, VoidNeRFModel | |
| from shap_e.models.nerf.renderer import OneStepNeRFRenderer, TwoStepNeRFRenderer | |
| from shap_e.models.nerstf.mlp import MLPDensitySDFModel, MLPNeRSTFModel | |
| from shap_e.models.nerstf.renderer import NeRSTFRenderer | |
| from shap_e.models.nn.meta import batch_meta_state_dict | |
| from shap_e.models.stf.mlp import MLPSDFModel, MLPTextureFieldModel | |
| from shap_e.models.stf.renderer import STFRenderer | |
| from shap_e.models.transmitter.base import ChannelsDecoder, Transmitter, VectorDecoder | |
| from shap_e.models.transmitter.channels_encoder import ( | |
| PointCloudPerceiverChannelsEncoder, | |
| PointCloudTransformerChannelsEncoder, | |
| ) | |
| from shap_e.models.transmitter.multiview_encoder import MultiviewTransformerEncoder | |
| from shap_e.models.transmitter.pc_encoder import ( | |
| PointCloudPerceiverEncoder, | |
| PointCloudTransformerEncoder, | |
| ) | |
| from shap_e.models.volume import BoundingBoxVolume, SphericalVolume, UnboundedVolume | |
| def model_from_config(config: Union[str, Dict[str, Any]], device: torch.device) -> nn.Module: | |
| print(config) | |
| if isinstance(config, str): | |
| print("config", config) | |
| with bf.BlobFile(config, "rb") as f: | |
| obj = yaml.load(f, Loader=yaml.SafeLoader) | |
| return model_from_config(obj, device=device) | |
| config = config.copy() | |
| name = config.pop("name") | |
| if name == "PointCloudTransformerEncoder": | |
| return PointCloudTransformerEncoder(device=device, dtype=torch.float32, **config) | |
| elif name == "PointCloudPerceiverEncoder": | |
| return PointCloudPerceiverEncoder(device=device, dtype=torch.float32, **config) | |
| elif name == "PointCloudTransformerChannelsEncoder": | |
| return PointCloudTransformerChannelsEncoder(device=device, dtype=torch.float32, **config) | |
| elif name == "PointCloudPerceiverChannelsEncoder": | |
| return PointCloudPerceiverChannelsEncoder(device=device, dtype=torch.float32, **config) | |
| elif name == "MultiviewTransformerEncoder": | |
| return MultiviewTransformerEncoder(device=device, dtype=torch.float32, **config) | |
| elif name == "Transmitter": | |
| renderer = model_from_config(config.pop("renderer"), device=device) | |
| param_shapes = { | |
| k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() | |
| } | |
| encoder_config = config.pop("encoder").copy() | |
| encoder_config["param_shapes"] = param_shapes | |
| encoder = model_from_config(encoder_config, device=device) | |
| return Transmitter(encoder=encoder, renderer=renderer, **config) | |
| elif name == "VectorDecoder": | |
| renderer = model_from_config(config.pop("renderer"), device=device) | |
| param_shapes = { | |
| k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() | |
| } | |
| return VectorDecoder(param_shapes=param_shapes, renderer=renderer, device=device, **config) | |
| elif name == "ChannelsDecoder": | |
| renderer = model_from_config(config.pop("renderer"), device=device) | |
| param_shapes = { | |
| k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() | |
| } | |
| return ChannelsDecoder( | |
| param_shapes=param_shapes, renderer=renderer, device=device, **config | |
| ) | |
| elif name == "OneStepNeRFRenderer": | |
| config = config.copy() | |
| for field in [ | |
| # Required | |
| "void_model", | |
| "foreground_model", | |
| "volume", | |
| # Optional to use NeRF++ | |
| "background_model", | |
| "outer_volume", | |
| ]: | |
| if field in config: | |
| config[field] = model_from_config(config.pop(field).copy(), device) | |
| return OneStepNeRFRenderer(device=device, **config) | |
| elif name == "TwoStepNeRFRenderer": | |
| config = config.copy() | |
| for field in [ | |
| # Required | |
| "void_model", | |
| "coarse_model", | |
| "fine_model", | |
| "volume", | |
| # Optional to use NeRF++ | |
| "coarse_background_model", | |
| "fine_background_model", | |
| "outer_volume", | |
| ]: | |
| if field in config: | |
| config[field] = model_from_config(config.pop(field).copy(), device) | |
| return TwoStepNeRFRenderer(device=device, **config) | |
| elif name == "PooledMLP": | |
| return PooledMLP(device, **config) | |
| elif name == "PointDiffusionTransformer": | |
| return PointDiffusionTransformer(device=device, dtype=torch.float32, **config) | |
| elif name == "PointDiffusionPerceiver": | |
| return PointDiffusionPerceiver(device=device, dtype=torch.float32, **config) | |
| elif name == "CLIPImagePointDiffusionTransformer": | |
| return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config) | |
| elif name == "CLIPImageGridPointDiffusionTransformer": | |
| return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config) | |
| elif name == "UpsamplePointDiffusionTransformer": | |
| return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config) | |
| elif name == "CLIPImageGridUpsamplePointDiffusionTransformer": | |
| return CLIPImageGridUpsamplePointDiffusionTransformer( | |
| device=device, dtype=torch.float32, **config | |
| ) | |
| elif name == "SplitVectorDiffusion": | |
| inner_config = config.pop("inner") | |
| d_latent = config.pop("d_latent") | |
| latent_ctx = config.pop("latent_ctx", 1) | |
| inner_config["input_channels"] = d_latent // latent_ctx | |
| inner_config["n_ctx"] = latent_ctx | |
| inner_config["output_channels"] = d_latent // latent_ctx * 2 | |
| inner_model = model_from_config(inner_config, device) | |
| return SplitVectorDiffusion( | |
| device=device, wrapped=inner_model, n_ctx=latent_ctx, d_latent=d_latent | |
| ) | |
| elif name == "STFRenderer": | |
| config = config.copy() | |
| for field in ["sdf", "tf", "volume"]: | |
| config[field] = model_from_config(config.pop(field), device) | |
| return STFRenderer(device=device, **config) | |
| elif name == "NeRSTFRenderer": | |
| config = config.copy() | |
| for field in ["sdf", "tf", "nerstf", "void", "volume"]: | |
| if field not in config: | |
| continue | |
| config[field] = model_from_config(config.pop(field), device) | |
| config.setdefault("sdf", None) | |
| config.setdefault("tf", None) | |
| config.setdefault("nerstf", None) | |
| return NeRSTFRenderer(device=device, **config) | |
| model_cls = { | |
| "MLPSDFModel": MLPSDFModel, | |
| "MLPTextureFieldModel": MLPTextureFieldModel, | |
| "MLPNeRFModel": MLPNeRFModel, | |
| "MLPDensitySDFModel": MLPDensitySDFModel, | |
| "MLPNeRSTFModel": MLPNeRSTFModel, | |
| "VoidNeRFModel": VoidNeRFModel, | |
| "BoundingBoxVolume": BoundingBoxVolume, | |
| "SphericalVolume": SphericalVolume, | |
| "UnboundedVolume": UnboundedVolume, | |
| }[name] | |
| return model_cls(device=device, **config) | |