| import torch | |
| import torch.nn as nn | |
| from torch.utils import data | |
| import os | |
| from PIL import Image | |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
| try: | |
| from torchvision.transforms import InterpolationMode | |
| BICUBIC = InterpolationMode.BICUBIC | |
| except ImportError: | |
| BICUBIC = Image.BICUBIC | |
| import glob | |
| def image_transform(n_px): | |
| return Compose([ | |
| Resize(n_px, interpolation=BICUBIC), | |
| CenterCrop(n_px), | |
| ToTensor(), | |
| Normalize((0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| class Image_dataset(data.Dataset): | |
| def __init__(self,dataset_folder="/data1/haolin/datasets",categories=['03001627'],n_px=224): | |
| self.dataset_folder=dataset_folder | |
| self.image_folder=os.path.join(self.dataset_folder,'other_data') | |
| self.preprocess=image_transform(n_px) | |
| self.image_path=[] | |
| for cat in categories: | |
| subpath=os.path.join(self.image_folder,cat,"6_images") | |
| model_list=os.listdir(subpath) | |
| for folder in model_list: | |
| model_folder=os.path.join(subpath,folder) | |
| image_list=os.listdir(model_folder) | |
| for image_filename in image_list: | |
| image_filepath=os.path.join(model_folder,image_filename) | |
| self.image_path.append(image_filepath) | |
| def __len__(self): | |
| return len(self.image_path) | |
| def __getitem__(self,index): | |
| path=self.image_path[index] | |
| basename=os.path.basename(path)[:-4] | |
| model_id=path.split(os.sep)[-2] | |
| category=path.split(os.sep)[-4] | |
| image=Image.open(path) | |
| image_tensor=self.preprocess(image) | |
| return {"images":image_tensor,"image_name":basename,"model_id":model_id,"category":category} | |
| class Image_InTheWild_dataset(data.Dataset): | |
| def __init__(self,dataset_dir="/data1/haolin/data/real_scene_process_data",scene_id="letian-310",n_px=224): | |
| self.dataset_dir=dataset_dir | |
| self.preprocess = image_transform(n_px) | |
| self.image_path = [] | |
| if scene_id=="all": | |
| scene_list=os.listdir(self.dataset_dir) | |
| for id in scene_list: | |
| image_folder=os.path.join(self.dataset_dir,id,"6_images") | |
| self.image_path+=glob.glob(image_folder+"/*/*jpg") | |
| else: | |
| image_folder = os.path.join(self.dataset_dir, scene_id, "6_images") | |
| self.image_path += glob.glob(image_folder + "/*/*jpg") | |
| def __len__(self): | |
| return len(self.image_path) | |
| def __getitem__(self,index): | |
| path=self.image_path[index] | |
| basename=os.path.basename(path)[:-4] | |
| model_id=path.split(os.sep)[-2] | |
| scene_id=path.split(os.sep)[-4] | |
| image=Image.open(path) | |
| image_tensor=self.preprocess(image) | |
| return {"images":image_tensor,"image_name":basename,"model_id":model_id,"scene_id":scene_id} | |