Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload 22 files
Browse files- .gitattributes +14 -0
 - dataloaders/paired_dataset.py +95 -0
 - dataloaders/params_realesrgan.yml +43 -0
 - dataloaders/realesrgan.py +303 -0
 - dataloaders/simple_dataset.py +156 -0
 - figs/bird1.png +3 -0
 - figs/building.png +3 -0
 - figs/data_real.png +3 -0
 - figs/data_real_sup.jpg +3 -0
 - figs/data_real_suppl.jpg +3 -0
 - figs/data_real_suppl.png +3 -0
 - figs/data_syn.png +3 -0
 - figs/figs.md +1 -0
 - figs/framework.png +3 -0
 - figs/gradio.png +0 -0
 - figs/ground.jpg +0 -0
 - figs/logo1.png +0 -0
 - figs/nature.png +3 -0
 - figs/person1.png +3 -0
 - figs/turbo_steps02_building.png +3 -0
 - figs/turbo_steps02_frog.png +3 -0
 - figs/turbo_steps04_building.png +3 -0
 - figs/turbo_steps04_frog.png +3 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            figs/bird1.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 37 | 
         
            +
            figs/building.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 38 | 
         
            +
            figs/data_real_sup.jpg filter=lfs diff=lfs merge=lfs -text
         
     | 
| 39 | 
         
            +
            figs/data_real_suppl.jpg filter=lfs diff=lfs merge=lfs -text
         
     | 
| 40 | 
         
            +
            figs/data_real_suppl.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 41 | 
         
            +
            figs/data_real.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 42 | 
         
            +
            figs/data_syn.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 43 | 
         
            +
            figs/framework.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 44 | 
         
            +
            figs/nature.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 45 | 
         
            +
            figs/person1.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 46 | 
         
            +
            figs/turbo_steps02_building.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 47 | 
         
            +
            figs/turbo_steps02_frog.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 48 | 
         
            +
            figs/turbo_steps04_building.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 49 | 
         
            +
            figs/turbo_steps04_frog.png filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        dataloaders/paired_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,95 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import glob
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from PIL import Image
         
     | 
| 4 | 
         
            +
            import random
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from torch import nn
         
     | 
| 8 | 
         
            +
            from torchvision import transforms
         
     | 
| 9 | 
         
            +
            from torch.utils import data as data
         
     | 
| 10 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from .realesrgan import RealESRGAN_degradation
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class PairedCaptionDataset(data.Dataset):
         
     | 
| 15 | 
         
            +
                def __init__(
         
     | 
| 16 | 
         
            +
                        self,
         
     | 
| 17 | 
         
            +
                        root_folders=None,
         
     | 
| 18 | 
         
            +
                        tokenizer=None,
         
     | 
| 19 | 
         
            +
                        null_text_ratio=0.5,
         
     | 
| 20 | 
         
            +
                        # use_ram_encoder=False,
         
     | 
| 21 | 
         
            +
                        # use_gt_caption=False,
         
     | 
| 22 | 
         
            +
                        # caption_type = 'gt_caption',
         
     | 
| 23 | 
         
            +
                ):
         
     | 
| 24 | 
         
            +
                    super(PairedCaptionDataset, self).__init__()
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    self.null_text_ratio = null_text_ratio
         
     | 
| 27 | 
         
            +
                    self.lr_list = []
         
     | 
| 28 | 
         
            +
                    self.gt_list = []
         
     | 
| 29 | 
         
            +
                    self.tag_path_list = []
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    root_folders = root_folders.split(',')
         
     | 
| 32 | 
         
            +
                    for root_folder in root_folders:
         
     | 
| 33 | 
         
            +
                        lr_path = root_folder +'/sr_bicubic'
         
     | 
| 34 | 
         
            +
                        tag_path = root_folder +'/tag'
         
     | 
| 35 | 
         
            +
                        gt_path = root_folder +'/gt'
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                        self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
         
     | 
| 38 | 
         
            +
                        self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
         
     | 
| 39 | 
         
            +
                        self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    assert len(self.lr_list) == len(self.gt_list)
         
     | 
| 43 | 
         
            +
                    assert len(self.lr_list) == len(self.tag_path_list)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    self.img_preproc = transforms.Compose([       
         
     | 
| 46 | 
         
            +
                        transforms.ToTensor(),
         
     | 
| 47 | 
         
            +
                    ])
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    ram_mean = [0.485, 0.456, 0.406]
         
     | 
| 50 | 
         
            +
                    ram_std = [0.229, 0.224, 0.225]
         
     | 
| 51 | 
         
            +
                    self.ram_normalize = transforms.Normalize(mean=ram_mean, std=ram_std)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def tokenize_caption(self, caption=""):
         
     | 
| 56 | 
         
            +
                    inputs = self.tokenizer(
         
     | 
| 57 | 
         
            +
                        caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
         
     | 
| 58 | 
         
            +
                    )
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    return inputs.input_ids
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                   
         
     | 
| 65 | 
         
            +
                    gt_path = self.gt_list[index]
         
     | 
| 66 | 
         
            +
                    gt_img = Image.open(gt_path).convert('RGB')
         
     | 
| 67 | 
         
            +
                    gt_img = self.img_preproc(gt_img)
         
     | 
| 68 | 
         
            +
                    
         
     | 
| 69 | 
         
            +
                    lq_path = self.lr_list[index]
         
     | 
| 70 | 
         
            +
                    lq_img = Image.open(lq_path).convert('RGB')
         
     | 
| 71 | 
         
            +
                    lq_img = self.img_preproc(lq_img)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    if random.random() < self.null_text_ratio:
         
     | 
| 74 | 
         
            +
                        tag = ''
         
     | 
| 75 | 
         
            +
                    else:
         
     | 
| 76 | 
         
            +
                        tag_path = self.tag_path_list[index]
         
     | 
| 77 | 
         
            +
                        file = open(tag_path, 'r')
         
     | 
| 78 | 
         
            +
                        tag = file.read()
         
     | 
| 79 | 
         
            +
                        file.close()
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    example = dict()
         
     | 
| 82 | 
         
            +
                    example["conditioning_pixel_values"] = lq_img.squeeze(0)
         
     | 
| 83 | 
         
            +
                    example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0
         
     | 
| 84 | 
         
            +
                    example["input_ids"] = self.tokenize_caption(caption=tag).squeeze(0)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    lq_img = lq_img.squeeze()
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    ram_values = F.interpolate(lq_img.unsqueeze(0), size=(384, 384), mode='bicubic')
         
     | 
| 89 | 
         
            +
                    ram_values = ram_values.clamp(0.0, 1.0)
         
     | 
| 90 | 
         
            +
                    example["ram_values"] = self.ram_normalize(ram_values.squeeze(0))
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    return example
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                def __len__(self):
         
     | 
| 95 | 
         
            +
                    return len(self.gt_list)
         
     | 
    	
        dataloaders/params_realesrgan.yml
    ADDED
    
    | 
         @@ -0,0 +1,43 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            scale: 4
         
     | 
| 2 | 
         
            +
            color_jitter_prob: 0.0
         
     | 
| 3 | 
         
            +
            gray_prob: 0.0
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # the first degradation process
         
     | 
| 6 | 
         
            +
            resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
         
     | 
| 7 | 
         
            +
            resize_range: [0.3, 1.5]
         
     | 
| 8 | 
         
            +
            gaussian_noise_prob: 0.5
         
     | 
| 9 | 
         
            +
            noise_range: [1, 15]
         
     | 
| 10 | 
         
            +
            poisson_scale_range: [0.05, 2.0]
         
     | 
| 11 | 
         
            +
            gray_noise_prob: 0.4
         
     | 
| 12 | 
         
            +
            jpeg_range: [60, 95]
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # the second degradation process
         
     | 
| 15 | 
         
            +
            second_blur_prob: 0.5
         
     | 
| 16 | 
         
            +
            resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
         
     | 
| 17 | 
         
            +
            resize_range2: [0.6, 1.2]
         
     | 
| 18 | 
         
            +
            gaussian_noise_prob2: 0.5
         
     | 
| 19 | 
         
            +
            noise_range2: [1, 12]
         
     | 
| 20 | 
         
            +
            poisson_scale_range2: [0.05, 1.0]
         
     | 
| 21 | 
         
            +
            gray_noise_prob2: 0.4
         
     | 
| 22 | 
         
            +
            jpeg_range2: [60, 100]
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            kernel_info:
         
     | 
| 25 | 
         
            +
                blur_kernel_size: 21
         
     | 
| 26 | 
         
            +
                kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
         
     | 
| 27 | 
         
            +
                kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
         
     | 
| 28 | 
         
            +
                sinc_prob: 0.1
         
     | 
| 29 | 
         
            +
                blur_sigma: [0.2, 3]
         
     | 
| 30 | 
         
            +
                betag_range: [0.5, 4]
         
     | 
| 31 | 
         
            +
                betap_range: [1, 2]
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                blur_kernel_size2: 21
         
     | 
| 34 | 
         
            +
                kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
         
     | 
| 35 | 
         
            +
                kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
         
     | 
| 36 | 
         
            +
                sinc_prob2: 0.1
         
     | 
| 37 | 
         
            +
                blur_sigma2: [0.2, 1.5]
         
     | 
| 38 | 
         
            +
                betag_range2: [0.5, 4]
         
     | 
| 39 | 
         
            +
                betap_range2: [1, 2]
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                final_sinc_prob: 0.8
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
    	
        dataloaders/realesrgan.py
    ADDED
    
    | 
         @@ -0,0 +1,303 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import cv2
         
     | 
| 4 | 
         
            +
            import glob
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            import yaml
         
     | 
| 7 | 
         
            +
            import random
         
     | 
| 8 | 
         
            +
            from collections import OrderedDict
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from basicsr.data.transforms import augment
         
     | 
| 13 | 
         
            +
            from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
         
     | 
| 14 | 
         
            +
            from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img
         
     | 
| 15 | 
         
            +
            from basicsr.utils.img_process_util import filter2D
         
     | 
| 16 | 
         
            +
            from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
         
     | 
| 17 | 
         
            +
            from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
         
     | 
| 18 | 
         
            +
                                                           normalize, rgb_to_grayscale)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            cur_path = os.path.dirname(os.path.abspath(__file__))
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def ordered_yaml():
         
     | 
| 24 | 
         
            +
                """Support OrderedDict for yaml.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                Returns:
         
     | 
| 27 | 
         
            +
                    yaml Loader and Dumper.
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                try:
         
     | 
| 30 | 
         
            +
                    from yaml import CDumper as Dumper
         
     | 
| 31 | 
         
            +
                    from yaml import CLoader as Loader
         
     | 
| 32 | 
         
            +
                except ImportError:
         
     | 
| 33 | 
         
            +
                    from yaml import Dumper, Loader
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def dict_representer(dumper, data):
         
     | 
| 38 | 
         
            +
                    return dumper.represent_dict(data.items())
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def dict_constructor(loader, node):
         
     | 
| 41 | 
         
            +
                    return OrderedDict(loader.construct_pairs(node))
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                Dumper.add_representer(OrderedDict, dict_representer)
         
     | 
| 44 | 
         
            +
                Loader.add_constructor(_mapping_tag, dict_constructor)
         
     | 
| 45 | 
         
            +
                return Loader, Dumper
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def opt_parse(opt_path):
         
     | 
| 48 | 
         
            +
                with open(opt_path, mode='r') as f:
         
     | 
| 49 | 
         
            +
                    Loader, _ = ordered_yaml()
         
     | 
| 50 | 
         
            +
                    opt = yaml.load(f, Loader=Loader) 
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                return opt
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            class RealESRGAN_degradation(object):
         
     | 
| 55 | 
         
            +
                def __init__(self, opt_path='', device='cpu'):
         
     | 
| 56 | 
         
            +
                    self.opt = opt_parse(opt_path)
         
     | 
| 57 | 
         
            +
                    self.device = device #torch.device('cpu')
         
     | 
| 58 | 
         
            +
                    optk = self.opt['kernel_info']       
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    # blur settings for the first degradation
         
     | 
| 61 | 
         
            +
                    self.blur_kernel_size = optk['blur_kernel_size']
         
     | 
| 62 | 
         
            +
                    self.kernel_list = optk['kernel_list']
         
     | 
| 63 | 
         
            +
                    self.kernel_prob = optk['kernel_prob']
         
     | 
| 64 | 
         
            +
                    self.blur_sigma = optk['blur_sigma']
         
     | 
| 65 | 
         
            +
                    self.betag_range = optk['betag_range']
         
     | 
| 66 | 
         
            +
                    self.betap_range = optk['betap_range']
         
     | 
| 67 | 
         
            +
                    self.sinc_prob = optk['sinc_prob']
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    # blur settings for the second degradation
         
     | 
| 70 | 
         
            +
                    self.blur_kernel_size2 = optk['blur_kernel_size2']
         
     | 
| 71 | 
         
            +
                    self.kernel_list2 = optk['kernel_list2']
         
     | 
| 72 | 
         
            +
                    self.kernel_prob2 = optk['kernel_prob2']
         
     | 
| 73 | 
         
            +
                    self.blur_sigma2 = optk['blur_sigma2']
         
     | 
| 74 | 
         
            +
                    self.betag_range2 = optk['betag_range2']
         
     | 
| 75 | 
         
            +
                    self.betap_range2 = optk['betap_range2']
         
     | 
| 76 | 
         
            +
                    self.sinc_prob2 = optk['sinc_prob2']
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    # a final sinc filter
         
     | 
| 79 | 
         
            +
                    self.final_sinc_prob = optk['final_sinc_prob']
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    self.kernel_range = [2 * v + 1 for v in range(3, 11)]  # kernel size ranges from 7 to 21
         
     | 
| 82 | 
         
            +
                    self.pulse_tensor = torch.zeros(21, 21).float()  # convolving with pulse tensor brings no blurry effect
         
     | 
| 83 | 
         
            +
                    self.pulse_tensor[10, 10] = 1
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    self.jpeger = DiffJPEG(differentiable=False).to(self.device)
         
     | 
| 86 | 
         
            +
                    self.usm_shaper = USMSharp().to(self.device)
         
     | 
| 87 | 
         
            +
                
         
     | 
| 88 | 
         
            +
                def color_jitter_pt(self, img, brightness, contrast, saturation, hue):
         
     | 
| 89 | 
         
            +
                    fn_idx = torch.randperm(4)
         
     | 
| 90 | 
         
            +
                    for fn_id in fn_idx:
         
     | 
| 91 | 
         
            +
                        if fn_id == 0 and brightness is not None:
         
     | 
| 92 | 
         
            +
                            brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
         
     | 
| 93 | 
         
            +
                            img = adjust_brightness(img, brightness_factor)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                        if fn_id == 1 and contrast is not None:
         
     | 
| 96 | 
         
            +
                            contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
         
     | 
| 97 | 
         
            +
                            img = adjust_contrast(img, contrast_factor)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                        if fn_id == 2 and saturation is not None:
         
     | 
| 100 | 
         
            +
                            saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
         
     | 
| 101 | 
         
            +
                            img = adjust_saturation(img, saturation_factor)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                        if fn_id == 3 and hue is not None:
         
     | 
| 104 | 
         
            +
                            hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
         
     | 
| 105 | 
         
            +
                            img = adjust_hue(img, hue_factor)
         
     | 
| 106 | 
         
            +
                    return img
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def random_augment(self, img_gt):
         
     | 
| 109 | 
         
            +
                    # random horizontal flip
         
     | 
| 110 | 
         
            +
                    img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True)
         
     | 
| 111 | 
         
            +
                    """
         
     | 
| 112 | 
         
            +
                    # random color jitter 
         
     | 
| 113 | 
         
            +
                    if np.random.uniform() < self.opt['color_jitter_prob']:
         
     | 
| 114 | 
         
            +
                        jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
         
     | 
| 115 | 
         
            +
                        img_gt = img_gt + jitter_val
         
     | 
| 116 | 
         
            +
                        img_gt = np.clip(img_gt, 0, 1)    
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    # random grayscale
         
     | 
| 119 | 
         
            +
                    if np.random.uniform() < self.opt['gray_prob']:
         
     | 
| 120 | 
         
            +
                        #img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
         
     | 
| 121 | 
         
            +
                        img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY)
         
     | 
| 122 | 
         
            +
                        img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
         
     | 
| 123 | 
         
            +
                    """
         
     | 
| 124 | 
         
            +
                    # BGR to RGB, HWC to CHW, numpy to tensor
         
     | 
| 125 | 
         
            +
                    img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    return img_gt
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                def random_kernels(self):
         
     | 
| 130 | 
         
            +
                    # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
         
     | 
| 131 | 
         
            +
                    kernel_size = random.choice(self.kernel_range)
         
     | 
| 132 | 
         
            +
                    if np.random.uniform() < self.sinc_prob:
         
     | 
| 133 | 
         
            +
                        # this sinc filter setting is for kernels ranging from [7, 21]
         
     | 
| 134 | 
         
            +
                        if kernel_size < 13:
         
     | 
| 135 | 
         
            +
                            omega_c = np.random.uniform(np.pi / 3, np.pi)
         
     | 
| 136 | 
         
            +
                        else:
         
     | 
| 137 | 
         
            +
                            omega_c = np.random.uniform(np.pi / 5, np.pi)
         
     | 
| 138 | 
         
            +
                        kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
         
     | 
| 139 | 
         
            +
                    else:
         
     | 
| 140 | 
         
            +
                        kernel = random_mixed_kernels(
         
     | 
| 141 | 
         
            +
                                self.kernel_list,
         
     | 
| 142 | 
         
            +
                                self.kernel_prob,
         
     | 
| 143 | 
         
            +
                                kernel_size,
         
     | 
| 144 | 
         
            +
                                self.blur_sigma,
         
     | 
| 145 | 
         
            +
                                self.blur_sigma, [-math.pi, math.pi],
         
     | 
| 146 | 
         
            +
                                self.betag_range,
         
     | 
| 147 | 
         
            +
                                self.betap_range,
         
     | 
| 148 | 
         
            +
                                noise_range=None)
         
     | 
| 149 | 
         
            +
                    # pad kernel
         
     | 
| 150 | 
         
            +
                    pad_size = (21 - kernel_size) // 2
         
     | 
| 151 | 
         
            +
                    kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
         
     | 
| 154 | 
         
            +
                    kernel_size = random.choice(self.kernel_range)
         
     | 
| 155 | 
         
            +
                    if np.random.uniform() < self.sinc_prob2:
         
     | 
| 156 | 
         
            +
                        if kernel_size < 13:
         
     | 
| 157 | 
         
            +
                            omega_c = np.random.uniform(np.pi / 3, np.pi)
         
     | 
| 158 | 
         
            +
                        else:
         
     | 
| 159 | 
         
            +
                            omega_c = np.random.uniform(np.pi / 5, np.pi)
         
     | 
| 160 | 
         
            +
                        kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
         
     | 
| 161 | 
         
            +
                    else:
         
     | 
| 162 | 
         
            +
                        kernel2 = random_mixed_kernels(
         
     | 
| 163 | 
         
            +
                            self.kernel_list2,
         
     | 
| 164 | 
         
            +
                            self.kernel_prob2,
         
     | 
| 165 | 
         
            +
                            kernel_size,
         
     | 
| 166 | 
         
            +
                            self.blur_sigma2,
         
     | 
| 167 | 
         
            +
                            self.blur_sigma2, [-math.pi, math.pi],
         
     | 
| 168 | 
         
            +
                            self.betag_range2,
         
     | 
| 169 | 
         
            +
                            self.betap_range2,
         
     | 
| 170 | 
         
            +
                            noise_range=None)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    # pad kernel
         
     | 
| 173 | 
         
            +
                    pad_size = (21 - kernel_size) // 2
         
     | 
| 174 | 
         
            +
                    kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    # ------------------------------------- sinc kernel ------------------------------------- #
         
     | 
| 177 | 
         
            +
                    if np.random.uniform() < self.final_sinc_prob:
         
     | 
| 178 | 
         
            +
                        kernel_size = random.choice(self.kernel_range)
         
     | 
| 179 | 
         
            +
                        omega_c = np.random.uniform(np.pi / 3, np.pi)
         
     | 
| 180 | 
         
            +
                        sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
         
     | 
| 181 | 
         
            +
                        sinc_kernel = torch.FloatTensor(sinc_kernel)
         
     | 
| 182 | 
         
            +
                    else:
         
     | 
| 183 | 
         
            +
                        sinc_kernel = self.pulse_tensor
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    kernel = torch.FloatTensor(kernel)
         
     | 
| 186 | 
         
            +
                    kernel2 = torch.FloatTensor(kernel2) 
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    return kernel, kernel2, sinc_kernel
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                @torch.no_grad()
         
     | 
| 191 | 
         
            +
                def degrade_process(self, img_gt, resize_bak=False):
         
     | 
| 192 | 
         
            +
                    img_gt = self.random_augment(img_gt)
         
     | 
| 193 | 
         
            +
                    kernel1, kernel2, sinc_kernel = self.random_kernels()
         
     | 
| 194 | 
         
            +
                    img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device)
         
     | 
| 195 | 
         
            +
                    #img_gt = self.usm_shaper(img_gt) # shaper gt
         
     | 
| 196 | 
         
            +
                    ori_h, ori_w = img_gt.size()[2:4]
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    #scale_final = random.randint(4, 16)
         
     | 
| 199 | 
         
            +
                    scale_final = 4
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    # ----------------------- The first degradation process ----------------------- #
         
     | 
| 202 | 
         
            +
                    # blur
         
     | 
| 203 | 
         
            +
                    out = filter2D(img_gt, kernel1)
         
     | 
| 204 | 
         
            +
                    # random resize
         
     | 
| 205 | 
         
            +
                    updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
         
     | 
| 206 | 
         
            +
                    if updown_type == 'up':
         
     | 
| 207 | 
         
            +
                        scale = np.random.uniform(1, self.opt['resize_range'][1])
         
     | 
| 208 | 
         
            +
                    elif updown_type == 'down':
         
     | 
| 209 | 
         
            +
                        scale = np.random.uniform(self.opt['resize_range'][0], 1)
         
     | 
| 210 | 
         
            +
                    else:
         
     | 
| 211 | 
         
            +
                        scale = 1
         
     | 
| 212 | 
         
            +
                    mode = random.choice(['area', 'bilinear', 'bicubic'])
         
     | 
| 213 | 
         
            +
                    out = F.interpolate(out, scale_factor=scale, mode=mode)
         
     | 
| 214 | 
         
            +
                    # noise
         
     | 
| 215 | 
         
            +
                    gray_noise_prob = self.opt['gray_noise_prob']
         
     | 
| 216 | 
         
            +
                    if np.random.uniform() < self.opt['gaussian_noise_prob']:
         
     | 
| 217 | 
         
            +
                        out = random_add_gaussian_noise_pt(
         
     | 
| 218 | 
         
            +
                            out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
         
     | 
| 219 | 
         
            +
                    else:
         
     | 
| 220 | 
         
            +
                        out = random_add_poisson_noise_pt(
         
     | 
| 221 | 
         
            +
                            out,
         
     | 
| 222 | 
         
            +
                            scale_range=self.opt['poisson_scale_range'],
         
     | 
| 223 | 
         
            +
                            gray_prob=gray_noise_prob,
         
     | 
| 224 | 
         
            +
                            clip=True,
         
     | 
| 225 | 
         
            +
                            rounds=False)
         
     | 
| 226 | 
         
            +
                    # JPEG compression
         
     | 
| 227 | 
         
            +
                    jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
         
     | 
| 228 | 
         
            +
                    out = torch.clamp(out, 0, 1)
         
     | 
| 229 | 
         
            +
                    out = self.jpeger(out, quality=jpeg_p)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    # ----------------------- The second degradation process ----------------------- #
         
     | 
| 232 | 
         
            +
                    # blur
         
     | 
| 233 | 
         
            +
                    if np.random.uniform() < self.opt['second_blur_prob']:
         
     | 
| 234 | 
         
            +
                        out = filter2D(out, kernel2)
         
     | 
| 235 | 
         
            +
                    # random resize
         
     | 
| 236 | 
         
            +
                    updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
         
     | 
| 237 | 
         
            +
                    if updown_type == 'up':
         
     | 
| 238 | 
         
            +
                        scale = np.random.uniform(1, self.opt['resize_range2'][1])
         
     | 
| 239 | 
         
            +
                    elif updown_type == 'down':
         
     | 
| 240 | 
         
            +
                        scale = np.random.uniform(self.opt['resize_range2'][0], 1)
         
     | 
| 241 | 
         
            +
                    else:
         
     | 
| 242 | 
         
            +
                        scale = 1
         
     | 
| 243 | 
         
            +
                    mode = random.choice(['area', 'bilinear', 'bicubic'])
         
     | 
| 244 | 
         
            +
                    out = F.interpolate(
         
     | 
| 245 | 
         
            +
                        out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode)
         
     | 
| 246 | 
         
            +
                    # noise
         
     | 
| 247 | 
         
            +
                    gray_noise_prob = self.opt['gray_noise_prob2']
         
     | 
| 248 | 
         
            +
                    if np.random.uniform() < self.opt['gaussian_noise_prob2']:
         
     | 
| 249 | 
         
            +
                        out = random_add_gaussian_noise_pt(
         
     | 
| 250 | 
         
            +
                            out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
         
     | 
| 251 | 
         
            +
                    else:
         
     | 
| 252 | 
         
            +
                        out = random_add_poisson_noise_pt(
         
     | 
| 253 | 
         
            +
                            out,
         
     | 
| 254 | 
         
            +
                            scale_range=self.opt['poisson_scale_range2'],
         
     | 
| 255 | 
         
            +
                            gray_prob=gray_noise_prob,
         
     | 
| 256 | 
         
            +
                            clip=True,
         
     | 
| 257 | 
         
            +
                            rounds=False)
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    # JPEG compression + the final sinc filter
         
     | 
| 260 | 
         
            +
                    # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
         
     | 
| 261 | 
         
            +
                    # as one operation.
         
     | 
| 262 | 
         
            +
                    # We consider two orders:
         
     | 
| 263 | 
         
            +
                    #   1. [resize back + sinc filter] + JPEG compression
         
     | 
| 264 | 
         
            +
                    #   2. JPEG compression + [resize back + sinc filter]
         
     | 
| 265 | 
         
            +
                    # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
         
     | 
| 266 | 
         
            +
                    if np.random.uniform() < 0.5:
         
     | 
| 267 | 
         
            +
                        # resize back + the final sinc filter
         
     | 
| 268 | 
         
            +
                        mode = random.choice(['area', 'bilinear', 'bicubic'])
         
     | 
| 269 | 
         
            +
                        out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
         
     | 
| 270 | 
         
            +
                        out = filter2D(out, sinc_kernel)
         
     | 
| 271 | 
         
            +
                        # JPEG compression
         
     | 
| 272 | 
         
            +
                        jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
         
     | 
| 273 | 
         
            +
                        out = torch.clamp(out, 0, 1)
         
     | 
| 274 | 
         
            +
                        out = self.jpeger(out, quality=jpeg_p)
         
     | 
| 275 | 
         
            +
                    else:
         
     | 
| 276 | 
         
            +
                        # JPEG compression
         
     | 
| 277 | 
         
            +
                        jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
         
     | 
| 278 | 
         
            +
                        out = torch.clamp(out, 0, 1)
         
     | 
| 279 | 
         
            +
                        out = self.jpeger(out, quality=jpeg_p)
         
     | 
| 280 | 
         
            +
                        # resize back + the final sinc filter
         
     | 
| 281 | 
         
            +
                        mode = random.choice(['area', 'bilinear', 'bicubic'])
         
     | 
| 282 | 
         
            +
                        out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
         
     | 
| 283 | 
         
            +
                        out = filter2D(out, sinc_kernel)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    if np.random.uniform() < self.opt['gray_prob']:
         
     | 
| 286 | 
         
            +
                        out = rgb_to_grayscale(out, num_output_channels=1)
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    if np.random.uniform() < self.opt['color_jitter_prob']:
         
     | 
| 289 | 
         
            +
                        brightness = self.opt.get('brightness', (0.5, 1.5))
         
     | 
| 290 | 
         
            +
                        contrast = self.opt.get('contrast', (0.5, 1.5))
         
     | 
| 291 | 
         
            +
                        saturation = self.opt.get('saturation', (0, 1.5))
         
     | 
| 292 | 
         
            +
                        hue = self.opt.get('hue', (-0.1, 0.1))
         
     | 
| 293 | 
         
            +
                        out = self.color_jitter_pt(out, brightness, contrast, saturation, hue)
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    if resize_bak:
         
     | 
| 296 | 
         
            +
                        mode = random.choice(['area', 'bilinear', 'bicubic'])
         
     | 
| 297 | 
         
            +
                        out = F.interpolate(out, size=(ori_h, ori_w), mode=mode)
         
     | 
| 298 | 
         
            +
                    # clamp and round
         
     | 
| 299 | 
         
            +
                    img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    return img_gt, img_lq
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
             
     | 
    	
        dataloaders/simple_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,156 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import cv2
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import glob
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 6 | 
         
            +
            from torchvision import transforms
         
     | 
| 7 | 
         
            +
            import random
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import math
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
         
     | 
| 12 | 
         
            +
            from basicsr.data.transforms import augment
         
     | 
| 13 | 
         
            +
            from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from PIL import Image
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class SimpleDataset(Dataset):
         
     | 
| 20 | 
         
            +
                def __init__(self, opt, fix_size=512): 
         
     | 
| 21 | 
         
            +
                    
         
     | 
| 22 | 
         
            +
                    self.opt = opt
         
     | 
| 23 | 
         
            +
                    self.image_root = opt['gt_path']
         
     | 
| 24 | 
         
            +
                    self.fix_size = fix_size
         
     | 
| 25 | 
         
            +
                    exts = ['*.jpg', '*.png']
         
     | 
| 26 | 
         
            +
                    self.image_list = []
         
     | 
| 27 | 
         
            +
                    for image_root in self.image_root:
         
     | 
| 28 | 
         
            +
                        for ext in exts:
         
     | 
| 29 | 
         
            +
                            image_list = glob.glob(os.path.join(image_root, ext))
         
     | 
| 30 | 
         
            +
                            self.image_list += image_list
         
     | 
| 31 | 
         
            +
                            # if add lsdir dataset
         
     | 
| 32 | 
         
            +
                            image_list = glob.glob(os.path.join(image_root, '00*', ext))
         
     | 
| 33 | 
         
            +
                            self.image_list += image_list
         
     | 
| 34 | 
         
            +
                    
         
     | 
| 35 | 
         
            +
                    self.crop_preproc = transforms.Compose([
         
     | 
| 36 | 
         
            +
                        # transforms.CenterCrop(fix_size),
         
     | 
| 37 | 
         
            +
                        transforms.Resize(fix_size)
         
     | 
| 38 | 
         
            +
                        # transforms.RandomHorizontalFlip(),
         
     | 
| 39 | 
         
            +
                    ])
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.img_preproc = transforms.Compose([
         
     | 
| 42 | 
         
            +
                        transforms.ToTensor(),
         
     | 
| 43 | 
         
            +
                    ])
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    # blur settings for the first degradation
         
     | 
| 46 | 
         
            +
                    self.blur_kernel_size = opt['blur_kernel_size']
         
     | 
| 47 | 
         
            +
                    self.kernel_list = opt['kernel_list']
         
     | 
| 48 | 
         
            +
                    self.kernel_prob = opt['kernel_prob']  # a list for each kernel probability
         
     | 
| 49 | 
         
            +
                    self.blur_sigma = opt['blur_sigma']
         
     | 
| 50 | 
         
            +
                    self.betag_range = opt['betag_range']  # betag used in generalized Gaussian blur kernels
         
     | 
| 51 | 
         
            +
                    self.betap_range = opt['betap_range']  # betap used in plateau blur kernels
         
     | 
| 52 | 
         
            +
                    self.sinc_prob = opt['sinc_prob']  # the probability for sinc filters
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # blur settings for the second degradation
         
     | 
| 55 | 
         
            +
                    self.blur_kernel_size2 = opt['blur_kernel_size2']
         
     | 
| 56 | 
         
            +
                    self.kernel_list2 = opt['kernel_list2']
         
     | 
| 57 | 
         
            +
                    self.kernel_prob2 = opt['kernel_prob2']
         
     | 
| 58 | 
         
            +
                    self.blur_sigma2 = opt['blur_sigma2']
         
     | 
| 59 | 
         
            +
                    self.betag_range2 = opt['betag_range2']
         
     | 
| 60 | 
         
            +
                    self.betap_range2 = opt['betap_range2']
         
     | 
| 61 | 
         
            +
                    self.sinc_prob2 = opt['sinc_prob2']
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    # a final sinc filter
         
     | 
| 64 | 
         
            +
                    self.final_sinc_prob = opt['final_sinc_prob']
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    self.kernel_range = [2 * v + 1 for v in range(3, 11)]  # kernel size ranges from 7 to 21
         
     | 
| 67 | 
         
            +
                    # TODO: kernel range is now hard-coded, should be in the configure file
         
     | 
| 68 | 
         
            +
                    self.pulse_tensor = torch.zeros(21, 21).float()  # convolving with pulse tensor brings no blurry effect
         
     | 
| 69 | 
         
            +
                    self.pulse_tensor[10, 10] = 1
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    print(f'The dataset length: {len(self.image_list)}')
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 75 | 
         
            +
                    image = Image.open(self.image_list[index]).convert('RGB') 
         
     | 
| 76 | 
         
            +
                    # width, height = image.size  
         
     | 
| 77 | 
         
            +
                    # if width > height:
         
     | 
| 78 | 
         
            +
                    #     width_after = self.fix_size
         
     | 
| 79 | 
         
            +
                    #     height_after = int(height*width_after/width)
         
     | 
| 80 | 
         
            +
                    # elif height > width:
         
     | 
| 81 | 
         
            +
                    #     height_after = self.fix_size
         
     | 
| 82 | 
         
            +
                    #     width_after = int(width*height_after/height)
         
     | 
| 83 | 
         
            +
                    # elif height == width:
         
     | 
| 84 | 
         
            +
                    #     height_after = self.fix_size
         
     | 
| 85 | 
         
            +
                    #     width_after = self.fix_size
         
     | 
| 86 | 
         
            +
                    image = image.resize((self.fix_size, self.fix_size),Image.LANCZOS)
         
     | 
| 87 | 
         
            +
                    # image = self.crop_preproc(image)
         
     | 
| 88 | 
         
            +
                    image = self.img_preproc(image)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
         
     | 
| 91 | 
         
            +
                    kernel_size = random.choice(self.kernel_range)
         
     | 
| 92 | 
         
            +
                    if np.random.uniform() < self.opt['sinc_prob']:
         
     | 
| 93 | 
         
            +
                        # this sinc filter setting is for kernels ranging from [7, 21]
         
     | 
| 94 | 
         
            +
                        if kernel_size < 13:
         
     | 
| 95 | 
         
            +
                            omega_c = np.random.uniform(np.pi / 3, np.pi)
         
     | 
| 96 | 
         
            +
                        else:
         
     | 
| 97 | 
         
            +
                            omega_c = np.random.uniform(np.pi / 5, np.pi)
         
     | 
| 98 | 
         
            +
                        kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
         
     | 
| 99 | 
         
            +
                    else:
         
     | 
| 100 | 
         
            +
                        kernel = random_mixed_kernels(
         
     | 
| 101 | 
         
            +
                            self.kernel_list,
         
     | 
| 102 | 
         
            +
                            self.kernel_prob,
         
     | 
| 103 | 
         
            +
                            kernel_size,
         
     | 
| 104 | 
         
            +
                            self.blur_sigma,
         
     | 
| 105 | 
         
            +
                            self.blur_sigma, [-math.pi, math.pi],
         
     | 
| 106 | 
         
            +
                            self.betag_range,
         
     | 
| 107 | 
         
            +
                            self.betap_range,
         
     | 
| 108 | 
         
            +
                            noise_range=None)
         
     | 
| 109 | 
         
            +
                    # pad kernel
         
     | 
| 110 | 
         
            +
                    pad_size = (21 - kernel_size) // 2
         
     | 
| 111 | 
         
            +
                    kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
         
     | 
| 114 | 
         
            +
                    kernel_size = random.choice(self.kernel_range)
         
     | 
| 115 | 
         
            +
                    if np.random.uniform() < self.opt['sinc_prob2']:
         
     | 
| 116 | 
         
            +
                        if kernel_size < 13:
         
     | 
| 117 | 
         
            +
                            omega_c = np.random.uniform(np.pi / 3, np.pi)
         
     | 
| 118 | 
         
            +
                        else:
         
     | 
| 119 | 
         
            +
                            omega_c = np.random.uniform(np.pi / 5, np.pi)
         
     | 
| 120 | 
         
            +
                        kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
         
     | 
| 121 | 
         
            +
                    else:
         
     | 
| 122 | 
         
            +
                        kernel2 = random_mixed_kernels(
         
     | 
| 123 | 
         
            +
                            self.kernel_list2,
         
     | 
| 124 | 
         
            +
                            self.kernel_prob2,
         
     | 
| 125 | 
         
            +
                            kernel_size,
         
     | 
| 126 | 
         
            +
                            self.blur_sigma2,
         
     | 
| 127 | 
         
            +
                            self.blur_sigma2, [-math.pi, math.pi],
         
     | 
| 128 | 
         
            +
                            self.betag_range2,
         
     | 
| 129 | 
         
            +
                            self.betap_range2,
         
     | 
| 130 | 
         
            +
                            noise_range=None)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    # pad kernel
         
     | 
| 133 | 
         
            +
                    pad_size = (21 - kernel_size) // 2
         
     | 
| 134 | 
         
            +
                    kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    # ------------------------------------- the final sinc kernel ------------------------------------- #
         
     | 
| 137 | 
         
            +
                    if np.random.uniform() < self.opt['final_sinc_prob']:
         
     | 
| 138 | 
         
            +
                        kernel_size = random.choice(self.kernel_range)
         
     | 
| 139 | 
         
            +
                        omega_c = np.random.uniform(np.pi / 3, np.pi)
         
     | 
| 140 | 
         
            +
                        sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
         
     | 
| 141 | 
         
            +
                        sinc_kernel = torch.FloatTensor(sinc_kernel)
         
     | 
| 142 | 
         
            +
                    else:
         
     | 
| 143 | 
         
            +
                        sinc_kernel = self.pulse_tensor
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    # BGR to RGB, HWC to CHW, numpy to tensor
         
     | 
| 146 | 
         
            +
                    # img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
         
     | 
| 147 | 
         
            +
                    kernel = torch.FloatTensor(kernel)
         
     | 
| 148 | 
         
            +
                    kernel2 = torch.FloatTensor(kernel2)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    return_d = {'gt': image, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'lq_path': self.image_list[index]}
         
     | 
| 151 | 
         
            +
                    return return_d
         
     | 
| 152 | 
         
            +
                    
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def __len__(self):
         
     | 
| 155 | 
         
            +
                    return len(self.image_list)
         
     | 
| 156 | 
         
            +
                    
         
     | 
    	
        figs/bird1.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/building.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/data_real.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/data_real_sup.jpg
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/data_real_suppl.jpg
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/data_real_suppl.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/data_syn.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/figs.md
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
    	
        figs/framework.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/gradio.png
    ADDED
    
    
											 
									 | 
									
								
    	
        figs/ground.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        figs/logo1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        figs/nature.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/person1.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/turbo_steps02_building.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/turbo_steps02_frog.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/turbo_steps04_building.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        figs/turbo_steps04_frog.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  |