Spaces:
Build error
Build error
| from typing import List | |
| from pydantic import validator | |
| from my.config import BaseConf, SingleOrList, dispatch | |
| from my.utils.seed import seed_everything | |
| import numpy as np | |
| from voxnerf.vox import VOXRF_REGISTRY | |
| from voxnerf.pipelines import train | |
| class VoxConfig(BaseConf): | |
| model_type: str = "VoxRF" | |
| bbox_len: float = 1.5 | |
| grid_size: SingleOrList(int) = [128, 128, 128] | |
| step_ratio: float = 0.5 | |
| density_shift: float = -10. | |
| ray_march_weight_thres: float = 0.0001 | |
| c: int = 3 | |
| blend_bg_texture: bool = False | |
| bg_texture_hw: int = 64 | |
| def check_gsize(cls, grid_size): | |
| if isinstance(grid_size, int): | |
| return [grid_size, ] * 3 | |
| else: | |
| assert len(grid_size) == 3 | |
| return grid_size | |
| def make(self): | |
| params = self.dict() | |
| m_type = params.pop("model_type") | |
| model_fn = VOXRF_REGISTRY.get(m_type) | |
| radius = params.pop('bbox_len') | |
| aabb = radius * np.array([ | |
| [-1, -1, -1], | |
| [1, 1, 1] | |
| ]) | |
| model = model_fn(aabb=aabb, **params) | |
| return model | |
| class TrainerConfig(BaseConf): | |
| model: VoxConfig = VoxConfig() | |
| scene: str = "lego" | |
| n_epoch: int = 2 | |
| bs: int = 4096 | |
| lr: float = 0.02 | |
| def run(self): | |
| args = self.dict() | |
| args.pop("model") | |
| model = self.model.make() | |
| train(model, **args) | |
| if __name__ == "__main__": | |
| seed_everything(0) | |
| dispatch(TrainerConfig) | |