update files for demo
Browse files- configs/finetune_triplane_diffusion.yaml +1 -1
- datasets/taxonomy.py +25 -16
- demo/api.py +151 -0
- demo/extract_vit_features.py +57 -0
- demo/process_data.py +340 -0
- demo/simple_dataset.py +182 -0
- train_VAE.sh +1 -1
- util/misc.py +1 -1
- util/simple_image_loader.py +0 -4
configs/finetune_triplane_diffusion.yaml
CHANGED
|
@@ -37,7 +37,7 @@ model:
|
|
| 37 |
norm: "batch"
|
| 38 |
img_in_channels: 1280
|
| 39 |
vit_reso: 16
|
| 40 |
-
use_cat_embedding:
|
| 41 |
block_type: multiview_local
|
| 42 |
par_point_encoder:
|
| 43 |
plane_reso: 64
|
|
|
|
| 37 |
norm: "batch"
|
| 38 |
img_in_channels: 1280
|
| 39 |
vit_reso: 16
|
| 40 |
+
use_cat_embedding: False #only use category embedding when all categories are trained
|
| 41 |
block_type: multiview_local
|
| 42 |
par_point_encoder:
|
| 43 |
plane_reso: 64
|
datasets/taxonomy.py
CHANGED
|
@@ -1,20 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
category_map={
|
| 2 |
-
"
|
| 3 |
-
"
|
| 4 |
-
"
|
| 5 |
-
"
|
| 6 |
-
"
|
| 7 |
-
"
|
| 8 |
-
"oven":6,
|
| 9 |
-
"refrigerator":7,
|
| 10 |
-
"shelf":8,
|
| 11 |
-
"sink":9,
|
| 12 |
-
"sofa":10,
|
| 13 |
-
"stool":11,
|
| 14 |
-
"stove":12,
|
| 15 |
-
"table":13,
|
| 16 |
-
"toilet":14,
|
| 17 |
-
"washer":15
|
| 18 |
}
|
| 19 |
|
| 20 |
category_map_from_synthetic={
|
|
|
|
| 1 |
+
# category_map={
|
| 2 |
+
# "bathtub":0,
|
| 3 |
+
# "bed":1,
|
| 4 |
+
# "cabinet":2,
|
| 5 |
+
# "chair":3,
|
| 6 |
+
# "dishwasher":4,
|
| 7 |
+
# "fireplace":5,
|
| 8 |
+
# "oven":6,
|
| 9 |
+
# "refrigerator":7,
|
| 10 |
+
# "shelf":8,
|
| 11 |
+
# "sink":9,
|
| 12 |
+
# "sofa":10,
|
| 13 |
+
# "stool":11,
|
| 14 |
+
# "stove":12,
|
| 15 |
+
# "table":13,
|
| 16 |
+
# "toilet":14,
|
| 17 |
+
# "washer":15
|
| 18 |
+
# }
|
| 19 |
+
|
| 20 |
category_map={
|
| 21 |
+
"chair":0,
|
| 22 |
+
"sofa":1,
|
| 23 |
+
"table":2,
|
| 24 |
+
"cabinet":3,
|
| 25 |
+
"bed":4,
|
| 26 |
+
"shelf":5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
}
|
| 28 |
|
| 29 |
category_map_from_synthetic={
|
demo/api.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os,sys
|
| 2 |
+
sys.path.append("..")
|
| 3 |
+
from configs.config_utils import CONFIG
|
| 4 |
+
from models import get_model
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import open3d as o3d
|
| 8 |
+
import timm
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 11 |
+
from simple_dataset import InTheWild_Dataset,classname_remap,classname_map
|
| 12 |
+
try:
|
| 13 |
+
from torchvision.transforms import InterpolationMode
|
| 14 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 15 |
+
except ImportError:
|
| 16 |
+
BICUBIC = Image.BICUBIC
|
| 17 |
+
import mcubes
|
| 18 |
+
import trimesh
|
| 19 |
+
from torch.utils.data import DataLoader
|
| 20 |
+
|
| 21 |
+
def image_transform(n_px):
|
| 22 |
+
return Compose([
|
| 23 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 24 |
+
CenterCrop(n_px),
|
| 25 |
+
ToTensor(),
|
| 26 |
+
Normalize((0.48145466, 0.4578275, 0.40821073),
|
| 27 |
+
(0.26862954, 0.26130258, 0.27577711)),
|
| 28 |
+
])
|
| 29 |
+
|
| 30 |
+
MAX_IMG_LENGTH=5 #take up to 5 images as inputs
|
| 31 |
+
|
| 32 |
+
ae_paths={
|
| 33 |
+
"chair":"../checkpoint/ae/chair/best-checkpoint.pth",
|
| 34 |
+
"table":"../checkpoint/ae/table/best-checkpoint.pth",
|
| 35 |
+
"cabinet":"../checkpoint/ae/cabinet/best-checkpoint.pth",
|
| 36 |
+
"shelf":"../checkpoint/ae/shelf/best-checkpoint.pth",
|
| 37 |
+
"sofa":"../checkpoint/ae/sofa/best-checkpoint.pth",
|
| 38 |
+
"bed":"../checkpoint/ae/bed/best-checkpoint.pth"
|
| 39 |
+
}
|
| 40 |
+
dm_paths={
|
| 41 |
+
"chair":"../checkpoint/finetune_dm/chair/best-checkpoint.pth",
|
| 42 |
+
"table":"../checkpoint/finetune_dm/table/best-checkpoint.pth",
|
| 43 |
+
"cabinet":"../checkpoint/finetune_dm/cabinet/best-checkpoint.pth",
|
| 44 |
+
"shelf":"../checkpoint/finetune_dm/shelf/best-checkpoint.pth",
|
| 45 |
+
"sofa":"../checkpoint/finetune_dm/sofa/best-checkpoint.pth",
|
| 46 |
+
"bed":"../checkpoint/finetune_dm/bed/best-checkpoint.pth"
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
def inference(ae_model,dm_model,data_batch,device,reso=256):
|
| 50 |
+
density = reso
|
| 51 |
+
gap = 2.2 / density
|
| 52 |
+
x = np.linspace(-1.1, 1.1, int(density + 1))
|
| 53 |
+
y = np.linspace(-1.1, 1.1, int(density + 1))
|
| 54 |
+
z = np.linspace(-1.1, 1.1, int(density + 1))
|
| 55 |
+
xv, yv, zv = np.meshgrid(x, y, z, indexing='ij')
|
| 56 |
+
grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,
|
| 57 |
+
non_blocking=True)
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
sample_input = dm_model.prepare_sample_data(data_batch)
|
| 60 |
+
sampled_array = dm_model.sample(sample_input, num_steps=36).float()
|
| 61 |
+
sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")
|
| 62 |
+
|
| 63 |
+
model_ids = data_batch['model_id']
|
| 64 |
+
tran_mats = data_batch['tran_mat']
|
| 65 |
+
|
| 66 |
+
output_meshes={}
|
| 67 |
+
|
| 68 |
+
for j in range(sampled_array.shape[0]):
|
| 69 |
+
grid_list = torch.split(grid, 128 ** 3, dim=1)
|
| 70 |
+
output_list = []
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
for sub_grid in grid_list:
|
| 73 |
+
output_list.append(ae_model.decode(sampled_array[j:j + 1], sub_grid))
|
| 74 |
+
output = torch.cat(output_list, dim=1)
|
| 75 |
+
logits = output[j].detach()
|
| 76 |
+
|
| 77 |
+
volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy()
|
| 78 |
+
verts, faces = mcubes.marching_cubes(volume, 0)
|
| 79 |
+
|
| 80 |
+
verts *= gap
|
| 81 |
+
verts -= 1.1
|
| 82 |
+
|
| 83 |
+
tran_mat = tran_mats[j].numpy()
|
| 84 |
+
verts_homo = np.concatenate([verts, np.ones((verts.shape[0], 1))], axis=1)
|
| 85 |
+
verts_inwrd = np.dot(verts_homo, tran_mat.T)[:, 0:3]
|
| 86 |
+
m_inwrd = trimesh.Trimesh(verts_inwrd, faces[:, ::-1]) #transform the mesh into world coordinate
|
| 87 |
+
|
| 88 |
+
output_meshes[model_ids[j]]=m_inwrd
|
| 89 |
+
return output_meshes
|
| 90 |
+
|
| 91 |
+
if __name__=="__main__":
|
| 92 |
+
import argparse
|
| 93 |
+
parser=argparse.ArgumentParser()
|
| 94 |
+
parser.add_argument("--data_dir", type=str, default="../example_process_data")
|
| 95 |
+
parser.add_argument('--scene_id', default="all", type=str)
|
| 96 |
+
parser.add_argument("--save_dir", type=str,default="../example_output_data")
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
|
| 99 |
+
config_path="../configs/finetune_triplane_diffusion.yaml"
|
| 100 |
+
config=CONFIG(config_path).config
|
| 101 |
+
|
| 102 |
+
'''creating save folder'''
|
| 103 |
+
save_folder=os.path.join(args.save_dir,args.scene_id)
|
| 104 |
+
os.makedirs(save_folder,exist_ok=True)
|
| 105 |
+
|
| 106 |
+
'''prepare model'''
|
| 107 |
+
device=torch.device("cuda")
|
| 108 |
+
ae_config=config['model']['ae']
|
| 109 |
+
dm_config=config['model']['dm']
|
| 110 |
+
dm_model=get_model(dm_config).to(device)
|
| 111 |
+
ae_model=get_model(ae_config).to(device)
|
| 112 |
+
dm_model.eval()
|
| 113 |
+
ae_model.eval()
|
| 114 |
+
|
| 115 |
+
'''preparing data'''
|
| 116 |
+
'''find out how many classes are there in the whole scene'''
|
| 117 |
+
images_folder=os.path.join(args.data_dir,args.scene_id,"6_images")
|
| 118 |
+
object_id_list=os.listdir(images_folder)
|
| 119 |
+
object_class_list=[item.split("_")[0] for item in object_id_list]
|
| 120 |
+
all_object_class=list(set(object_class_list))
|
| 121 |
+
|
| 122 |
+
exist_super_categories=[]
|
| 123 |
+
for object_class in all_object_class:
|
| 124 |
+
if object_class not in classname_remap:
|
| 125 |
+
continue
|
| 126 |
+
else:
|
| 127 |
+
exist_super_categories.append(classname_remap[object_class]) #find which category specific models should be employed
|
| 128 |
+
exist_super_categories=list(set(exist_super_categories))
|
| 129 |
+
for super_category in exist_super_categories:
|
| 130 |
+
print("processing %s"%(super_category))
|
| 131 |
+
ae_ckpt=torch.load(ae_paths[super_category],map_location="cpu")["model"]
|
| 132 |
+
dm_ckpt=torch.load(dm_paths[super_category],map_location="cpu")["model"]
|
| 133 |
+
ae_model.load_state_dict(ae_ckpt)
|
| 134 |
+
dm_model.load_state_dict(dm_ckpt)
|
| 135 |
+
dataset = InTheWild_Dataset(data_dir=args.data_dir, scene_id=args.scene_id, category=super_category, max_n_imgs=5)
|
| 136 |
+
dataloader=DataLoader(
|
| 137 |
+
dataset=dataset,
|
| 138 |
+
num_workers=1,
|
| 139 |
+
batch_size=1,
|
| 140 |
+
shuffle=False
|
| 141 |
+
)
|
| 142 |
+
for data_batch in dataloader:
|
| 143 |
+
output_meshes=inference(ae_model,dm_model,data_batch,device)
|
| 144 |
+
#print(output_meshes)
|
| 145 |
+
for model_id in output_meshes:
|
| 146 |
+
mesh=output_meshes[model_id]
|
| 147 |
+
save_path=os.path.join(save_folder,model_id+".ply")
|
| 148 |
+
print("saving to %s"%(save_path))
|
| 149 |
+
mesh.export(save_path)
|
| 150 |
+
|
| 151 |
+
|
demo/extract_vit_features.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os,sys
|
| 2 |
+
sys.path.append("..")
|
| 3 |
+
import numpy
|
| 4 |
+
from simple_dataset import Simple_InTheWild_dataset
|
| 5 |
+
import argparse
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import timm
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from util import misc
|
| 11 |
+
|
| 12 |
+
parser=argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument("--data_dir",type=str,default="../example_process_data")
|
| 14 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 15 |
+
help='number of distributed processes')
|
| 16 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
| 17 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
| 18 |
+
parser.add_argument('--dist_url', default='env://',
|
| 19 |
+
help='url used to set up distributed training')
|
| 20 |
+
parser.add_argument('--scene_id',default="all",type=str)
|
| 21 |
+
args=parser.parse_args()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
misc.init_distributed_mode(args)
|
| 25 |
+
dataset=Simple_InTheWild_dataset(dataset_dir=args.data_dir,scene_id=args.scene_id,n_px=224)
|
| 26 |
+
num_tasks = misc.get_world_size()
|
| 27 |
+
global_rank = misc.get_rank()
|
| 28 |
+
print(num_tasks,global_rank)
|
| 29 |
+
sampler = torch.utils.data.DistributedSampler(
|
| 30 |
+
dataset, num_replicas=num_tasks, rank=global_rank,
|
| 31 |
+
shuffle=False) # shuffle=True to reduce monitor bias
|
| 32 |
+
|
| 33 |
+
dataloader=DataLoader(
|
| 34 |
+
dataset,
|
| 35 |
+
sampler=sampler,
|
| 36 |
+
batch_size=10,
|
| 37 |
+
num_workers=4,
|
| 38 |
+
pin_memory=True,
|
| 39 |
+
drop_last=False
|
| 40 |
+
)
|
| 41 |
+
VIT_MODEL = 'vit_huge_patch14_224_clip_laion2b'
|
| 42 |
+
model=timm.create_model(VIT_MODEL, pretrained=True,pretrained_cfg_overlay=dict(file="./open_clip_pytorch_model.bin"))
|
| 43 |
+
model=model.eval().float().cuda()
|
| 44 |
+
for idx,data_batch in enumerate(dataloader):
|
| 45 |
+
if idx%10==0:
|
| 46 |
+
print("{}/{}".format(dataloader.__len__(),idx))
|
| 47 |
+
images = data_batch["images"].cuda().float()
|
| 48 |
+
model_id= data_batch["model_id"]
|
| 49 |
+
image_name=data_batch["image_name"]
|
| 50 |
+
scene_id=data_batch["scene_id"]
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
output_features=model.forward_features(images)
|
| 53 |
+
for j in range(output_features.shape[0]):
|
| 54 |
+
save_folder=os.path.join(args.data_dir,scene_id[j],"7_img_feature",model_id[j])
|
| 55 |
+
os.makedirs(save_folder,exist_ok=True)
|
| 56 |
+
save_path=os.path.join(save_folder,image_name[j]+".npz")
|
| 57 |
+
np.savez_compressed(save_path,img_features=output_features[j].detach().cpu().numpy().astype(np.float32))
|
demo/process_data.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import open3d as o3d
|
| 5 |
+
import glob
|
| 6 |
+
import cv2
|
| 7 |
+
import copy
|
| 8 |
+
|
| 9 |
+
def get_roll_rot(angle):
|
| 10 |
+
ca=np.cos(angle)
|
| 11 |
+
sa=np.sin(angle)
|
| 12 |
+
rot=np.array([
|
| 13 |
+
[ca,-sa,0,0],
|
| 14 |
+
[sa,ca,0,0],
|
| 15 |
+
[0,0,1,0],
|
| 16 |
+
[0,0,0,1]
|
| 17 |
+
])
|
| 18 |
+
return rot
|
| 19 |
+
def rotate_mat(direction):
|
| 20 |
+
if direction == 'Up':
|
| 21 |
+
return np.eye(4)
|
| 22 |
+
elif direction == 'Left':
|
| 23 |
+
rot_mat=get_roll_rot(np.pi/2)
|
| 24 |
+
elif direction == 'Right':
|
| 25 |
+
rot_mat=get_roll_rot(-np.pi/2)
|
| 26 |
+
elif direction == 'Down':
|
| 27 |
+
rot_mat=get_roll_rot(np.pi)
|
| 28 |
+
else:
|
| 29 |
+
raise Exception(f'No such direction (={direction}) rotation')
|
| 30 |
+
return rot_mat
|
| 31 |
+
|
| 32 |
+
def rotate_K(K,direction):
|
| 33 |
+
if direction == 'Up' or direction=="Down":
|
| 34 |
+
new_K4=np.eye(4)
|
| 35 |
+
new_K4[0:3,0:3]=copy.deepcopy(K)
|
| 36 |
+
return new_K4
|
| 37 |
+
elif direction == 'Left' or direction =="Right":
|
| 38 |
+
fx,fy,cx,cy=K[0,0],K[1,1],K[0,2],K[1,2]
|
| 39 |
+
new_K4 = np.array([
|
| 40 |
+
[fy, 0, cy, 0],
|
| 41 |
+
[0, fx, cx, 0],
|
| 42 |
+
[0, 0, 1, 0],
|
| 43 |
+
[0, 0, 0, 1]
|
| 44 |
+
])
|
| 45 |
+
return new_K4
|
| 46 |
+
|
| 47 |
+
def rotate_bbox(bbox,direction, H,W):
|
| 48 |
+
|
| 49 |
+
x_min,y_min,x_max,y_max=bbox[0:4]
|
| 50 |
+
if direction == 'Up':
|
| 51 |
+
return bbox
|
| 52 |
+
elif direction == 'Left':
|
| 53 |
+
#print(W-bbox[1],W-bbox[3])
|
| 54 |
+
new_bbox=[min(H-bbox[1],H-bbox[3]),bbox[0],max(H-bbox[1],H-bbox[3]),bbox[2]]
|
| 55 |
+
elif direction == 'Right':
|
| 56 |
+
new_bbox=[bbox[1],min(W-bbox[0],W-bbox[2]),bbox[3],max(W-bbox[0],W-bbox[2])]
|
| 57 |
+
elif direction == 'Down':
|
| 58 |
+
new_bbox=[min(W-x_min,W-x_max),min(H-y_min,H-y_max),max(W-x_min,W-x_max),max(H-y_min,H-y_max)]
|
| 59 |
+
else:
|
| 60 |
+
raise Exception(f'No such direction (={direction}) rotation')
|
| 61 |
+
return new_bbox
|
| 62 |
+
|
| 63 |
+
def rotate_image(img, direction):
|
| 64 |
+
if direction == 'Up':
|
| 65 |
+
pass
|
| 66 |
+
elif direction == 'Left':
|
| 67 |
+
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
| 68 |
+
elif direction == 'Right':
|
| 69 |
+
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
| 70 |
+
elif direction == 'Down':
|
| 71 |
+
img = cv2.rotate(img, cv2.ROTATE_180)
|
| 72 |
+
else:
|
| 73 |
+
raise Exception(f'No such direction (={direction}) rotation')
|
| 74 |
+
return img
|
| 75 |
+
|
| 76 |
+
parser=argparse.ArgumentParser()
|
| 77 |
+
parser.add_argument("--data_folder",type=str,required=True)
|
| 78 |
+
parser.add_argument("--save_dir",type=str,default=r"../example_process_data")
|
| 79 |
+
parser.add_argument("--debug",action="store_true",default=False)
|
| 80 |
+
args=parser.parse_args()
|
| 81 |
+
|
| 82 |
+
print("processing %s"%(args.data_folder))
|
| 83 |
+
|
| 84 |
+
data_folder=args.data_folder
|
| 85 |
+
scene_name=os.path.basename(data_folder)
|
| 86 |
+
save_folder=os.path.join(args.save_dir,scene_name)
|
| 87 |
+
os.makedirs(save_folder,exist_ok=True)
|
| 88 |
+
color_folder=os.path.join(data_folder,"color")
|
| 89 |
+
depth_folder=os.path.join(data_folder,"depth")
|
| 90 |
+
pose_folder=os.path.join(data_folder,"pose")
|
| 91 |
+
|
| 92 |
+
print(color_folder)
|
| 93 |
+
|
| 94 |
+
color_list=glob.glob(color_folder+"/*.jpg")
|
| 95 |
+
image_id_list=[os.path.basename(item)[0:-4] for item in color_list]
|
| 96 |
+
image_id_list.sort()
|
| 97 |
+
|
| 98 |
+
bbox_path=os.path.join(data_folder,"objects.npy")
|
| 99 |
+
bboxes_dict=np.load(bbox_path,allow_pickle=True).item()
|
| 100 |
+
|
| 101 |
+
intrinsic_path=os.path.join(data_folder,"intrinsic","intrinsic_color.txt")
|
| 102 |
+
K=np.loadtxt(intrinsic_path)
|
| 103 |
+
|
| 104 |
+
align_path=os.path.join(data_folder,"alignment_matrix.txt")
|
| 105 |
+
align_matrix=np.loadtxt(align_path)
|
| 106 |
+
if align_matrix.shape[0]==3:
|
| 107 |
+
new_align_matrix=np.eye(4)
|
| 108 |
+
new_align_matrix[0:3,0:3]=align_matrix
|
| 109 |
+
align_matrix=new_align_matrix
|
| 110 |
+
|
| 111 |
+
mesh_path=os.path.join(data_folder,"fused_mesh.ply")
|
| 112 |
+
o3d_mesh=o3d.io.read_triangle_mesh(mesh_path)
|
| 113 |
+
o3d_vertices = np.array(o3d_mesh.vertices)
|
| 114 |
+
o3d_vert_homo=np.concatenate([o3d_vertices,np.ones([o3d_vertices.shape[0],1])],axis=1)
|
| 115 |
+
align_o3d_vertices = np.dot(o3d_vert_homo,align_matrix)[:,0:3]
|
| 116 |
+
o3d_mesh.vertices = o3d.utility.Vector3dVector(align_o3d_vertices)
|
| 117 |
+
align_mesh_save_path=os.path.join(save_folder,"align_mesh.ply")
|
| 118 |
+
o3d.io.write_triangle_mesh(align_mesh_save_path,o3d_mesh)
|
| 119 |
+
|
| 120 |
+
x=np.linspace(-1,1,10)
|
| 121 |
+
y=np.linspace(-1,1,10)
|
| 122 |
+
z=np.linspace(-1,1,10)
|
| 123 |
+
X,Y,Z=np.meshgrid(x,y,z,indexing='ij')
|
| 124 |
+
vox_coor=np.concatenate([X[:,:,:,np.newaxis],Y[:,:,:,np.newaxis],Z[:,:,:,np.newaxis]],axis=-1)
|
| 125 |
+
vox_coor=np.reshape(vox_coor,(-1,3))
|
| 126 |
+
#print(np.amin(vox_coor,axis=0),np.amax(vox_coor,axis=0))
|
| 127 |
+
|
| 128 |
+
pre_proj_mates={}
|
| 129 |
+
obj_points_dict={}
|
| 130 |
+
trans_mats={}
|
| 131 |
+
point_save_folder=os.path.join(save_folder,"5_partial_points")
|
| 132 |
+
os.makedirs(point_save_folder,exist_ok=True)
|
| 133 |
+
tran_save_folder=os.path.join(save_folder,"10_tran_matrix")
|
| 134 |
+
os.makedirs(tran_save_folder,exist_ok=True)
|
| 135 |
+
for object_id in bboxes_dict:
|
| 136 |
+
object = bboxes_dict[object_id]
|
| 137 |
+
category = object['category']
|
| 138 |
+
sizes = object['size']
|
| 139 |
+
sizes *= 1.1
|
| 140 |
+
transform_matrix_t = np.array(object['transform']).reshape([4, 4])
|
| 141 |
+
translate = transform_matrix_t[:3, 3]
|
| 142 |
+
rotation = transform_matrix_t[:3, :3]
|
| 143 |
+
|
| 144 |
+
bbox_o3d = o3d.geometry.OrientedBoundingBox(translate.reshape([3, 1]),
|
| 145 |
+
rotation,
|
| 146 |
+
np.array(sizes).reshape([3, 1]))
|
| 147 |
+
crop_pcd = o3d_mesh.crop(bbox_o3d)
|
| 148 |
+
crop_vert = np.asarray(crop_pcd.vertices)
|
| 149 |
+
org_crop_vert = crop_vert[:, :]
|
| 150 |
+
crop_vert = crop_vert - translate
|
| 151 |
+
crop_vert = np.dot(crop_vert,np.linalg.inv(rotation).T)
|
| 152 |
+
crop_vert[:, 2] *= -1
|
| 153 |
+
bb_min, bb_max = np.amin(crop_vert, axis=0), np.amax(crop_vert, axis=0)
|
| 154 |
+
max_length = (bb_max - bb_min).max()
|
| 155 |
+
center = (bb_max + bb_min) / 2
|
| 156 |
+
crop_vert = (crop_vert - center) / max_length * 2
|
| 157 |
+
|
| 158 |
+
obj_points_dict[object_id]=crop_vert
|
| 159 |
+
crop_pcd.vertices=o3d.utility.Vector3dVector(crop_vert)
|
| 160 |
+
save_path=os.path.join(point_save_folder,category+"_%d.ply"%(object_id))
|
| 161 |
+
o3d.io.write_triangle_mesh(save_path,crop_pcd)
|
| 162 |
+
|
| 163 |
+
proj_mat = np.eye(4)
|
| 164 |
+
scale_tran = np.eye(4)
|
| 165 |
+
scale_tran[0, 0], scale_tran[1, 1], scale_tran[2, 2] = max_length / 2, max_length / 2, max_length / 2
|
| 166 |
+
proj_mat = np.dot(proj_mat, scale_tran)
|
| 167 |
+
center_tran = np.eye(4)
|
| 168 |
+
center_tran[0:3, 3] = center
|
| 169 |
+
proj_mat = np.dot(center_tran, proj_mat)
|
| 170 |
+
invert_mat = np.eye(4)
|
| 171 |
+
invert_mat[2, 2] *= -1
|
| 172 |
+
proj_mat = np.dot(invert_mat, proj_mat)
|
| 173 |
+
proj_mat[0:3, 0:3] = np.dot(rotation,proj_mat[0:3, 0:3])
|
| 174 |
+
translate_mat = np.eye(4)
|
| 175 |
+
translate_mat[0:3, 3] = translate
|
| 176 |
+
proj_mat = np.dot(translate_mat, proj_mat)
|
| 177 |
+
|
| 178 |
+
'''tran mat is to align output to scene space'''
|
| 179 |
+
tran_mat=copy.deepcopy(proj_mat)
|
| 180 |
+
trans_mats[object_id]=tran_mat
|
| 181 |
+
tran_save_path=os.path.join(tran_save_folder,category+"_%d.npy"%(object_id))
|
| 182 |
+
np.save(tran_save_path,tran_mat)
|
| 183 |
+
|
| 184 |
+
unalign_mat = np.linalg.inv(align_matrix)
|
| 185 |
+
proj_mat = np.dot(unalign_mat.T, proj_mat)
|
| 186 |
+
pre_proj_mates[object_id]=proj_mat
|
| 187 |
+
|
| 188 |
+
ref=np.array([
|
| 189 |
+
[0,1.0], #Up
|
| 190 |
+
[-1.0,0],#Left
|
| 191 |
+
[0,1.0], #Right
|
| 192 |
+
[0.0,-1.0] #Down
|
| 193 |
+
]) #4*2
|
| 194 |
+
dir_list=[
|
| 195 |
+
"Down",
|
| 196 |
+
"Left",
|
| 197 |
+
"Right",
|
| 198 |
+
"Up"
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
for image_id in image_id_list:
|
| 202 |
+
color_path=os.path.join(color_folder,image_id+".jpg")
|
| 203 |
+
depth_path=os.path.join(depth_folder,image_id+".png")
|
| 204 |
+
pose_path=os.path.join(pose_folder,image_id+".txt")
|
| 205 |
+
|
| 206 |
+
color=cv2.imread(color_path)
|
| 207 |
+
height,width=color.shape[0:2]
|
| 208 |
+
depth=cv2.imread(depth_path,cv2.IMREAD_ANYCOLOR|cv2.IMREAD_ANYDEPTH)/1000.0
|
| 209 |
+
pose=np.loadtxt(pose_path)
|
| 210 |
+
for object_id in bboxes_dict:
|
| 211 |
+
object=bboxes_dict[object_id]
|
| 212 |
+
category=object['category']
|
| 213 |
+
sizes=object['size']
|
| 214 |
+
object_vox_coor=vox_coor*sizes[np.newaxis,:]
|
| 215 |
+
#print(np.amin(object_vox_coor,axis=0),np.amax(object_vox_coor,axis=0))
|
| 216 |
+
#print(sizes)
|
| 217 |
+
|
| 218 |
+
prev_proj_mat=pre_proj_mates[object_id]
|
| 219 |
+
wrd2cam_pose = np.linalg.inv(pose)
|
| 220 |
+
current_proj_mat = np.dot(wrd2cam_pose, prev_proj_mat)
|
| 221 |
+
K4=np.eye(4)
|
| 222 |
+
K4[0:3,0:3]=K
|
| 223 |
+
|
| 224 |
+
'''calibrate proj_mat'''
|
| 225 |
+
up_vectors = np.array([[0, 0, 0, 1.0],
|
| 226 |
+
[0, 0.5, 0, 1.0]])
|
| 227 |
+
up_vec_inimg = np.dot(up_vectors, current_proj_mat.T)
|
| 228 |
+
up_vec_inimg = np.dot(up_vec_inimg,K4.T)
|
| 229 |
+
up_x = up_vec_inimg[:, 0] / up_vec_inimg[:, 2]
|
| 230 |
+
up_y = up_vec_inimg[:, 1] / up_vec_inimg[:, 2]
|
| 231 |
+
pt1 = np.array((up_x[0], up_y[0]))
|
| 232 |
+
pt2 = np.array((up_x[1], up_y[1]))
|
| 233 |
+
up_dir = pt2 - pt1
|
| 234 |
+
# print(up_dir)
|
| 235 |
+
|
| 236 |
+
product = np.sum(up_dir[np.newaxis, :] * ref, axis=1)
|
| 237 |
+
max_ind = np.argmax(product)
|
| 238 |
+
direction = dir_list[max_ind]
|
| 239 |
+
sky_rot = rotate_mat(direction)
|
| 240 |
+
#final_proj_mat = np.dot(K4,final_proj_mat)
|
| 241 |
+
|
| 242 |
+
vox_homo=np.concatenate([object_vox_coor,np.ones((object_vox_coor.shape[0],1))],axis=1)
|
| 243 |
+
vox_proj=np.dot(vox_homo,current_proj_mat.T)
|
| 244 |
+
vox_proj=np.dot(vox_proj,K4.T)
|
| 245 |
+
vox_x=vox_proj[:,0]/vox_proj[:,2]
|
| 246 |
+
vox_y=vox_proj[:,1]/vox_proj[:,2]
|
| 247 |
+
|
| 248 |
+
if np.mean(vox_proj[:,2])>5:
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
inside_mask=((vox_x<width-1) &(vox_x>0) &(vox_y<height-1) &(vox_y>0)).astype(np.float32)
|
| 252 |
+
infrustum_ratio=np.sum(inside_mask)/vox_x.shape[0]
|
| 253 |
+
if infrustum_ratio < 0.4 and category in ["chair", "stool"]:
|
| 254 |
+
continue
|
| 255 |
+
elif infrustum_ratio <0.2:
|
| 256 |
+
continue
|
| 257 |
+
#print(object_id,image_id,infrustum_ratio)
|
| 258 |
+
|
| 259 |
+
'''objects visibility check for every frame'''
|
| 260 |
+
vox_x_inside=vox_x[inside_mask>0].astype(np.int32)
|
| 261 |
+
vox_y_inside=vox_y[inside_mask>0].astype(np.int32)
|
| 262 |
+
vox_depth=vox_proj[inside_mask>0,2]
|
| 263 |
+
#print(depth.shape,np.amax(vox_y_inside),np.amax(vox_x_inside))
|
| 264 |
+
depth_sample=depth[vox_y_inside,vox_x_inside]
|
| 265 |
+
depth_mask=(depth_sample>0)&(depth_sample<10.0)
|
| 266 |
+
depth_sample=depth_sample[depth_mask]
|
| 267 |
+
vox_depth=vox_depth[depth_mask]
|
| 268 |
+
|
| 269 |
+
if vox_depth.shape[0]<100:
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
occluded_ratio=np.sum(((vox_depth-depth_sample)>0.2).astype(np.float32))/vox_depth.shape[0]
|
| 273 |
+
if occluded_ratio>0.6 and category in ["chair"]: #chair is easily occluded, while table is not
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
depth_near_ratio = np.sum((np.abs(vox_depth - depth_sample) < sizes.max() * 0.5).astype(np.float32)) / \
|
| 277 |
+
vox_depth.shape[0]
|
| 278 |
+
if depth_near_ratio < 0.2:
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
'''make sure in every image, the object is upward'''
|
| 282 |
+
bbox=(np.amin(vox_x_inside),np.amin(vox_y_inside),np.amax(vox_x_inside),np.amax(vox_y_inside))
|
| 283 |
+
rot_image=rotate_image(color,direction)
|
| 284 |
+
bbox = rotate_bbox(bbox, direction, height, width)
|
| 285 |
+
crop_image=rot_image[bbox[1]:bbox[3],bbox[0]:bbox[2]]
|
| 286 |
+
crop_h, crop_w = crop_image.shape[0:2]
|
| 287 |
+
max_length = max(crop_h, crop_w)
|
| 288 |
+
if max_length<100:
|
| 289 |
+
continue
|
| 290 |
+
pad_image = np.zeros((max_length, max_length, 3))
|
| 291 |
+
if crop_h > crop_w:
|
| 292 |
+
margin = crop_h - crop_w
|
| 293 |
+
pad_image[:, margin // 2:margin // 2 + crop_w] = crop_image[:, :, :]
|
| 294 |
+
x_start, x_end = bbox[0] - margin // 2, margin // 2 + bbox[2]
|
| 295 |
+
y_start, y_end = bbox[1], bbox[3]
|
| 296 |
+
else:
|
| 297 |
+
margin = crop_w - crop_h
|
| 298 |
+
pad_image[margin // 2:margin // 2 + crop_h, :] = crop_image[:, :, :]
|
| 299 |
+
|
| 300 |
+
y_start, y_end = bbox[1] - margin // 2, bbox[3] + margin // 2
|
| 301 |
+
x_start, x_end = bbox[0], bbox[2]
|
| 302 |
+
|
| 303 |
+
pad_image=cv2.resize(pad_image,dsize=(224,224),interpolation=cv2.INTER_LINEAR)
|
| 304 |
+
image_save_folder = os.path.join(save_folder, "6_images", category + "_%d" % (object_id))
|
| 305 |
+
os.makedirs(image_save_folder, exist_ok=True)
|
| 306 |
+
image_save_path=os.path.join(image_save_folder,image_id+".jpg")
|
| 307 |
+
#print("saving to %s"%(image_save_path))
|
| 308 |
+
cv2.imwrite(image_save_path,pad_image)
|
| 309 |
+
|
| 310 |
+
proj_mat=np.dot(sky_rot,current_proj_mat)
|
| 311 |
+
new_K4 = rotate_K(K, direction)
|
| 312 |
+
new_K4[0, 2] -= x_start
|
| 313 |
+
new_K4[1, 2] -= y_start
|
| 314 |
+
new_K4[0] = new_K4[0] / max_length * 224
|
| 315 |
+
new_K4[1] = new_K4[1] / max_length * 224
|
| 316 |
+
proj_mat = np.dot(new_K4, proj_mat)
|
| 317 |
+
|
| 318 |
+
proj_save_folder=os.path.join(save_folder,"8_proj_matrix",category+"_%d"%(object_id))
|
| 319 |
+
os.makedirs(proj_save_folder,exist_ok=True)
|
| 320 |
+
proj_save_path=os.path.join(proj_save_folder,image_id+".npy")
|
| 321 |
+
np.save(proj_save_path,proj_mat)
|
| 322 |
+
|
| 323 |
+
'''debug proj matrix'''
|
| 324 |
+
if args.debug:
|
| 325 |
+
proj_save_folder=os.path.join(save_folder,"9_proj_images",category+"_%d"%(object_id))
|
| 326 |
+
os.makedirs(proj_save_folder,exist_ok=True)
|
| 327 |
+
canvas=copy.deepcopy(pad_image)
|
| 328 |
+
par_points=obj_points_dict[object_id]
|
| 329 |
+
par_homo=np.concatenate([par_points,np.ones((par_points.shape[0],1))],axis=1)
|
| 330 |
+
par_inimg=np.dot(par_homo,proj_mat.T)
|
| 331 |
+
x=par_inimg[:,0]/par_inimg[:,2]
|
| 332 |
+
y=par_inimg[:,1]/par_inimg[:,2]
|
| 333 |
+
x=np.clip(x,a_min=0,a_max=223).astype(np.int32)
|
| 334 |
+
y=np.clip(y,a_min=0,a_max=223).astype(np.int32)
|
| 335 |
+
canvas[y,x]=np.array([[0,255,0]])
|
| 336 |
+
proj_save_path=os.path.join(proj_save_folder,image_id+".jpg")
|
| 337 |
+
cv2.imwrite(proj_save_path,canvas)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
demo/simple_dataset.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils import data
|
| 4 |
+
import os
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 7 |
+
try:
|
| 8 |
+
from torchvision.transforms import InterpolationMode
|
| 9 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 10 |
+
except ImportError:
|
| 11 |
+
BICUBIC = Image.BICUBIC
|
| 12 |
+
import glob
|
| 13 |
+
import numpy as np
|
| 14 |
+
import open3d as o3d
|
| 15 |
+
import cv2
|
| 16 |
+
from datasets.taxonomy import category_map as category_ids
|
| 17 |
+
|
| 18 |
+
classname_map={
|
| 19 |
+
"chair":["chair","stool"],
|
| 20 |
+
"cabinet":["dishwasher","cabinet","oven","refrigerator",'storage'],
|
| 21 |
+
"sofa":["sofa"],
|
| 22 |
+
"table":["table"],
|
| 23 |
+
"bed":["bed"],
|
| 24 |
+
"shelf":["shelf"]
|
| 25 |
+
}
|
| 26 |
+
classname_remap={ #map small categories to six large categories
|
| 27 |
+
"chair":"chair",
|
| 28 |
+
"stool":"chair",
|
| 29 |
+
"dishwasher":"cabinet",
|
| 30 |
+
"cabinet":"cabinet",
|
| 31 |
+
"oven":"cabinet",
|
| 32 |
+
"refrigerator":"cabinet",
|
| 33 |
+
"storage":"cabinet",
|
| 34 |
+
"sofa":"sofa",
|
| 35 |
+
"table":"table",
|
| 36 |
+
"bed":"bed",
|
| 37 |
+
"shelf":"shelf"
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def image_transform(n_px):
|
| 41 |
+
return Compose([
|
| 42 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 43 |
+
CenterCrop(n_px),
|
| 44 |
+
ToTensor(),
|
| 45 |
+
Normalize((0.48145466, 0.4578275, 0.40821073),
|
| 46 |
+
(0.26862954, 0.26130258, 0.27577711)),
|
| 47 |
+
])
|
| 48 |
+
class Simple_InTheWild_dataset(data.Dataset):
|
| 49 |
+
def __init__(self,dataset_dir="/data1/haolin/data/real_scene_process_data",scene_id="letian-310",n_px=224):
|
| 50 |
+
self.dataset_dir=dataset_dir
|
| 51 |
+
self.preprocess = image_transform(n_px)
|
| 52 |
+
self.image_path = []
|
| 53 |
+
if scene_id=="all":
|
| 54 |
+
scene_list=os.listdir(self.dataset_dir)
|
| 55 |
+
for id in scene_list:
|
| 56 |
+
image_folder=os.path.join(self.dataset_dir,id,"6_images")
|
| 57 |
+
self.image_path+=glob.glob(image_folder+"/*/*jpg")
|
| 58 |
+
else:
|
| 59 |
+
image_folder = os.path.join(self.dataset_dir, scene_id, "6_images")
|
| 60 |
+
self.image_path += glob.glob(image_folder + "/*/*jpg")
|
| 61 |
+
def __len__(self):
|
| 62 |
+
return len(self.image_path)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self,index):
|
| 65 |
+
path=self.image_path[index]
|
| 66 |
+
basename=os.path.basename(path)[:-4]
|
| 67 |
+
model_id=path.split(os.sep)[-2]
|
| 68 |
+
scene_id=path.split(os.sep)[-4]
|
| 69 |
+
image=Image.open(path)
|
| 70 |
+
image_tensor=self.preprocess(image)
|
| 71 |
+
|
| 72 |
+
return {"images":image_tensor,"image_name":basename,"model_id":model_id,"scene_id":scene_id}
|
| 73 |
+
|
| 74 |
+
class InTheWild_Dataset(data.Dataset):
|
| 75 |
+
def __init__(self,data_dir="/data1/haolin/data/real_scene_process_data/letian-310",scene_id="letian-310",
|
| 76 |
+
par_pc_size=2048,category="chair",max_n_imgs=5):
|
| 77 |
+
self.par_pc_size=par_pc_size
|
| 78 |
+
self.data_dir=data_dir
|
| 79 |
+
self.category=category
|
| 80 |
+
self.max_n_imgs=max_n_imgs
|
| 81 |
+
|
| 82 |
+
self.models=[]
|
| 83 |
+
category_list=classname_map[category]
|
| 84 |
+
modelid_list=[]
|
| 85 |
+
for cat in category_list:
|
| 86 |
+
if scene_id=="all":
|
| 87 |
+
scene_list=os.listdir(self.data_dir)
|
| 88 |
+
for id in scene_list:
|
| 89 |
+
data_folder=os.path.join(self.data_dir,id)
|
| 90 |
+
modelid_list+=glob.glob(data_folder+"/6_images/%s*"%(cat))
|
| 91 |
+
else:
|
| 92 |
+
data_folder=os.path.join(self.data_dir,scene_id)
|
| 93 |
+
modelid_list+=glob.glob(data_folder+"/6_images/%s*"%(cat))
|
| 94 |
+
sceneid_list = [item.split("/")[-3] for item in modelid_list]
|
| 95 |
+
modelid_list=[item.split("/")[-1] for item in modelid_list]
|
| 96 |
+
for idx,modelid in enumerate(modelid_list):
|
| 97 |
+
scene_id=sceneid_list[idx]
|
| 98 |
+
image_folder=os.path.join(self.data_dir,scene_id,"6_images",modelid)
|
| 99 |
+
image_list=os.listdir(image_folder)
|
| 100 |
+
if len(image_list)==0:
|
| 101 |
+
continue
|
| 102 |
+
imageid_list=[item[0:-4] for item in image_list]
|
| 103 |
+
imageid_list.sort(key=lambda x:int(x))
|
| 104 |
+
partial_path=os.path.join(self.data_dir,scene_id,"5_partial_points",modelid+".ply")
|
| 105 |
+
if os.path.exists(partial_path)==False: continue
|
| 106 |
+
self.models+=[
|
| 107 |
+
{'model_id':modelid,
|
| 108 |
+
"scene_id":scene_id,
|
| 109 |
+
"partial_path":partial_path,
|
| 110 |
+
"imageid_list":imageid_list,
|
| 111 |
+
}
|
| 112 |
+
]
|
| 113 |
+
def __len__(self):
|
| 114 |
+
return len(self.models)
|
| 115 |
+
|
| 116 |
+
def __getitem__(self,idx):
|
| 117 |
+
model = self.models[idx]['model_id']
|
| 118 |
+
scene_id=self.models[idx]['scene_id']
|
| 119 |
+
imageid_list = self.models[idx]['imageid_list']
|
| 120 |
+
partial_path=self.models[idx]['partial_path']
|
| 121 |
+
n_frames=min(len(imageid_list),self.max_n_imgs)
|
| 122 |
+
img_indexes=np.linspace(start=0,stop=len(imageid_list)-1,num=n_frames).astype(np.int32)
|
| 123 |
+
|
| 124 |
+
'''load partial points'''
|
| 125 |
+
par_point_o3d = o3d.io.read_point_cloud(partial_path)
|
| 126 |
+
par_points = np.asarray(par_point_o3d.points)
|
| 127 |
+
replace = par_points.shape[0] < self.par_pc_size
|
| 128 |
+
ind = np.random.default_rng().choice(par_points.shape[0], self.par_pc_size, replace=replace)
|
| 129 |
+
par_points=par_points[ind]
|
| 130 |
+
par_points=torch.from_numpy(par_points).float()
|
| 131 |
+
|
| 132 |
+
'''load image features'''
|
| 133 |
+
image_list=[]
|
| 134 |
+
valid_frames = []
|
| 135 |
+
image_namelist=[]
|
| 136 |
+
for img_index in img_indexes:
|
| 137 |
+
image_name = imageid_list[img_index]
|
| 138 |
+
image_feat_path = os.path.join(self.data_dir,scene_id, "7_img_feature", model,image_name + '.npz')
|
| 139 |
+
image = np.load(image_feat_path)["img_features"]
|
| 140 |
+
image_list.append(torch.from_numpy(image).float())
|
| 141 |
+
image_namelist.append(image_name)
|
| 142 |
+
valid_frames.append(True)
|
| 143 |
+
'''load original image'''
|
| 144 |
+
org_img_list=[]
|
| 145 |
+
for img_index in img_indexes:
|
| 146 |
+
image_name = imageid_list[img_index]
|
| 147 |
+
image_path = os.path.join(self.data_dir,scene_id, "6_images", model,image_name+".jpg")
|
| 148 |
+
org_image = cv2.imread(image_path)
|
| 149 |
+
org_image = cv2.resize(org_image, dsize=(224, 224), interpolation=cv2.INTER_LINEAR)
|
| 150 |
+
org_img_list.append(org_image)
|
| 151 |
+
|
| 152 |
+
'''load project matrix'''
|
| 153 |
+
proj_mat_list=[]
|
| 154 |
+
for img_index in img_indexes:
|
| 155 |
+
image_name = imageid_list[img_index]
|
| 156 |
+
proj_mat_path = os.path.join(self.data_dir,scene_id, "8_proj_matrix", model, image_name + ".npy")
|
| 157 |
+
proj_mat = np.load(proj_mat_path)
|
| 158 |
+
proj_mat_list.append(proj_mat)
|
| 159 |
+
|
| 160 |
+
'''load transformation matrix'''
|
| 161 |
+
tran_mat_path = os.path.join(self.data_dir,scene_id, "10_tran_matrix", model+".npy")
|
| 162 |
+
tran_mat = np.load(tran_mat_path)
|
| 163 |
+
|
| 164 |
+
'''category code, not used for category specific models'''
|
| 165 |
+
category_id = category_ids[self.category]
|
| 166 |
+
one_hot = torch.zeros((6)).float()
|
| 167 |
+
one_hot[category_id] = 1.0
|
| 168 |
+
|
| 169 |
+
ret_dict={
|
| 170 |
+
"model_id":model,
|
| 171 |
+
"scene_id":scene_id,
|
| 172 |
+
"par_points":par_points,
|
| 173 |
+
"proj_mat":torch.stack([torch.from_numpy(mat) for mat in proj_mat_list], dim=0),
|
| 174 |
+
"tran_mat":torch.from_numpy(tran_mat).float(),
|
| 175 |
+
"image":torch.stack(image_list,dim=0),
|
| 176 |
+
"org_image":org_img_list,
|
| 177 |
+
"valid_frames":torch.tensor(valid_frames).bool(),
|
| 178 |
+
"category_ids": category_id,
|
| 179 |
+
"category_code":one_hot,
|
| 180 |
+
}
|
| 181 |
+
return ret_dict
|
| 182 |
+
|
train_VAE.sh
CHANGED
|
@@ -12,4 +12,4 @@ train_triplane_vae.py \
|
|
| 12 |
--clip_grad 0.35 \
|
| 13 |
--category chair \
|
| 14 |
--data-pth ../data \
|
| 15 |
-
--replica 5
|
|
|
|
| 12 |
--clip_grad 0.35 \
|
| 13 |
--category chair \
|
| 14 |
--data-pth ../data \
|
| 15 |
+
--replica 5 7
|
util/misc.py
CHANGED
|
@@ -15,7 +15,7 @@ from pathlib import Path
|
|
| 15 |
import torch
|
| 16 |
import torch.distributed as dist
|
| 17 |
#from torch._six import inf
|
| 18 |
-
import
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
def log_codefiles(data_root,save_root):
|
|
|
|
| 15 |
import torch
|
| 16 |
import torch.distributed as dist
|
| 17 |
#from torch._six import inf
|
| 18 |
+
import math
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
def log_codefiles(data_root,save_root):
|
util/simple_image_loader.py
CHANGED
|
@@ -16,12 +16,8 @@ def image_transform(n_px):
|
|
| 16 |
Resize(n_px, interpolation=BICUBIC),
|
| 17 |
CenterCrop(n_px),
|
| 18 |
ToTensor(),
|
| 19 |
-
# Normalize((123.675/255.0,116.28/255.0,103.53/255.0),
|
| 20 |
-
# (58.395/255.0,57.12/255.0,57.375/255.0))
|
| 21 |
Normalize((0.48145466, 0.4578275, 0.40821073),
|
| 22 |
(0.26862954, 0.26130258, 0.27577711)),
|
| 23 |
-
# Normalize((0.5, 0.5, 0.5),
|
| 24 |
-
# (0.5, 0.5, 0.5)),
|
| 25 |
])
|
| 26 |
|
| 27 |
class Image_dataset(data.Dataset):
|
|
|
|
| 16 |
Resize(n_px, interpolation=BICUBIC),
|
| 17 |
CenterCrop(n_px),
|
| 18 |
ToTensor(),
|
|
|
|
|
|
|
| 19 |
Normalize((0.48145466, 0.4578275, 0.40821073),
|
| 20 |
(0.26862954, 0.26130258, 0.27577711)),
|
|
|
|
|
|
|
| 21 |
])
|
| 22 |
|
| 23 |
class Image_dataset(data.Dataset):
|