| from .TriplaneVAE import TriplaneVAE | |
| from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM | |
| from .Triplane_Diffusion import EDMLoss_MultiImgCond | |
| #from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug | |
| def get_model(model_args): | |
| if model_args['type']=="TriVAE": | |
| model=TriplaneVAE(model_args) | |
| elif model_args['type']=="triplane_diff_multiimg_cond": | |
| model=Triplane_Diff_MultiImgCond_EDM(model_args) | |
| else: | |
| raise NotImplementedError | |
| return model | |
| def get_criterion(cri_args): | |
| if cri_args['type']=="EDMLoss_MultiImgCond": | |
| criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par']) | |
| else: | |
| raise NotImplementedError | |
| return criterion | |