Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						64bf706
	
1
								Parent(s):
							
							3aaab28
								
Add model and infrance app
Browse files- VARtext_v1.pth +3 -0
- app.py +236 -4
- dist.py +211 -0
- models/__init__.py +39 -0
- models/basic_vae.py +226 -0
- models/basic_var.py +174 -0
- models/helpers.py +59 -0
- models/quant.py +281 -0
- models/var.py +360 -0
- models/vqvae.py +95 -0
- utils/amp_sc.py +89 -0
- utils/arg_util.py +284 -0
- utils/data.py +54 -0
- utils/data_sampler.py +103 -0
- utils/lr_control.py +108 -0
- utils/misc.py +381 -0
    	
        VARtext_v1.pth
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:bbaa03cee25cb0abba7ac5d476f6b800b78dda29c6cb2773a11b584022585fcf
         | 
| 3 | 
            +
            size 1963751390
         | 
    	
        app.py
    CHANGED
    
    | @@ -1,7 +1,239 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
|  | |
|  | |
| 2 |  | 
| 3 | 
            -
             | 
| 4 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from models import VQVAE, build_vae_var
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from transformers import AutoTokenizer, SiglipTextModel
         | 
| 6 | 
            +
            from peft import LoraConfig, get_peft_model
         | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            from torchvision.transforms import ToPILImage
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from moviepy.editor import ImageSequenceClip
         | 
| 11 | 
            +
            import random
         | 
| 12 | 
             
            import gradio as gr
         | 
| 13 | 
            +
            import tempfile
         | 
| 14 | 
            +
            import os
         | 
| 15 |  | 
| 16 | 
            +
            class SimpleAdapter(nn.Module):
         | 
| 17 | 
            +
                def __init__(self, input_dim=512, hidden_dim=1024, out_dim=1024):
         | 
| 18 | 
            +
                    super(SimpleAdapter, self).__init__()
         | 
| 19 | 
            +
                    self.layer1 = nn.Linear(input_dim, hidden_dim)
         | 
| 20 | 
            +
                    self.norm0 = nn.LayerNorm(input_dim)
         | 
| 21 | 
            +
                    self.activation1 = nn.GELU()
         | 
| 22 | 
            +
                    self.layer2 = nn.Linear(hidden_dim, out_dim)
         | 
| 23 | 
            +
                    self.norm2 = nn.LayerNorm(out_dim)
         | 
| 24 | 
            +
                    self._initialize_weights()
         | 
| 25 |  | 
| 26 | 
            +
                def _initialize_weights(self):
         | 
| 27 | 
            +
                    for m in self.modules():
         | 
| 28 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 29 | 
            +
                            nn.init.xavier_uniform_(m.weight, gain=0.001)
         | 
| 30 | 
            +
                            nn.init.zeros_(m.bias)
         | 
| 31 | 
            +
                        elif isinstance(m, nn.LayerNorm):
         | 
| 32 | 
            +
                            nn.init.ones_(m.weight)
         | 
| 33 | 
            +
                            nn.init.zeros_(m.bias)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def forward(self, x):
         | 
| 36 | 
            +
                    x = self.norm0(x)
         | 
| 37 | 
            +
                    x = self.layer1(x)
         | 
| 38 | 
            +
                    x = self.activation1(x)
         | 
| 39 | 
            +
                    x = self.layer2(x)
         | 
| 40 | 
            +
                    x = self.norm2(x)
         | 
| 41 | 
            +
                    return x
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            class InrenceTextVAR(nn.Module):
         | 
| 44 | 
            +
                def __init__(self, pl_checkpoint=None, start_class_id=578, hugging_face_token=None, siglip_model='google/siglip-base-patch16-224', device="cpu", MODEL_DEPTH=16):
         | 
| 45 | 
            +
                    super(InrenceTextVAR, self).__init__()
         | 
| 46 | 
            +
                    self.device = device
         | 
| 47 | 
            +
                    self.class_id = start_class_id
         | 
| 48 | 
            +
                    # Define layers
         | 
| 49 | 
            +
                    patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
         | 
| 50 | 
            +
                    self.vae, self.var = build_vae_var(
         | 
| 51 | 
            +
                        V=4096, Cvae=32, ch=160, share_quant_resi=4,
         | 
| 52 | 
            +
                        device=device, patch_nums=patch_nums,
         | 
| 53 | 
            +
                        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    self.text_processor = AutoTokenizer.from_pretrained(siglip_model, token=hugging_face_token)
         | 
| 56 | 
            +
                    self.siglip_text_encoder = SiglipTextModel.from_pretrained(siglip_model, token=hugging_face_token).to(device)
         | 
| 57 | 
            +
                    self.adapter = SimpleAdapter(
         | 
| 58 | 
            +
                        input_dim=self.siglip_text_encoder.config.hidden_size,
         | 
| 59 | 
            +
                        out_dim=self.var.C  # Ensure dimensional consistency
         | 
| 60 | 
            +
                    ).to(device)
         | 
| 61 | 
            +
                    self.apply_lora_to_var()
         | 
| 62 | 
            +
                    if pl_checkpoint is not None:
         | 
| 63 | 
            +
                        state_dict = torch.load(pl_checkpoint, map_location="cpu")['state_dict']
         | 
| 64 | 
            +
                        var_state_dict = {k[len('var.'):]: v for k, v in state_dict.items() if k.startswith('var.')}
         | 
| 65 | 
            +
                        vae_state_dict = {k[len('vae.'):]: v for k, v in state_dict.items() if k.startswith('vae.')}
         | 
| 66 | 
            +
                        adapter_state_dict = {k[len('adapter.'):]: v for k, v in state_dict.items() if k.startswith('adapter.')}
         | 
| 67 | 
            +
                        self.var.load_state_dict(var_state_dict)
         | 
| 68 | 
            +
                        self.vae.load_state_dict(vae_state_dict)
         | 
| 69 | 
            +
                        self.adapter.load_state_dict(adapter_state_dict)
         | 
| 70 | 
            +
                    del self.vae.encoder
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def apply_lora_to_var(self):
         | 
| 73 | 
            +
                    """
         | 
| 74 | 
            +
                    Applies LoRA (Low-Rank Adaptation) to the VAR model.
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    def find_linear_module_names(model):
         | 
| 77 | 
            +
                        linear_module_names = []
         | 
| 78 | 
            +
                        for name, module in model.named_modules():
         | 
| 79 | 
            +
                            if isinstance(module, nn.Linear):
         | 
| 80 | 
            +
                                linear_module_names.append(name)
         | 
| 81 | 
            +
                        return linear_module_names
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    linear_module_names = find_linear_module_names(self.var)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    lora_config = LoraConfig(
         | 
| 86 | 
            +
                        r=8,
         | 
| 87 | 
            +
                        lora_alpha=32,
         | 
| 88 | 
            +
                        target_modules=linear_module_names,
         | 
| 89 | 
            +
                        lora_dropout=0.05,
         | 
| 90 | 
            +
                        bias="none",
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    self.var = get_peft_model(self.var, lora_config)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                @torch.no_grad()
         | 
| 96 | 
            +
                def generate_image(self, text, beta=1, seed=None, more_smooth=False, top_k=0, top_p=0.9):
         | 
| 97 | 
            +
                    if seed is None:
         | 
| 98 | 
            +
                        seed = random.randint(0, 2**32 - 1)
         | 
| 99 | 
            +
                    inputs = self.text_processor([text], padding="max_length", return_tensors="pt").to(self.device)
         | 
| 100 | 
            +
                    outputs = self.siglip_text_encoder(**inputs)
         | 
| 101 | 
            +
                    pooled_output = outputs.pooler_output  # pooled (EOS token) states
         | 
| 102 | 
            +
                    pooled_output = F.normalize(pooled_output, p=2, dim=-1)  # Normalize delta condition
         | 
| 103 | 
            +
                    cond_delta = F.normalize(pooled_output, p=2, dim=-1).to(self.device)  # Use correct device
         | 
| 104 | 
            +
                    cond_delta = self.adapter(cond_delta)
         | 
| 105 | 
            +
                    cond_delta = F.normalize(cond_delta, p=2, dim=-1)  # Normalize delta condition
         | 
| 106 | 
            +
                    generated_images = self.var.autoregressive_infer_cfg(
         | 
| 107 | 
            +
                        B=1,
         | 
| 108 | 
            +
                        label_B=self.class_id,
         | 
| 109 | 
            +
                        delta_condition=cond_delta[:1],
         | 
| 110 | 
            +
                        beta=beta,
         | 
| 111 | 
            +
                        alpha=1,
         | 
| 112 | 
            +
                        top_k=top_k,
         | 
| 113 | 
            +
                        top_p=top_p,
         | 
| 114 | 
            +
                        more_smooth=more_smooth,
         | 
| 115 | 
            +
                        g_seed=seed
         | 
| 116 | 
            +
                    )
         | 
| 117 | 
            +
                    image = ToPILImage()(generated_images[0].cpu())
         | 
| 118 | 
            +
                    return image
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                @torch.no_grad()
         | 
| 121 | 
            +
                def generate_video(self, text, start_beta, target_beta, fps, length, top_k=0, top_p=0.9, seed=None,
         | 
| 122 | 
            +
                                   more_smooth=False,
         | 
| 123 | 
            +
                                   output_filename='output_video.mp4'):
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    if seed is None:
         | 
| 126 | 
            +
                        seed = random.randint(0, 2 ** 32 - 1)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    num_frames = int(fps * length)
         | 
| 129 | 
            +
                    images = []
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    # Define an easing function for smoother interpolation
         | 
| 132 | 
            +
                    def ease_in_out(t):
         | 
| 133 | 
            +
                        return t * t * (3 - 2 * t)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # Generate t values between 0 and 1
         | 
| 136 | 
            +
                    t_values = np.linspace(0, 1, num_frames)
         | 
| 137 | 
            +
                    # Apply the easing function
         | 
| 138 | 
            +
                    eased_t_values = ease_in_out(t_values)
         | 
| 139 | 
            +
                    # Interpolate beta values using the eased t values
         | 
| 140 | 
            +
                    beta_values = start_beta + (target_beta - start_beta) * eased_t_values
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    for beta in beta_values:
         | 
| 143 | 
            +
                        image = self.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=top_k, top_p=top_p)
         | 
| 144 | 
            +
                        images.append(np.array(image))
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # Create a video from images
         | 
| 147 | 
            +
                    clip = ImageSequenceClip(images, fps=fps)
         | 
| 148 | 
            +
                    clip.write_videofile(output_filename, codec='libx264')
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            if __name__ == '__main__':
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                # Initialize the model
         | 
| 153 | 
            +
                checkpoint = 'VARtext_v1.pth'  # Replace with your actual checkpoint path
         | 
| 154 | 
            +
                device = 'cpu' if not torch.cuda.is_available() else 'cuda'
         | 
| 155 | 
            +
                state_dict = torch.load(checkpoint, map_location="cpu")
         | 
| 156 | 
            +
                model = InrenceTextVAR(device=device)
         | 
| 157 | 
            +
                model.load_state_dict(state_dict)
         | 
| 158 | 
            +
                model.to(device)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
         | 
| 161 | 
            +
                    print(f"Generating image for text: {text}\n"
         | 
| 162 | 
            +
                          f"beta: {beta}\n"
         | 
| 163 | 
            +
                          f"seed: {seed}\n"
         | 
| 164 | 
            +
                          f"more_smooth: {more_smooth}\n"
         | 
| 165 | 
            +
                          f"top_k: {top_k}\n"
         | 
| 166 | 
            +
                          f"top_p: {top_p}\n")
         | 
| 167 | 
            +
                    image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=int(top_k), top_p=top_p)
         | 
| 168 | 
            +
                    return image
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def generate_video_gradio(text, start_beta=1.0, target_beta=1.0, fps=10, length=5.0, top_k=0, top_p=0.9, seed=None, more_smooth=False, progress=gr.Progress()):
         | 
| 171 | 
            +
                    print(f"Generating video for text: {text}\n"
         | 
| 172 | 
            +
                          f"start_beta: {start_beta}\n"
         | 
| 173 | 
            +
                          f"target_beta: {target_beta}\n"
         | 
| 174 | 
            +
                          f"seed: {seed}\n"
         | 
| 175 | 
            +
                          f"more_smooth: {more_smooth}\n"
         | 
| 176 | 
            +
                          f"top_k: {top_k}\n"
         | 
| 177 | 
            +
                          f"top_p: {top_p}"
         | 
| 178 | 
            +
                          f"fps: {fps}\n"
         | 
| 179 | 
            +
                          f"length: {length}\n")
         | 
| 180 | 
            +
                    with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmpfile:
         | 
| 181 | 
            +
                        output_filename = tmpfile.name
         | 
| 182 | 
            +
                    num_frames = int(fps * length)
         | 
| 183 | 
            +
                    beta_values = np.linspace(start_beta, target_beta, num_frames)
         | 
| 184 | 
            +
                    images = []
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    for i, beta in enumerate(beta_values):
         | 
| 187 | 
            +
                        image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=top_k, top_p=top_p)
         | 
| 188 | 
            +
                        images.append(np.array(image))
         | 
| 189 | 
            +
                        # Update progress
         | 
| 190 | 
            +
                        progress((i + 1) / num_frames)
         | 
| 191 | 
            +
                        # Yield the frame image to update the GUI
         | 
| 192 | 
            +
                        yield image, gr.update()
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    # After generating all frames, create the video
         | 
| 195 | 
            +
                    clip = ImageSequenceClip(images, fps=fps)
         | 
| 196 | 
            +
                    clip.write_videofile(output_filename, codec='libx264')
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    # Yield the final video output
         | 
| 199 | 
            +
                    yield gr.update(), output_filename
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                with gr.Blocks() as demo:
         | 
| 202 | 
            +
                    gr.Markdown("# Text to Image/Video Generator")
         | 
| 203 | 
            +
                    with gr.Tab("Generate Image"):
         | 
| 204 | 
            +
                        text_input = gr.Textbox(label="Input Text")
         | 
| 205 | 
            +
                        beta_input = gr.Slider(label="Beta", minimum=0.0, maximum=2.5, step=0.05, value=1.0)
         | 
| 206 | 
            +
                        seed_input = gr.Number(label="Seed", value=None)
         | 
| 207 | 
            +
                        more_smooth_input = gr.Checkbox(label="More Smooth", value=False)
         | 
| 208 | 
            +
                        top_k_input = gr.Number(label="Top K", value=0)
         | 
| 209 | 
            +
                        top_p_input = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
         | 
| 210 | 
            +
                        generate_button = gr.Button("Generate Image")
         | 
| 211 | 
            +
                        image_output = gr.Image(label="Generated Image")
         | 
| 212 | 
            +
                        generate_button.click(
         | 
| 213 | 
            +
                            generate_image_gradio,
         | 
| 214 | 
            +
                            inputs=[text_input, beta_input, seed_input, more_smooth_input, top_k_input, top_p_input],
         | 
| 215 | 
            +
                            outputs=image_output
         | 
| 216 | 
            +
                        )
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    with gr.Tab("Generate Video"):
         | 
| 219 | 
            +
                        text_input_video = gr.Textbox(label="Input Text")
         | 
| 220 | 
            +
                        start_beta_input = gr.Slider(label="Start Beta", minimum=0.0, maximum=2.5, step=0.05, value=0)
         | 
| 221 | 
            +
                        target_beta_input = gr.Slider(label="Target Beta",minimum=0.0, maximum=2.5, step=0.05, value=1.0)
         | 
| 222 | 
            +
                        fps_input = gr.Number(label="FPS", value=10)
         | 
| 223 | 
            +
                        length_input = gr.Number(label="Length (seconds)", value=5.0)
         | 
| 224 | 
            +
                        seed_input_video = gr.Number(label="Seed", value=None)
         | 
| 225 | 
            +
                        more_smooth_input_video = gr.Checkbox(label="More Smooth", value=False)
         | 
| 226 | 
            +
                        top_k_input_video = gr.Number(label="Top K", value=0)
         | 
| 227 | 
            +
                        top_p_input_video = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
         | 
| 228 | 
            +
                        generate_video_button = gr.Button("Generate Video")
         | 
| 229 | 
            +
                        frame_output = gr.Image(label="Current Frame")
         | 
| 230 | 
            +
                        video_output = gr.Video(label="Generated Video")
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                        generate_video_button.click(
         | 
| 233 | 
            +
                            generate_video_gradio,
         | 
| 234 | 
            +
                            inputs=[text_input_video, start_beta_input, target_beta_input, fps_input, length_input, top_k_input_video, top_p_input_video, seed_input_video, more_smooth_input_video],
         | 
| 235 | 
            +
                            outputs=[frame_output, video_output],
         | 
| 236 | 
            +
                            queue=True  # Enable queuing to allow for progress updates
         | 
| 237 | 
            +
                        )
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                demo.launch()
         | 
    	
        dist.py
    ADDED
    
    | @@ -0,0 +1,211 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import datetime
         | 
| 2 | 
            +
            import functools
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import sys
         | 
| 5 | 
            +
            from typing import List
         | 
| 6 | 
            +
            from typing import Union
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.distributed as tdist
         | 
| 10 | 
            +
            import torch.multiprocessing as mp
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 13 | 
            +
            __initialized = False
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def initialized():
         | 
| 17 | 
            +
                return __initialized
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
         | 
| 21 | 
            +
                global __device
         | 
| 22 | 
            +
                if not torch.cuda.is_available():
         | 
| 23 | 
            +
                    print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
         | 
| 24 | 
            +
                    return
         | 
| 25 | 
            +
                elif 'RANK' not in os.environ:
         | 
| 26 | 
            +
                    torch.cuda.set_device(gpu_id_if_not_distibuted)
         | 
| 27 | 
            +
                    __device = torch.empty(1).cuda().device
         | 
| 28 | 
            +
                    print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
         | 
| 29 | 
            +
                    return
         | 
| 30 | 
            +
                # then 'RANK' must exist
         | 
| 31 | 
            +
                global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
         | 
| 32 | 
            +
                local_rank = global_rank % num_gpus
         | 
| 33 | 
            +
                torch.cuda.set_device(local_rank)
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
         | 
| 36 | 
            +
                if mp.get_start_method(allow_none=True) is None:
         | 
| 37 | 
            +
                    method = 'fork' if fork else 'spawn'
         | 
| 38 | 
            +
                    print(f'[dist initialize] mp method={method}')
         | 
| 39 | 
            +
                    mp.set_start_method(method)
         | 
| 40 | 
            +
                tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                global __rank, __local_rank, __world_size, __initialized
         | 
| 43 | 
            +
                __local_rank = local_rank
         | 
| 44 | 
            +
                __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
         | 
| 45 | 
            +
                __device = torch.empty(1).cuda().device
         | 
| 46 | 
            +
                __initialized = True
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                assert tdist.is_initialized(), 'torch.distributed is not initialized!'
         | 
| 49 | 
            +
                print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def get_rank():
         | 
| 53 | 
            +
                return __rank
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def get_local_rank():
         | 
| 57 | 
            +
                return __local_rank
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            def get_world_size():
         | 
| 61 | 
            +
                return __world_size
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def get_device():
         | 
| 65 | 
            +
                return __device
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def set_gpu_id(gpu_id: int):
         | 
| 69 | 
            +
                if gpu_id is None: return
         | 
| 70 | 
            +
                global __device
         | 
| 71 | 
            +
                if isinstance(gpu_id, (str, int)):
         | 
| 72 | 
            +
                    torch.cuda.set_device(int(gpu_id))
         | 
| 73 | 
            +
                    __device = torch.empty(1).cuda().device
         | 
| 74 | 
            +
                else:
         | 
| 75 | 
            +
                    raise NotImplementedError
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def is_master():
         | 
| 79 | 
            +
                return __rank == 0
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def is_local_master():
         | 
| 83 | 
            +
                return __local_rank == 0
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def new_group(ranks: List[int]):
         | 
| 87 | 
            +
                if __initialized:
         | 
| 88 | 
            +
                    return tdist.new_group(ranks=ranks)
         | 
| 89 | 
            +
                return None
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def barrier():
         | 
| 93 | 
            +
                if __initialized:
         | 
| 94 | 
            +
                    tdist.barrier()
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def allreduce(t: torch.Tensor, async_op=False):
         | 
| 98 | 
            +
                if __initialized:
         | 
| 99 | 
            +
                    if not t.is_cuda:
         | 
| 100 | 
            +
                        cu = t.detach().cuda()
         | 
| 101 | 
            +
                        ret = tdist.all_reduce(cu, async_op=async_op)
         | 
| 102 | 
            +
                        t.copy_(cu.cpu())
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        ret = tdist.all_reduce(t, async_op=async_op)
         | 
| 105 | 
            +
                    return ret
         | 
| 106 | 
            +
                return None
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
         | 
| 110 | 
            +
                if __initialized:
         | 
| 111 | 
            +
                    if not t.is_cuda:
         | 
| 112 | 
            +
                        t = t.cuda()
         | 
| 113 | 
            +
                    ls = [torch.empty_like(t) for _ in range(__world_size)]
         | 
| 114 | 
            +
                    tdist.all_gather(ls, t)
         | 
| 115 | 
            +
                else:
         | 
| 116 | 
            +
                    ls = [t]
         | 
| 117 | 
            +
                if cat:
         | 
| 118 | 
            +
                    ls = torch.cat(ls, dim=0)
         | 
| 119 | 
            +
                return ls
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
         | 
| 123 | 
            +
                if __initialized:
         | 
| 124 | 
            +
                    if not t.is_cuda:
         | 
| 125 | 
            +
                        t = t.cuda()
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    t_size = torch.tensor(t.size(), device=t.device)
         | 
| 128 | 
            +
                    ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
         | 
| 129 | 
            +
                    tdist.all_gather(ls_size, t_size)
         | 
| 130 | 
            +
                    
         | 
| 131 | 
            +
                    max_B = max(size[0].item() for size in ls_size)
         | 
| 132 | 
            +
                    pad = max_B - t_size[0].item()
         | 
| 133 | 
            +
                    if pad:
         | 
| 134 | 
            +
                        pad_size = (pad, *t.size()[1:])
         | 
| 135 | 
            +
                        t = torch.cat((t, t.new_empty(pad_size)), dim=0)
         | 
| 136 | 
            +
                    
         | 
| 137 | 
            +
                    ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
         | 
| 138 | 
            +
                    tdist.all_gather(ls_padded, t)
         | 
| 139 | 
            +
                    ls = []
         | 
| 140 | 
            +
                    for t, size in zip(ls_padded, ls_size):
         | 
| 141 | 
            +
                        ls.append(t[:size[0].item()])
         | 
| 142 | 
            +
                else:
         | 
| 143 | 
            +
                    ls = [t]
         | 
| 144 | 
            +
                if cat:
         | 
| 145 | 
            +
                    ls = torch.cat(ls, dim=0)
         | 
| 146 | 
            +
                return ls
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def broadcast(t: torch.Tensor, src_rank) -> None:
         | 
| 150 | 
            +
                if __initialized:
         | 
| 151 | 
            +
                    if not t.is_cuda:
         | 
| 152 | 
            +
                        cu = t.detach().cuda()
         | 
| 153 | 
            +
                        tdist.broadcast(cu, src=src_rank)
         | 
| 154 | 
            +
                        t.copy_(cu.cpu())
         | 
| 155 | 
            +
                    else:
         | 
| 156 | 
            +
                        tdist.broadcast(t, src=src_rank)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
         | 
| 160 | 
            +
                if not initialized():
         | 
| 161 | 
            +
                    return torch.tensor([val]) if fmt is None else [fmt % val]
         | 
| 162 | 
            +
                
         | 
| 163 | 
            +
                ts = torch.zeros(__world_size)
         | 
| 164 | 
            +
                ts[__rank] = val
         | 
| 165 | 
            +
                allreduce(ts)
         | 
| 166 | 
            +
                if fmt is None:
         | 
| 167 | 
            +
                    return ts
         | 
| 168 | 
            +
                return [fmt % v for v in ts.cpu().numpy().tolist()]
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def master_only(func):
         | 
| 172 | 
            +
                @functools.wraps(func)
         | 
| 173 | 
            +
                def wrapper(*args, **kwargs):
         | 
| 174 | 
            +
                    force = kwargs.pop('force', False)
         | 
| 175 | 
            +
                    if force or is_master():
         | 
| 176 | 
            +
                        ret = func(*args, **kwargs)
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        ret = None
         | 
| 179 | 
            +
                    barrier()
         | 
| 180 | 
            +
                    return ret
         | 
| 181 | 
            +
                return wrapper
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            def local_master_only(func):
         | 
| 185 | 
            +
                @functools.wraps(func)
         | 
| 186 | 
            +
                def wrapper(*args, **kwargs):
         | 
| 187 | 
            +
                    force = kwargs.pop('force', False)
         | 
| 188 | 
            +
                    if force or is_local_master():
         | 
| 189 | 
            +
                        ret = func(*args, **kwargs)
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        ret = None
         | 
| 192 | 
            +
                    barrier()
         | 
| 193 | 
            +
                    return ret
         | 
| 194 | 
            +
                return wrapper
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            def for_visualize(func):
         | 
| 198 | 
            +
                @functools.wraps(func)
         | 
| 199 | 
            +
                def wrapper(*args, **kwargs):
         | 
| 200 | 
            +
                    if is_master():
         | 
| 201 | 
            +
                        # with torch.no_grad():
         | 
| 202 | 
            +
                        ret = func(*args, **kwargs)
         | 
| 203 | 
            +
                    else:
         | 
| 204 | 
            +
                        ret = None
         | 
| 205 | 
            +
                    return ret
         | 
| 206 | 
            +
                return wrapper
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            def finalize():
         | 
| 210 | 
            +
                if __initialized:
         | 
| 211 | 
            +
                    tdist.destroy_process_group()
         | 
    	
        models/__init__.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Tuple
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .quant import VectorQuantizer2
         | 
| 5 | 
            +
            from .var import VAR
         | 
| 6 | 
            +
            from .vqvae import VQVAE
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def build_vae_var(
         | 
| 10 | 
            +
                # Shared args
         | 
| 11 | 
            +
                device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),   # 10 steps by default
         | 
| 12 | 
            +
                # VQVAE args
         | 
| 13 | 
            +
                V=4096, Cvae=32, ch=160, share_quant_resi=4,
         | 
| 14 | 
            +
                # VAR args
         | 
| 15 | 
            +
                num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
         | 
| 16 | 
            +
                flash_if_available=True, fused_if_available=True,
         | 
| 17 | 
            +
                init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1,    # init_std < 0: automated
         | 
| 18 | 
            +
            ) -> Tuple[VQVAE, VAR]:
         | 
| 19 | 
            +
                heads = depth
         | 
| 20 | 
            +
                width = depth * 64
         | 
| 21 | 
            +
                dpr = 0.1 * depth/24
         | 
| 22 | 
            +
                
         | 
| 23 | 
            +
                # disable built-in initialization for speed
         | 
| 24 | 
            +
                for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
         | 
| 25 | 
            +
                    setattr(clz, 'reset_parameters', lambda self: None)
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                # build models
         | 
| 28 | 
            +
                vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
         | 
| 29 | 
            +
                var_wo_ddp = VAR(
         | 
| 30 | 
            +
                    vae_local=vae_local,
         | 
| 31 | 
            +
                    num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
         | 
| 32 | 
            +
                    norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
         | 
| 33 | 
            +
                    attn_l2_norm=attn_l2_norm,
         | 
| 34 | 
            +
                    patch_nums=patch_nums,
         | 
| 35 | 
            +
                    flash_if_available=flash_if_available, fused_if_available=fused_if_available,
         | 
| 36 | 
            +
                ).to(device)
         | 
| 37 | 
            +
                var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
         | 
| 38 | 
            +
                
         | 
| 39 | 
            +
                return vae_local, var_wo_ddp
         | 
    	
        models/basic_vae.py
    ADDED
    
    | @@ -0,0 +1,226 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            # this file only provides the 2 modules used in VQVAE
         | 
| 7 | 
            +
            __all__ = ['Encoder', 'Decoder',]
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
            References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
         | 
| 12 | 
            +
            """
         | 
| 13 | 
            +
            # swish
         | 
| 14 | 
            +
            def nonlinearity(x):
         | 
| 15 | 
            +
                return x * torch.sigmoid(x)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def Normalize(in_channels, num_groups=32):
         | 
| 19 | 
            +
                return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class Upsample2x(nn.Module):
         | 
| 23 | 
            +
                def __init__(self, in_channels):
         | 
| 24 | 
            +
                    super().__init__()
         | 
| 25 | 
            +
                    self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                def forward(self, x):
         | 
| 28 | 
            +
                    return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class Downsample2x(nn.Module):
         | 
| 32 | 
            +
                def __init__(self, in_channels):
         | 
| 33 | 
            +
                    super().__init__()
         | 
| 34 | 
            +
                    self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
                def forward(self, x):
         | 
| 37 | 
            +
                    return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class ResnetBlock(nn.Module):
         | 
| 41 | 
            +
                def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False,  # conv_shortcut: always False in VAE
         | 
| 42 | 
            +
                    super().__init__()
         | 
| 43 | 
            +
                    self.in_channels = in_channels
         | 
| 44 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 45 | 
            +
                    self.out_channels = out_channels
         | 
| 46 | 
            +
                    
         | 
| 47 | 
            +
                    self.norm1 = Normalize(in_channels)
         | 
| 48 | 
            +
                    self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 49 | 
            +
                    self.norm2 = Normalize(out_channels)
         | 
| 50 | 
            +
                    self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
         | 
| 51 | 
            +
                    self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 52 | 
            +
                    if self.in_channels != self.out_channels:
         | 
| 53 | 
            +
                        self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
         | 
| 54 | 
            +
                    else:
         | 
| 55 | 
            +
                        self.nin_shortcut = nn.Identity()
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                def forward(self, x):
         | 
| 58 | 
            +
                    h = self.conv1(F.silu(self.norm1(x), inplace=True))
         | 
| 59 | 
            +
                    h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
         | 
| 60 | 
            +
                    return self.nin_shortcut(x) + h
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            class AttnBlock(nn.Module):
         | 
| 64 | 
            +
                def __init__(self, in_channels):
         | 
| 65 | 
            +
                    super().__init__()
         | 
| 66 | 
            +
                    self.C = in_channels
         | 
| 67 | 
            +
                    
         | 
| 68 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 69 | 
            +
                    self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 70 | 
            +
                    self.w_ratio = int(in_channels) ** (-0.5)
         | 
| 71 | 
            +
                    self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                def forward(self, x):
         | 
| 74 | 
            +
                    qkv = self.qkv(self.norm(x))
         | 
| 75 | 
            +
                    B, _, H, W = qkv.shape  # should be B,3C,H,W
         | 
| 76 | 
            +
                    C = self.C
         | 
| 77 | 
            +
                    q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    # compute attention
         | 
| 80 | 
            +
                    q = q.view(B, C, H * W).contiguous()
         | 
| 81 | 
            +
                    q = q.permute(0, 2, 1).contiguous()     # B,HW,C
         | 
| 82 | 
            +
                    k = k.view(B, C, H * W).contiguous()    # B,C,HW
         | 
| 83 | 
            +
                    w = torch.bmm(q, k).mul_(self.w_ratio)  # B,HW,HW    w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
         | 
| 84 | 
            +
                    w = F.softmax(w, dim=2)
         | 
| 85 | 
            +
                    
         | 
| 86 | 
            +
                    # attend to values
         | 
| 87 | 
            +
                    v = v.view(B, C, H * W).contiguous()
         | 
| 88 | 
            +
                    w = w.permute(0, 2, 1).contiguous()  # B,HW,HW (first HW of k, second of q)
         | 
| 89 | 
            +
                    h = torch.bmm(v, w)  # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
         | 
| 90 | 
            +
                    h = h.view(B, C, H, W).contiguous()
         | 
| 91 | 
            +
                    
         | 
| 92 | 
            +
                    return x + self.proj_out(h)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def make_attn(in_channels, using_sa=True):
         | 
| 96 | 
            +
                return AttnBlock(in_channels) if using_sa else nn.Identity()
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            class Encoder(nn.Module):
         | 
| 100 | 
            +
                def __init__(
         | 
| 101 | 
            +
                    self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
         | 
| 102 | 
            +
                    dropout=0.0, in_channels=3,
         | 
| 103 | 
            +
                    z_channels, double_z=False, using_sa=True, using_mid_sa=True,
         | 
| 104 | 
            +
                ):
         | 
| 105 | 
            +
                    super().__init__()
         | 
| 106 | 
            +
                    self.ch = ch
         | 
| 107 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 108 | 
            +
                    self.downsample_ratio = 2 ** (self.num_resolutions - 1)
         | 
| 109 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 110 | 
            +
                    self.in_channels = in_channels
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                    # downsampling
         | 
| 113 | 
            +
                    self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         | 
| 116 | 
            +
                    self.down = nn.ModuleList()
         | 
| 117 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 118 | 
            +
                        block = nn.ModuleList()
         | 
| 119 | 
            +
                        attn = nn.ModuleList()
         | 
| 120 | 
            +
                        block_in = ch * in_ch_mult[i_level]
         | 
| 121 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 122 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 123 | 
            +
                            block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
         | 
| 124 | 
            +
                            block_in = block_out
         | 
| 125 | 
            +
                            if i_level == self.num_resolutions - 1 and using_sa:
         | 
| 126 | 
            +
                                attn.append(make_attn(block_in, using_sa=True))
         | 
| 127 | 
            +
                        down = nn.Module()
         | 
| 128 | 
            +
                        down.block = block
         | 
| 129 | 
            +
                        down.attn = attn
         | 
| 130 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 131 | 
            +
                            down.downsample = Downsample2x(block_in)
         | 
| 132 | 
            +
                        self.down.append(down)
         | 
| 133 | 
            +
                    
         | 
| 134 | 
            +
                    # middle
         | 
| 135 | 
            +
                    self.mid = nn.Module()
         | 
| 136 | 
            +
                    self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
         | 
| 137 | 
            +
                    self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
         | 
| 138 | 
            +
                    self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
         | 
| 139 | 
            +
                    
         | 
| 140 | 
            +
                    # end
         | 
| 141 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 142 | 
            +
                    self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
         | 
| 143 | 
            +
                
         | 
| 144 | 
            +
                def forward(self, x):
         | 
| 145 | 
            +
                    # downsampling
         | 
| 146 | 
            +
                    h = self.conv_in(x)
         | 
| 147 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 148 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 149 | 
            +
                            h = self.down[i_level].block[i_block](h)
         | 
| 150 | 
            +
                            if len(self.down[i_level].attn) > 0:
         | 
| 151 | 
            +
                                h = self.down[i_level].attn[i_block](h)
         | 
| 152 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 153 | 
            +
                            h = self.down[i_level].downsample(h)
         | 
| 154 | 
            +
                    
         | 
| 155 | 
            +
                    # middle
         | 
| 156 | 
            +
                    h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
         | 
| 157 | 
            +
                    
         | 
| 158 | 
            +
                    # end
         | 
| 159 | 
            +
                    h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
         | 
| 160 | 
            +
                    return h
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            class Decoder(nn.Module):
         | 
| 164 | 
            +
                def __init__(
         | 
| 165 | 
            +
                    self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
         | 
| 166 | 
            +
                    dropout=0.0, in_channels=3,  # in_channels: raw img channels
         | 
| 167 | 
            +
                    z_channels, using_sa=True, using_mid_sa=True,
         | 
| 168 | 
            +
                ):
         | 
| 169 | 
            +
                    super().__init__()
         | 
| 170 | 
            +
                    self.ch = ch
         | 
| 171 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 172 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 173 | 
            +
                    self.in_channels = in_channels
         | 
| 174 | 
            +
                    
         | 
| 175 | 
            +
                    # compute in_ch_mult, block_in and curr_res at lowest res
         | 
| 176 | 
            +
                    in_ch_mult = (1,) + tuple(ch_mult)
         | 
| 177 | 
            +
                    block_in = ch * ch_mult[self.num_resolutions - 1]
         | 
| 178 | 
            +
                    
         | 
| 179 | 
            +
                    # z to block_in
         | 
| 180 | 
            +
                    self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
         | 
| 181 | 
            +
                    
         | 
| 182 | 
            +
                    # middle
         | 
| 183 | 
            +
                    self.mid = nn.Module()
         | 
| 184 | 
            +
                    self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
         | 
| 185 | 
            +
                    self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
         | 
| 186 | 
            +
                    self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
         | 
| 187 | 
            +
                    
         | 
| 188 | 
            +
                    # upsampling
         | 
| 189 | 
            +
                    self.up = nn.ModuleList()
         | 
| 190 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 191 | 
            +
                        block = nn.ModuleList()
         | 
| 192 | 
            +
                        attn = nn.ModuleList()
         | 
| 193 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 194 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 195 | 
            +
                            block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
         | 
| 196 | 
            +
                            block_in = block_out
         | 
| 197 | 
            +
                            if i_level == self.num_resolutions-1 and using_sa:
         | 
| 198 | 
            +
                                attn.append(make_attn(block_in, using_sa=True))
         | 
| 199 | 
            +
                        up = nn.Module()
         | 
| 200 | 
            +
                        up.block = block
         | 
| 201 | 
            +
                        up.attn = attn
         | 
| 202 | 
            +
                        if i_level != 0:
         | 
| 203 | 
            +
                            up.upsample = Upsample2x(block_in)
         | 
| 204 | 
            +
                        self.up.insert(0, up)  # prepend to get consistent order
         | 
| 205 | 
            +
                    
         | 
| 206 | 
            +
                    # end
         | 
| 207 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 208 | 
            +
                    self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
         | 
| 209 | 
            +
                
         | 
| 210 | 
            +
                def forward(self, z):
         | 
| 211 | 
            +
                    # z to block_in
         | 
| 212 | 
            +
                    # middle
         | 
| 213 | 
            +
                    h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
         | 
| 214 | 
            +
                    
         | 
| 215 | 
            +
                    # upsampling
         | 
| 216 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 217 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 218 | 
            +
                            h = self.up[i_level].block[i_block](h)
         | 
| 219 | 
            +
                            if len(self.up[i_level].attn) > 0:
         | 
| 220 | 
            +
                                h = self.up[i_level].attn[i_block](h)
         | 
| 221 | 
            +
                        if i_level != 0:
         | 
| 222 | 
            +
                            h = self.up[i_level].upsample(h)
         | 
| 223 | 
            +
                    
         | 
| 224 | 
            +
                    # end
         | 
| 225 | 
            +
                    h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
         | 
| 226 | 
            +
                    return h
         | 
    	
        models/basic_var.py
    ADDED
    
    | @@ -0,0 +1,174 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from models.helpers import DropPath, drop_path
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            # this file only provides the 3 blocks used in VAR transformer
         | 
| 11 | 
            +
            __all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            # automatically import fused operators
         | 
| 15 | 
            +
            dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
         | 
| 16 | 
            +
            try:
         | 
| 17 | 
            +
                from flash_attn.ops.layer_norm import dropout_add_layer_norm
         | 
| 18 | 
            +
                from flash_attn.ops.fused_dense import fused_mlp_func
         | 
| 19 | 
            +
            except ImportError: pass
         | 
| 20 | 
            +
            # automatically import faster attention implementations
         | 
| 21 | 
            +
            try: from xformers.ops import memory_efficient_attention
         | 
| 22 | 
            +
            except ImportError: pass
         | 
| 23 | 
            +
            try: from flash_attn import flash_attn_func              # qkv: BLHc, ret: BLHcq
         | 
| 24 | 
            +
            except ImportError: pass
         | 
| 25 | 
            +
            try: from torch.nn.functional import scaled_dot_product_attention as slow_attn    # q, k, v: BHLc
         | 
| 26 | 
            +
            except ImportError:
         | 
| 27 | 
            +
                def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
         | 
| 28 | 
            +
                    attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
         | 
| 29 | 
            +
                    if attn_mask is not None: attn.add_(attn_mask)
         | 
| 30 | 
            +
                    return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class FFN(nn.Module):
         | 
| 34 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
         | 
| 35 | 
            +
                    super().__init__()
         | 
| 36 | 
            +
                    self.fused_mlp_func = fused_mlp_func if fused_if_available else None
         | 
| 37 | 
            +
                    out_features = out_features or in_features
         | 
| 38 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 39 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features)
         | 
| 40 | 
            +
                    self.act = nn.GELU(approximate='tanh')
         | 
| 41 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features)
         | 
| 42 | 
            +
                    self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                def forward(self, x):
         | 
| 45 | 
            +
                    if self.fused_mlp_func is not None:
         | 
| 46 | 
            +
                        return self.drop(self.fused_mlp_func(
         | 
| 47 | 
            +
                            x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
         | 
| 48 | 
            +
                            activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
         | 
| 49 | 
            +
                            heuristic=0, process_group=None,
         | 
| 50 | 
            +
                        ))
         | 
| 51 | 
            +
                    else:
         | 
| 52 | 
            +
                        return self.drop(self.fc2( self.act(self.fc1(x)) ))
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
                def extra_repr(self) -> str:
         | 
| 55 | 
            +
                    return f'fused_mlp_func={self.fused_mlp_func is not None}'
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            class SelfAttention(nn.Module):
         | 
| 59 | 
            +
                def __init__(
         | 
| 60 | 
            +
                    self, block_idx, embed_dim=768, num_heads=12,
         | 
| 61 | 
            +
                    attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
         | 
| 62 | 
            +
                ):
         | 
| 63 | 
            +
                    super().__init__()
         | 
| 64 | 
            +
                    assert embed_dim % num_heads == 0
         | 
| 65 | 
            +
                    self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads  # =64
         | 
| 66 | 
            +
                    self.attn_l2_norm = attn_l2_norm
         | 
| 67 | 
            +
                    if self.attn_l2_norm:
         | 
| 68 | 
            +
                        self.scale = 1
         | 
| 69 | 
            +
                        self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
         | 
| 70 | 
            +
                        self.max_scale_mul = torch.log(torch.tensor(100)).item()
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        self.scale = 0.25 / math.sqrt(self.head_dim)
         | 
| 73 | 
            +
                    
         | 
| 74 | 
            +
                    self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
         | 
| 75 | 
            +
                    self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
         | 
| 76 | 
            +
                    self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
         | 
| 77 | 
            +
                    
         | 
| 78 | 
            +
                    self.proj = nn.Linear(embed_dim, embed_dim)
         | 
| 79 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
         | 
| 80 | 
            +
                    self.attn_drop: float = attn_drop
         | 
| 81 | 
            +
                    self.using_flash = flash_if_available and flash_attn_func is not None
         | 
| 82 | 
            +
                    self.using_xform = flash_if_available and memory_efficient_attention is not None
         | 
| 83 | 
            +
                    
         | 
| 84 | 
            +
                    # only used during inference
         | 
| 85 | 
            +
                    self.caching, self.cached_k, self.cached_v = False, None, None
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
         | 
| 88 | 
            +
                
         | 
| 89 | 
            +
                # NOTE: attn_bias is None during inference because kv cache is enabled
         | 
| 90 | 
            +
                def forward(self, x, attn_bias):
         | 
| 91 | 
            +
                    B, L, C = x.shape
         | 
| 92 | 
            +
                    
         | 
| 93 | 
            +
                    qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
         | 
| 94 | 
            +
                    main_type = qkv.dtype
         | 
| 95 | 
            +
                    # qkv: BL3Hc
         | 
| 96 | 
            +
                    
         | 
| 97 | 
            +
                    using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
         | 
| 98 | 
            +
                    if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1   # q or k or v: BLHc
         | 
| 99 | 
            +
                    else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2               # q or k or v: BHLc
         | 
| 100 | 
            +
                    
         | 
| 101 | 
            +
                    if self.attn_l2_norm:
         | 
| 102 | 
            +
                        scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
         | 
| 103 | 
            +
                        if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2)  # 1H11 to 11H1
         | 
| 104 | 
            +
                        q = F.normalize(q, dim=-1).mul(scale_mul)
         | 
| 105 | 
            +
                        k = F.normalize(k, dim=-1)
         | 
| 106 | 
            +
                    
         | 
| 107 | 
            +
                    if self.caching:
         | 
| 108 | 
            +
                        if self.cached_k is None: self.cached_k = k; self.cached_v = v
         | 
| 109 | 
            +
                        else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
         | 
| 110 | 
            +
                    
         | 
| 111 | 
            +
                    dropout_p = self.attn_drop if self.training else 0.0
         | 
| 112 | 
            +
                    if using_flash:
         | 
| 113 | 
            +
                        oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
         | 
| 114 | 
            +
                    elif self.using_xform:
         | 
| 115 | 
            +
                        oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
         | 
| 116 | 
            +
                    else:
         | 
| 117 | 
            +
                        oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    return self.proj_drop(self.proj(oup))
         | 
| 120 | 
            +
                    # attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb())  # BHLc @ BHcL => BHLL
         | 
| 121 | 
            +
                    # attn = self.attn_drop(attn.softmax(dim=-1))
         | 
| 122 | 
            +
                    # oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1)     # BHLL @ BHLc = BHLc => BLHc => BLC
         | 
| 123 | 
            +
                
         | 
| 124 | 
            +
                def extra_repr(self) -> str:
         | 
| 125 | 
            +
                    return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            class AdaLNSelfAttn(nn.Module):
         | 
| 129 | 
            +
                def __init__(
         | 
| 130 | 
            +
                    self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
         | 
| 131 | 
            +
                    num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
         | 
| 132 | 
            +
                    flash_if_available=False, fused_if_available=True,
         | 
| 133 | 
            +
                ):
         | 
| 134 | 
            +
                    super(AdaLNSelfAttn, self).__init__()
         | 
| 135 | 
            +
                    self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
         | 
| 136 | 
            +
                    self.C, self.D = embed_dim, cond_dim
         | 
| 137 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 138 | 
            +
                    self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
         | 
| 139 | 
            +
                    self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
         | 
| 140 | 
            +
                    
         | 
| 141 | 
            +
                    self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
         | 
| 142 | 
            +
                    self.shared_aln = shared_aln
         | 
| 143 | 
            +
                    if self.shared_aln:
         | 
| 144 | 
            +
                        self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        lin = nn.Linear(cond_dim, 6*embed_dim)
         | 
| 147 | 
            +
                        self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
         | 
| 148 | 
            +
                    
         | 
| 149 | 
            +
                    self.fused_add_norm_fn = None
         | 
| 150 | 
            +
                
         | 
| 151 | 
            +
                # NOTE: attn_bias is None during inference because kv cache is enabled
         | 
| 152 | 
            +
                def forward(self, x, cond_BD, attn_bias):   # C: embed_dim, D: cond_dim
         | 
| 153 | 
            +
                    if self.shared_aln:
         | 
| 154 | 
            +
                        gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
         | 
| 155 | 
            +
                    else:
         | 
| 156 | 
            +
                        gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
         | 
| 157 | 
            +
                    x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
         | 
| 158 | 
            +
                    x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
         | 
| 159 | 
            +
                    return x
         | 
| 160 | 
            +
                
         | 
| 161 | 
            +
                def extra_repr(self) -> str:
         | 
| 162 | 
            +
                    return f'shared_aln={self.shared_aln}'
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            class AdaLNBeforeHead(nn.Module):
         | 
| 166 | 
            +
                def __init__(self, C, D, norm_layer):   # C: embed_dim, D: cond_dim
         | 
| 167 | 
            +
                    super().__init__()
         | 
| 168 | 
            +
                    self.C, self.D = C, D
         | 
| 169 | 
            +
                    self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
         | 
| 170 | 
            +
                    self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
         | 
| 171 | 
            +
                
         | 
| 172 | 
            +
                def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
         | 
| 173 | 
            +
                    scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
         | 
| 174 | 
            +
                    return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
         | 
    	
        models/helpers.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn as nn
         | 
| 3 | 
            +
            from torch.nn import functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor:  # return idx, shaped (B, l)
         | 
| 7 | 
            +
                B, l, V = logits_BlV.shape
         | 
| 8 | 
            +
                if top_k > 0:
         | 
| 9 | 
            +
                    idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
         | 
| 10 | 
            +
                    logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
         | 
| 11 | 
            +
                if top_p > 0:
         | 
| 12 | 
            +
                    sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
         | 
| 13 | 
            +
                    sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
         | 
| 14 | 
            +
                    sorted_idx_to_remove[..., -1:] = False
         | 
| 15 | 
            +
                    logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
         | 
| 16 | 
            +
                # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
         | 
| 17 | 
            +
                replacement = num_samples >= 0
         | 
| 18 | 
            +
                num_samples = abs(num_samples)
         | 
| 19 | 
            +
                return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
         | 
| 23 | 
            +
                if rng is None:
         | 
| 24 | 
            +
                    return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
                gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
         | 
| 27 | 
            +
                gumbels = (logits + gumbels) / tau
         | 
| 28 | 
            +
                y_soft = gumbels.softmax(dim)
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                if hard:
         | 
| 31 | 
            +
                    index = y_soft.max(dim, keepdim=True)[1]
         | 
| 32 | 
            +
                    y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
         | 
| 33 | 
            +
                    ret = y_hard - y_soft.detach() + y_soft
         | 
| 34 | 
            +
                else:
         | 
| 35 | 
            +
                    ret = y_soft
         | 
| 36 | 
            +
                return ret
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):    # taken from timm
         | 
| 40 | 
            +
                if drop_prob == 0. or not training: return x
         | 
| 41 | 
            +
                keep_prob = 1 - drop_prob
         | 
| 42 | 
            +
                shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
         | 
| 43 | 
            +
                random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
         | 
| 44 | 
            +
                if keep_prob > 0.0 and scale_by_keep:
         | 
| 45 | 
            +
                    random_tensor.div_(keep_prob)
         | 
| 46 | 
            +
                return x * random_tensor
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            class DropPath(nn.Module):  # taken from timm
         | 
| 50 | 
            +
                def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
         | 
| 51 | 
            +
                    super(DropPath, self).__init__()
         | 
| 52 | 
            +
                    self.drop_prob = drop_prob
         | 
| 53 | 
            +
                    self.scale_by_keep = scale_by_keep
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                def forward(self, x):
         | 
| 56 | 
            +
                    return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
         | 
| 57 | 
            +
                
         | 
| 58 | 
            +
                def extra_repr(self):
         | 
| 59 | 
            +
                    return f'(drop_prob=...)'
         | 
    	
        models/quant.py
    ADDED
    
    | @@ -0,0 +1,281 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import List, Optional, Sequence, Tuple, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from torch import distributed as tdist, nn as nn
         | 
| 6 | 
            +
            from torch.nn import functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import dist
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # this file only provides the VectorQuantizer2 used in VQVAE
         | 
| 11 | 
            +
            __all__ = ['VectorQuantizer2', ]
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class VectorQuantizer2(nn.Module):
         | 
| 15 | 
            +
                # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
         | 
| 16 | 
            +
                def __init__(
         | 
| 17 | 
            +
                        self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
         | 
| 18 | 
            +
                        default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4,  # share_quant_resi: args.qsr
         | 
| 19 | 
            +
                ):
         | 
| 20 | 
            +
                    super().__init__()
         | 
| 21 | 
            +
                    self.vocab_size: int = vocab_size
         | 
| 22 | 
            +
                    self.Cvae: int = Cvae
         | 
| 23 | 
            +
                    self.using_znorm: bool = using_znorm
         | 
| 24 | 
            +
                    self.v_patch_nums: Tuple[int] = v_patch_nums
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    self.quant_resi_ratio = quant_resi
         | 
| 27 | 
            +
                    if share_quant_resi == 0:  # non-shared: \phi_{1 to K} for K scales
         | 
| 28 | 
            +
                        self.quant_resi = PhiNonShared(
         | 
| 29 | 
            +
                            [(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
         | 
| 30 | 
            +
                             range(default_qresi_counts or len(self.v_patch_nums))])
         | 
| 31 | 
            +
                    elif share_quant_resi == 1:  # fully shared: only a single \phi for K scales
         | 
| 32 | 
            +
                        self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
         | 
| 33 | 
            +
                    else:  # partially shared: \phi_{1 to share_quant_resi} for K scales
         | 
| 34 | 
            +
                        self.quant_resi = PhiPartiallyShared(nn.ModuleList(
         | 
| 35 | 
            +
                            [(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
         | 
| 36 | 
            +
                             range(share_quant_resi)]))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
         | 
| 39 | 
            +
                    self.record_hit = 0
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    self.beta: float = beta
         | 
| 42 | 
            +
                    self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
         | 
| 45 | 
            +
                    self.prog_si = -1  # progressive training: not supported yet, prog_si always -1
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def eini(self, eini):
         | 
| 48 | 
            +
                    if eini > 0:
         | 
| 49 | 
            +
                        nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
         | 
| 50 | 
            +
                    elif eini < 0:
         | 
| 51 | 
            +
                        self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def extra_repr(self) -> str:
         | 
| 54 | 
            +
                    return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta}  |  S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # ===================== `forward` is only used in VAE training =====================
         | 
| 57 | 
            +
                def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
         | 
| 58 | 
            +
                    dtype = f_BChw.dtype
         | 
| 59 | 
            +
                    if dtype != torch.float32: f_BChw = f_BChw.float()
         | 
| 60 | 
            +
                    B, C, H, W = f_BChw.shape
         | 
| 61 | 
            +
                    f_no_grad = f_BChw.detach()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    f_rest = f_no_grad.clone()
         | 
| 64 | 
            +
                    f_hat = torch.zeros_like(f_rest)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    with torch.cuda.amp.autocast(enabled=False):
         | 
| 67 | 
            +
                        mean_vq_loss: torch.Tensor = 0.0
         | 
| 68 | 
            +
                        vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
         | 
| 69 | 
            +
                        SN = len(self.v_patch_nums)
         | 
| 70 | 
            +
                        for si, pn in enumerate(self.v_patch_nums):  # from small to large
         | 
| 71 | 
            +
                            # find the nearest embedding
         | 
| 72 | 
            +
                            if self.using_znorm:
         | 
| 73 | 
            +
                                rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
         | 
| 74 | 
            +
                                                                                                                            C) if (
         | 
| 75 | 
            +
                                            si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
         | 
| 76 | 
            +
                                rest_NC = F.normalize(rest_NC, dim=-1)
         | 
| 77 | 
            +
                                idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
         | 
| 78 | 
            +
                            else:
         | 
| 79 | 
            +
                                rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
         | 
| 80 | 
            +
                                                                                                                            C) if (
         | 
| 81 | 
            +
                                            si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
         | 
| 82 | 
            +
                                d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(
         | 
| 83 | 
            +
                                    self.embedding.weight.data.square(), dim=1, keepdim=False)
         | 
| 84 | 
            +
                                d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1)  # (B*h*w, vocab_size)
         | 
| 85 | 
            +
                                idx_N = torch.argmin(d_no_grad, dim=1)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                            hit_V = idx_N.bincount(minlength=self.vocab_size).float()
         | 
| 88 | 
            +
                            if self.training:
         | 
| 89 | 
            +
                                if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                            # calc loss
         | 
| 92 | 
            +
                            idx_Bhw = idx_N.view(B, pn, pn)
         | 
| 93 | 
            +
                            h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
         | 
| 94 | 
            +
                                                   mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(
         | 
| 95 | 
            +
                                idx_Bhw).permute(0, 3, 1, 2).contiguous()
         | 
| 96 | 
            +
                            h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
         | 
| 97 | 
            +
                            f_hat = f_hat + h_BChw
         | 
| 98 | 
            +
                            f_rest -= h_BChw
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                            if self.training and dist.initialized():
         | 
| 101 | 
            +
                                handler.wait()
         | 
| 102 | 
            +
                                if self.record_hit == 0:
         | 
| 103 | 
            +
                                    self.ema_vocab_hit_SV[si].copy_(hit_V)
         | 
| 104 | 
            +
                                elif self.record_hit < 100:
         | 
| 105 | 
            +
                                    self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
         | 
| 106 | 
            +
                                else:
         | 
| 107 | 
            +
                                    self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
         | 
| 108 | 
            +
                                self.record_hit += 1
         | 
| 109 | 
            +
                            vocab_hit_V.add_(hit_V)
         | 
| 110 | 
            +
                            mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                        mean_vq_loss *= 1. / SN
         | 
| 113 | 
            +
                        f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
         | 
| 116 | 
            +
                    # margin = pn*pn / 100
         | 
| 117 | 
            +
                    if ret_usages:
         | 
| 118 | 
            +
                        usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in
         | 
| 119 | 
            +
                                  enumerate(self.v_patch_nums)]
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        usages = None
         | 
| 122 | 
            +
                    return f_hat, usages, mean_vq_loss
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                # ===================== `forward` is only used in VAE training =====================
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[
         | 
| 127 | 
            +
                    List[torch.Tensor], torch.Tensor]:
         | 
| 128 | 
            +
                    ls_f_hat_BChw = []
         | 
| 129 | 
            +
                    B = ms_h_BChw[0].shape[0]
         | 
| 130 | 
            +
                    H = W = self.v_patch_nums[-1]
         | 
| 131 | 
            +
                    SN = len(self.v_patch_nums)
         | 
| 132 | 
            +
                    if all_to_max_scale:
         | 
| 133 | 
            +
                        f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
         | 
| 134 | 
            +
                        for si, pn in enumerate(self.v_patch_nums):  # from small to large
         | 
| 135 | 
            +
                            h_BChw = ms_h_BChw[si]
         | 
| 136 | 
            +
                            if si < len(self.v_patch_nums) - 1:
         | 
| 137 | 
            +
                                h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bilinear')
         | 
| 138 | 
            +
                            h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
         | 
| 139 | 
            +
                            f_hat.add_(h_BChw)
         | 
| 140 | 
            +
                            if last_one:
         | 
| 141 | 
            +
                                ls_f_hat_BChw = f_hat
         | 
| 142 | 
            +
                            else:
         | 
| 143 | 
            +
                                ls_f_hat_BChw.append(f_hat.clone())
         | 
| 144 | 
            +
                    else:
         | 
| 145 | 
            +
                        # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
         | 
| 146 | 
            +
                        # WARNING: this should only be used for experimental purpose
         | 
| 147 | 
            +
                        f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0],
         | 
| 148 | 
            +
                                                       dtype=torch.float32)
         | 
| 149 | 
            +
                        for si, pn in enumerate(self.v_patch_nums):  # from small to large
         | 
| 150 | 
            +
                            f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bilinear')
         | 
| 151 | 
            +
                            h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
         | 
| 152 | 
            +
                            f_hat.add_(h_BChw)
         | 
| 153 | 
            +
                            if last_one:
         | 
| 154 | 
            +
                                ls_f_hat_BChw = f_hat
         | 
| 155 | 
            +
                            else:
         | 
| 156 | 
            +
                                ls_f_hat_BChw.append(f_hat)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    return ls_f_hat_BChw
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool,
         | 
| 161 | 
            +
                                       v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[
         | 
| 162 | 
            +
                    Union[torch.Tensor, torch.LongTensor]]:  # z_BChw is the feature from inp_img_no_grad
         | 
| 163 | 
            +
                    B, C, H, W = f_BChw.shape
         | 
| 164 | 
            +
                    f_no_grad = f_BChw.detach()
         | 
| 165 | 
            +
                    f_rest = f_no_grad.clone()
         | 
| 166 | 
            +
                    f_hat = torch.zeros_like(f_rest)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    f_hat_or_idx_Bl: List[torch.Tensor] = []
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in
         | 
| 171 | 
            +
                                 (v_patch_nums or self.v_patch_nums)]  # from small to large
         | 
| 172 | 
            +
                    assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    SN = len(patch_hws)
         | 
| 175 | 
            +
                    for si, (ph, pw) in enumerate(patch_hws):  # from small to large
         | 
| 176 | 
            +
                        if 0 <= self.prog_si < si: break  # progressive training: not supported yet, prog_si always -1
         | 
| 177 | 
            +
                        # find the nearest embedding
         | 
| 178 | 
            +
                        z_NC = F.interpolate(f_rest, size=(ph, pw), mode='bilinear').permute(0, 2, 3, 1).reshape(-1, C) if (
         | 
| 179 | 
            +
                                    si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
         | 
| 180 | 
            +
                        if self.using_znorm:
         | 
| 181 | 
            +
                            z_NC = F.normalize(z_NC, dim=-1)
         | 
| 182 | 
            +
                            idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
         | 
| 183 | 
            +
                        else:
         | 
| 184 | 
            +
                            d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
         | 
| 185 | 
            +
                                self.embedding.weight.data.square(), dim=1, keepdim=False)
         | 
| 186 | 
            +
                            d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1)  # (B*h*w, vocab_size)
         | 
| 187 | 
            +
                            idx_N = torch.argmin(d_no_grad, dim=1)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                        idx_Bhw = idx_N.view(B, ph, pw)
         | 
| 190 | 
            +
                        h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
         | 
| 191 | 
            +
                                               mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(idx_Bhw).permute(
         | 
| 192 | 
            +
                            0, 3, 1, 2).contiguous()
         | 
| 193 | 
            +
                        h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
         | 
| 194 | 
            +
                        f_hat.add_(h_BChw)
         | 
| 195 | 
            +
                        f_rest.sub_(h_BChw)
         | 
| 196 | 
            +
                        f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw))
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    return f_hat_or_idx_Bl
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
         | 
| 201 | 
            +
                def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
         | 
| 202 | 
            +
                    next_scales = []
         | 
| 203 | 
            +
                    B = gt_ms_idx_Bl[0].shape[0]
         | 
| 204 | 
            +
                    C = self.Cvae
         | 
| 205 | 
            +
                    H = W = self.v_patch_nums[-1]
         | 
| 206 | 
            +
                    SN = len(self.v_patch_nums)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
         | 
| 209 | 
            +
                    pn_next: int = self.v_patch_nums[0]
         | 
| 210 | 
            +
                    for si in range(SN - 1):
         | 
| 211 | 
            +
                        if self.prog_si == 0 or (
         | 
| 212 | 
            +
                                0 <= self.prog_si - 1 < si): break  # progressive training: not supported yet, prog_si always -1
         | 
| 213 | 
            +
                        h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next),
         | 
| 214 | 
            +
                                               size=(H, W), mode='bilinear')
         | 
| 215 | 
            +
                        f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
         | 
| 216 | 
            +
                        pn_next = self.v_patch_nums[si + 1]
         | 
| 217 | 
            +
                        next_scales.append(
         | 
| 218 | 
            +
                            F.interpolate(f_hat, size=(pn_next, pn_next), mode='bilinear').view(B, C, -1).transpose(1, 2))
         | 
| 219 | 
            +
                    return torch.cat(next_scales, dim=1) if len(next_scales) else None  # cat BlCs to BLC, this should be float32
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
         | 
| 222 | 
            +
                def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[
         | 
| 223 | 
            +
                    Optional[torch.Tensor], torch.Tensor]:  # only used in VAR inference
         | 
| 224 | 
            +
                    HW = self.v_patch_nums[-1]
         | 
| 225 | 
            +
                    if si != SN - 1:
         | 
| 226 | 
            +
                        h = self.quant_resi[si / (SN - 1)](
         | 
| 227 | 
            +
                            F.interpolate(h_BChw, size=(HW, HW), mode='bilinear'))  # conv after upsample
         | 
| 228 | 
            +
                        f_hat.add_(h)
         | 
| 229 | 
            +
                        return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
         | 
| 230 | 
            +
                                                    mode='bilinear')
         | 
| 231 | 
            +
                    else:
         | 
| 232 | 
            +
                        h = self.quant_resi[si / (SN - 1)](h_BChw)
         | 
| 233 | 
            +
                        f_hat.add_(h)
         | 
| 234 | 
            +
                        return f_hat, f_hat
         | 
| 235 | 
            +
             | 
| 236 | 
            +
             | 
| 237 | 
            +
            class Phi(nn.Conv2d):
         | 
| 238 | 
            +
                def __init__(self, embed_dim, quant_resi):
         | 
| 239 | 
            +
                    ks = 3
         | 
| 240 | 
            +
                    super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
         | 
| 241 | 
            +
                    self.resi_ratio = abs(quant_resi)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def forward(self, h_BChw):
         | 
| 244 | 
            +
                    return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
             | 
| 247 | 
            +
            class PhiShared(nn.Module):
         | 
| 248 | 
            +
                def __init__(self, qresi: Phi):
         | 
| 249 | 
            +
                    super().__init__()
         | 
| 250 | 
            +
                    self.qresi: Phi = qresi
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def __getitem__(self, _) -> Phi:
         | 
| 253 | 
            +
                    return self.qresi
         | 
| 254 | 
            +
             | 
| 255 | 
            +
             | 
| 256 | 
            +
            class PhiPartiallyShared(nn.Module):
         | 
| 257 | 
            +
                def __init__(self, qresi_ls: nn.ModuleList):
         | 
| 258 | 
            +
                    super().__init__()
         | 
| 259 | 
            +
                    self.qresi_ls = qresi_ls
         | 
| 260 | 
            +
                    K = len(qresi_ls)
         | 
| 261 | 
            +
                    self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                def __getitem__(self, at_from_0_to_1: float) -> Phi:
         | 
| 264 | 
            +
                    return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                def extra_repr(self) -> str:
         | 
| 267 | 
            +
                    return f'ticks={self.ticks}'
         | 
| 268 | 
            +
             | 
| 269 | 
            +
             | 
| 270 | 
            +
            class PhiNonShared(nn.ModuleList):
         | 
| 271 | 
            +
                def __init__(self, qresi: List):
         | 
| 272 | 
            +
                    super().__init__(qresi)
         | 
| 273 | 
            +
                    # self.qresi = qresi
         | 
| 274 | 
            +
                    K = len(qresi)
         | 
| 275 | 
            +
                    self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def __getitem__(self, at_from_0_to_1: float) -> Phi:
         | 
| 278 | 
            +
                    return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def extra_repr(self) -> str:
         | 
| 281 | 
            +
                    return f'ticks={self.ticks}'
         | 
    	
        models/var.py
    ADDED
    
    | @@ -0,0 +1,360 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from functools import partial
         | 
| 3 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from huggingface_hub import PyTorchModelHubMixin
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import dist
         | 
| 10 | 
            +
            from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
         | 
| 11 | 
            +
            from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
         | 
| 12 | 
            +
            from models.vqvae import VQVAE, VectorQuantizer2
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class SharedAdaLin(nn.Linear):
         | 
| 16 | 
            +
                def forward(self, cond_BD):
         | 
| 17 | 
            +
                    C = self.weight.shape[0] // 6
         | 
| 18 | 
            +
                    return super().forward(cond_BD).view(-1, 1, 6, C)  # B16C
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class VAR(nn.Module):
         | 
| 22 | 
            +
                def __init__(
         | 
| 23 | 
            +
                        self, vae_local: VQVAE,
         | 
| 24 | 
            +
                        num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.,
         | 
| 25 | 
            +
                        drop_path_rate=0.,
         | 
| 26 | 
            +
                        norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
         | 
| 27 | 
            +
                        attn_l2_norm=False,
         | 
| 28 | 
            +
                        patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),  # 10 steps by default
         | 
| 29 | 
            +
                        flash_if_available=True, fused_if_available=True,
         | 
| 30 | 
            +
                ):
         | 
| 31 | 
            +
                    super().__init__()
         | 
| 32 | 
            +
                    # 0. hyperparameters
         | 
| 33 | 
            +
                    assert embed_dim % num_heads == 0
         | 
| 34 | 
            +
                    self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
         | 
| 35 | 
            +
                    self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    self.cond_drop_rate = cond_drop_rate
         | 
| 38 | 
            +
                    self.prog_si = -1  # progressive training
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    self.patch_nums: Tuple[int] = patch_nums
         | 
| 41 | 
            +
                    self.L = sum(pn ** 2 for pn in self.patch_nums)
         | 
| 42 | 
            +
                    self.first_l = self.patch_nums[0] ** 2
         | 
| 43 | 
            +
                    self.begin_ends = []
         | 
| 44 | 
            +
                    cur = 0
         | 
| 45 | 
            +
                    for i, pn in enumerate(self.patch_nums):
         | 
| 46 | 
            +
                        self.begin_ends.append((cur, cur + pn ** 2))
         | 
| 47 | 
            +
                        cur += pn ** 2
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.num_stages_minus_1 = len(self.patch_nums) - 1
         | 
| 50 | 
            +
                    self.rng = torch.Generator(device="mps")
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # 1. input (word) embedding
         | 
| 53 | 
            +
                    quant: VectorQuantizer2 = vae_local.quantize
         | 
| 54 | 
            +
                    self.vae_proxy: Tuple[VQVAE] = (vae_local,)
         | 
| 55 | 
            +
                    self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
         | 
| 56 | 
            +
                    self.word_embed = nn.Linear(self.Cvae, self.C)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # 2. class embedding
         | 
| 59 | 
            +
                    init_std = math.sqrt(1 / self.C / 3)
         | 
| 60 | 
            +
                    self.num_classes = num_classes
         | 
| 61 | 
            +
                    self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32,
         | 
| 62 | 
            +
                                                   device=dist.get_device())
         | 
| 63 | 
            +
                    self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
         | 
| 64 | 
            +
                    nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
         | 
| 65 | 
            +
                    self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
         | 
| 66 | 
            +
                    nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # 3. absolute position embedding
         | 
| 69 | 
            +
                    pos_1LC = []
         | 
| 70 | 
            +
                    for i, pn in enumerate(self.patch_nums):
         | 
| 71 | 
            +
                        pe = torch.empty(1, pn * pn, self.C)
         | 
| 72 | 
            +
                        nn.init.trunc_normal_(pe, mean=0, std=init_std)
         | 
| 73 | 
            +
                        pos_1LC.append(pe)
         | 
| 74 | 
            +
                    pos_1LC = torch.cat(pos_1LC, dim=1)  # 1, L, C
         | 
| 75 | 
            +
                    assert tuple(pos_1LC.shape) == (1, self.L, self.C)
         | 
| 76 | 
            +
                    self.pos_1LC = nn.Parameter(pos_1LC)
         | 
| 77 | 
            +
                    # level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
         | 
| 78 | 
            +
                    self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
         | 
| 79 | 
            +
                    nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # 4. backbone blocks
         | 
| 82 | 
            +
                    self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False),
         | 
| 83 | 
            +
                                                        SharedAdaLin(self.D, 6 * self.C)) if shared_aln else nn.Identity()
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    norm_layer = partial(nn.LayerNorm, eps=norm_eps)
         | 
| 86 | 
            +
                    self.drop_path_rate = drop_path_rate
         | 
| 87 | 
            +
                    dpr = [x.item() for x in
         | 
| 88 | 
            +
                           torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule (linearly increasing)
         | 
| 89 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 90 | 
            +
                        AdaLNSelfAttn(
         | 
| 91 | 
            +
                            cond_dim=self.D, shared_aln=shared_aln,
         | 
| 92 | 
            +
                            block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
         | 
| 93 | 
            +
                            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx],
         | 
| 94 | 
            +
                            last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
         | 
| 95 | 
            +
                            attn_l2_norm=attn_l2_norm,
         | 
| 96 | 
            +
                            flash_if_available=flash_if_available, fused_if_available=fused_if_available,
         | 
| 97 | 
            +
                        )
         | 
| 98 | 
            +
                        for block_idx in range(depth)
         | 
| 99 | 
            +
                    ])
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
         | 
| 102 | 
            +
                    self.using_fused_add_norm_fn = any(fused_add_norm_fns)
         | 
| 103 | 
            +
                    print(
         | 
| 104 | 
            +
                        f'\n[constructor]  ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
         | 
| 105 | 
            +
                        f'    [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
         | 
| 106 | 
            +
                        f'    [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
         | 
| 107 | 
            +
                        end='\n\n', flush=True
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # 5. attention mask used in training (for masking out the future)
         | 
| 111 | 
            +
                    #    it won't be used in inference, since kv cache is enabled
         | 
| 112 | 
            +
                    d: torch.Tensor = torch.cat([torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L,
         | 
| 113 | 
            +
                                                                                                                          1)
         | 
| 114 | 
            +
                    dT = d.transpose(1, 2)  # dT: 11L
         | 
| 115 | 
            +
                    lvl_1L = dT[:, 0].contiguous()
         | 
| 116 | 
            +
                    self.register_buffer('lvl_1L', lvl_1L)
         | 
| 117 | 
            +
                    attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
         | 
| 118 | 
            +
                    self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    # 6. classifier head
         | 
| 121 | 
            +
                    self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
         | 
| 122 | 
            +
                    self.head = nn.Linear(self.C, self.V)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
         | 
| 125 | 
            +
                               cond_BD: Optional[torch.Tensor]):
         | 
| 126 | 
            +
                    if not isinstance(h_or_h_and_residual, torch.Tensor):
         | 
| 127 | 
            +
                        h, resi = h_or_h_and_residual  # fused_add_norm must be used
         | 
| 128 | 
            +
                        h = resi + self.blocks[-1].drop_path(h)
         | 
| 129 | 
            +
                    else:  # fused_add_norm is not used
         | 
| 130 | 
            +
                        h = h_or_h_and_residual
         | 
| 131 | 
            +
                    return self.head(self.head_nm(h.float(), cond_BD).float()).float()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                @torch.no_grad()
         | 
| 134 | 
            +
                def autoregressive_infer_cfg(
         | 
| 135 | 
            +
                        self, B: int, label_B: Optional[Union[int, torch.LongTensor]],
         | 
| 136 | 
            +
                        delta_condition: torch.Tensor, alpha: float, beta: float,
         | 
| 137 | 
            +
                        g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,
         | 
| 138 | 
            +
                        more_smooth=False,
         | 
| 139 | 
            +
                ) -> torch.Tensor:  # returns reconstructed image (B, 3, H, W) in [0, 1]
         | 
| 140 | 
            +
                    """
         | 
| 141 | 
            +
                    Generate images using autoregressive inference with classifier-free guidance.
         | 
| 142 | 
            +
                    :param B: batch size
         | 
| 143 | 
            +
                    :param label_B: class labels; if None, randomly sampled
         | 
| 144 | 
            +
                    :param delta_condition: tensor of shape (B, D)
         | 
| 145 | 
            +
                    :param alpha: scalar weight for class embedding
         | 
| 146 | 
            +
                    :param beta: scalar weight for delta_condition
         | 
| 147 | 
            +
                    :param g_seed: random seed
         | 
| 148 | 
            +
                    :param cfg: classifier-free guidance ratio
         | 
| 149 | 
            +
                    :param top_k: top-k sampling
         | 
| 150 | 
            +
                    :param top_p: top-p sampling
         | 
| 151 | 
            +
                    :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
         | 
| 152 | 
            +
                    :return: reconstructed images (B, 3, H, W)
         | 
| 153 | 
            +
                    """
         | 
| 154 | 
            +
                    if g_seed is None:
         | 
| 155 | 
            +
                        rng = None
         | 
| 156 | 
            +
                    else:
         | 
| 157 | 
            +
                        self.rng.manual_seed(g_seed)
         | 
| 158 | 
            +
                        rng = self.rng
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    device = self.lvl_1L.device
         | 
| 161 | 
            +
                    if label_B is None:
         | 
| 162 | 
            +
                        label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
         | 
| 163 | 
            +
                    elif isinstance(label_B, int):
         | 
| 164 | 
            +
                        label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=device)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    # Prepare labels for conditioned and unconditioned versions
         | 
| 167 | 
            +
                    label_B_cond = label_B
         | 
| 168 | 
            +
                    label_B_uncond = torch.full_like(label_B, fill_value=self.num_classes)
         | 
| 169 | 
            +
                    label_B = torch.cat((label_B_cond, label_B_uncond), dim=0)  # shape (2B,)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    # Prepare delta_condition for conditioned and unconditioned versions
         | 
| 172 | 
            +
                    delta_condition_uncond = torch.zeros_like(delta_condition)
         | 
| 173 | 
            +
                    delta_condition = torch.cat((delta_condition, delta_condition_uncond), dim=0)  # shape (2B, D)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    class_emb = self.class_emb(label_B)  # shape (2B, D)
         | 
| 176 | 
            +
                    cond_BD = alpha * class_emb + beta * delta_condition  # shape (2B, D)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    sos = cond_BD.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
         | 
| 181 | 
            +
                    next_token_map = sos + lvl_pos[:, :self.first_l]
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    cur_L = 0
         | 
| 184 | 
            +
                    f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    for b in self.blocks:
         | 
| 187 | 
            +
                        b.attn.kv_caching(True)
         | 
| 188 | 
            +
                    for si, pn in enumerate(self.patch_nums):  # si: i-th segment
         | 
| 189 | 
            +
                        ratio = si / self.num_stages_minus_1
         | 
| 190 | 
            +
                        cur_L += pn * pn
         | 
| 191 | 
            +
                        cond_BD_or_gss = self.shared_ada_lin(cond_BD)
         | 
| 192 | 
            +
                        x = next_token_map
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                        for b in self.blocks:
         | 
| 195 | 
            +
                            x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
         | 
| 196 | 
            +
                        logits_BlV = self.get_logits(x, cond_BD)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                        t = cfg * ratio
         | 
| 199 | 
            +
                        logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                        idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
         | 
| 202 | 
            +
                        if not more_smooth:  # this is the default case
         | 
| 203 | 
            +
                            h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl)  # B, l, Cvae
         | 
| 204 | 
            +
                        else:  # not used when evaluating FID/IS/Precision/Recall
         | 
| 205 | 
            +
                            gum_t = max(0.27 * (1 - ratio * 0.95), 0.005)  # refer to mask-git
         | 
| 206 | 
            +
                            h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ \
         | 
| 207 | 
            +
                                     self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                        h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
         | 
| 210 | 
            +
                        f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums),
         | 
| 211 | 
            +
                                                                                                      f_hat, h_BChw)
         | 
| 212 | 
            +
                        if si != self.num_stages_minus_1:  # prepare for next stage
         | 
| 213 | 
            +
                            next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
         | 
| 214 | 
            +
                            next_token_map = self.word_embed(next_token_map) + lvl_pos[:,
         | 
| 215 | 
            +
                                                                               cur_L:cur_L + self.patch_nums[si + 1] ** 2]
         | 
| 216 | 
            +
                            next_token_map = next_token_map.repeat(2, 1, 1)  # double the batch sizes due to CFG
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    for b in self.blocks:
         | 
| 219 | 
            +
                        b.attn.kv_caching(False)
         | 
| 220 | 
            +
                    return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5)  # de-normalize, from [-1, 1] to [0, 1]
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor, delta_condition: torch.Tensor,
         | 
| 223 | 
            +
                            alpha: float, beta: float) -> torch.Tensor:
         | 
| 224 | 
            +
                    """
         | 
| 225 | 
            +
                    :param label_B: label_B
         | 
| 226 | 
            +
                    :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
         | 
| 227 | 
            +
                    :param delta_condition: tensor of shape (B, D)
         | 
| 228 | 
            +
                    :param alpha: scalar weight for class embedding
         | 
| 229 | 
            +
                    :param beta: scalar weight for delta_condition
         | 
| 230 | 
            +
                    :return: logits BLV, V is vocab_size
         | 
| 231 | 
            +
                    """
         | 
| 232 | 
            +
                    bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
         | 
| 233 | 
            +
                    B = x_BLCv_wo_first_l.shape[0]
         | 
| 234 | 
            +
                    with torch.cuda.amp.autocast(enabled=False):
         | 
| 235 | 
            +
                        # Implement conditional dropout
         | 
| 236 | 
            +
                        drop_mask = torch.rand(B, device=label_B.device) < self.cond_drop_rate
         | 
| 237 | 
            +
                        label_B_dropped = torch.where(drop_mask, self.num_classes, label_B)
         | 
| 238 | 
            +
                        delta_condition_dropped = delta_condition.clone()
         | 
| 239 | 
            +
                        delta_condition_dropped[drop_mask] = 0.0  # Drop delta_condition
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        class_emb = self.class_emb(label_B_dropped)
         | 
| 242 | 
            +
                        cond_BD = alpha * class_emb + beta * delta_condition_dropped
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                        sos = cond_BD.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                        if self.prog_si == 0:
         | 
| 247 | 
            +
                            x_BLC = sos
         | 
| 248 | 
            +
                        else:
         | 
| 249 | 
            +
                            x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
         | 
| 250 | 
            +
                        x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed]  # lvl: BLC;  pos: 1LC
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
         | 
| 253 | 
            +
                    cond_BD_or_gss = self.shared_ada_lin(cond_BD)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    # hack: get the dtype if mixed precision is used
         | 
| 256 | 
            +
                    temp = x_BLC.new_ones(8, 8)
         | 
| 257 | 
            +
                    main_type = torch.matmul(temp, temp).dtype
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    x_BLC = x_BLC.to(dtype=main_type)
         | 
| 260 | 
            +
                    cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
         | 
| 261 | 
            +
                    attn_bias = attn_bias.to(dtype=main_type)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    AdaLNSelfAttn.forward
         | 
| 264 | 
            +
                    for i, b in enumerate(self.blocks):
         | 
| 265 | 
            +
                        x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
         | 
| 266 | 
            +
                    x_BLC = self.get_logits(x_BLC.float(), cond_BD)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    if self.prog_si == 0:
         | 
| 269 | 
            +
                        if isinstance(self.word_embed, nn.Linear):
         | 
| 270 | 
            +
                            x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
         | 
| 271 | 
            +
                        else:
         | 
| 272 | 
            +
                            s = 0
         | 
| 273 | 
            +
                            for p in self.word_embed.parameters():
         | 
| 274 | 
            +
                                if p.requires_grad:
         | 
| 275 | 
            +
                                    s += p.view(-1)[0] * 0
         | 
| 276 | 
            +
                            x_BLC[0, 0, 0] += s
         | 
| 277 | 
            +
                    return x_BLC  # logits BLV, V is vocab_size
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
         | 
| 280 | 
            +
                    if init_std < 0: init_std = (1 / self.C / 3) ** 0.5  # init_std < 0: automated
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
         | 
| 283 | 
            +
                    for m in self.modules():
         | 
| 284 | 
            +
                        with_weight = hasattr(m, 'weight') and m.weight is not None
         | 
| 285 | 
            +
                        with_bias = hasattr(m, 'bias') and m.bias is not None
         | 
| 286 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 287 | 
            +
                            nn.init.trunc_normal_(m.weight.data, std=init_std)
         | 
| 288 | 
            +
                            if with_bias: m.bias.data.zero_()
         | 
| 289 | 
            +
                        elif isinstance(m, nn.Embedding):
         | 
| 290 | 
            +
                            nn.init.trunc_normal_(m.weight.data, std=init_std)
         | 
| 291 | 
            +
                            if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
         | 
| 292 | 
            +
                        elif isinstance(m, (
         | 
| 293 | 
            +
                        nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm,
         | 
| 294 | 
            +
                        nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
         | 
| 295 | 
            +
                            if with_weight: m.weight.data.fill_(1.)
         | 
| 296 | 
            +
                            if with_bias: m.bias.data.zero_()
         | 
| 297 | 
            +
                        # conv: VAR has no conv, only VQVAE has conv
         | 
| 298 | 
            +
                        elif isinstance(m, (
         | 
| 299 | 
            +
                        nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
         | 
| 300 | 
            +
                            if conv_std_or_gain > 0:
         | 
| 301 | 
            +
                                nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
         | 
| 302 | 
            +
                            else:
         | 
| 303 | 
            +
                                nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
         | 
| 304 | 
            +
                            if with_bias: m.bias.data.zero_()
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    if init_head >= 0:
         | 
| 307 | 
            +
                        if isinstance(self.head, nn.Linear):
         | 
| 308 | 
            +
                            self.head.weight.data.mul_(init_head)
         | 
| 309 | 
            +
                            self.head.bias.data.zero_()
         | 
| 310 | 
            +
                        elif isinstance(self.head, nn.Sequential):
         | 
| 311 | 
            +
                            self.head[-1].weight.data.mul_(init_head)
         | 
| 312 | 
            +
                            self.head[-1].bias.data.zero_()
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    if isinstance(self.head_nm, AdaLNBeforeHead):
         | 
| 315 | 
            +
                        self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
         | 
| 316 | 
            +
                        if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
         | 
| 317 | 
            +
                            self.head_nm.ada_lin[-1].bias.data.zero_()
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    depth = len(self.blocks)
         | 
| 320 | 
            +
                    for block_idx, sab in enumerate(self.blocks):
         | 
| 321 | 
            +
                        sab: AdaLNSelfAttn
         | 
| 322 | 
            +
                        sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
         | 
| 323 | 
            +
                        sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
         | 
| 324 | 
            +
                        if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
         | 
| 325 | 
            +
                            nn.init.ones_(sab.ffn.fcg.bias)
         | 
| 326 | 
            +
                            nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
         | 
| 327 | 
            +
                        if hasattr(sab, 'ada_lin'):
         | 
| 328 | 
            +
                            sab.ada_lin[-1].weight.data[2 * self.C:].mul_(init_adaln)
         | 
| 329 | 
            +
                            sab.ada_lin[-1].weight.data[:2 * self.C].mul_(init_adaln_gamma)
         | 
| 330 | 
            +
                            if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
         | 
| 331 | 
            +
                                sab.ada_lin[-1].bias.data.zero_()
         | 
| 332 | 
            +
                        elif hasattr(sab, 'ada_gss'):
         | 
| 333 | 
            +
                            sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
         | 
| 334 | 
            +
                            sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                def extra_repr(self):
         | 
| 337 | 
            +
                    return f'drop_path_rate={self.drop_path_rate:g}'
         | 
| 338 | 
            +
             | 
| 339 | 
            +
             | 
| 340 | 
            +
            class VARHF(VAR, PyTorchModelHubMixin):
         | 
| 341 | 
            +
                def __init__(
         | 
| 342 | 
            +
                        self,
         | 
| 343 | 
            +
                        vae_kwargs,
         | 
| 344 | 
            +
                        num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.,
         | 
| 345 | 
            +
                        drop_path_rate=0.,
         | 
| 346 | 
            +
                        norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
         | 
| 347 | 
            +
                        attn_l2_norm=False,
         | 
| 348 | 
            +
                        patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),  # 10 steps by default
         | 
| 349 | 
            +
                        flash_if_available=True, fused_if_available=True,
         | 
| 350 | 
            +
                ):
         | 
| 351 | 
            +
                    vae_local = VQVAE(**vae_kwargs)
         | 
| 352 | 
            +
                    super().__init__(
         | 
| 353 | 
            +
                        vae_local=vae_local,
         | 
| 354 | 
            +
                        num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
         | 
| 355 | 
            +
                        drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
         | 
| 356 | 
            +
                        norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
         | 
| 357 | 
            +
                        attn_l2_norm=attn_l2_norm,
         | 
| 358 | 
            +
                        patch_nums=patch_nums,
         | 
| 359 | 
            +
                        flash_if_available=flash_if_available, fused_if_available=fused_if_available,
         | 
| 360 | 
            +
                    )
         | 
    	
        models/vqvae.py
    ADDED
    
    | @@ -0,0 +1,95 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            References:
         | 
| 3 | 
            +
            - VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
         | 
| 4 | 
            +
            - GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
         | 
| 5 | 
            +
            - VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
         | 
| 6 | 
            +
            """
         | 
| 7 | 
            +
            from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .basic_vae import Decoder, Encoder
         | 
| 13 | 
            +
            from .quant import VectorQuantizer2
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class VQVAE(nn.Module):
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                    self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
         | 
| 19 | 
            +
                    beta=0.25,              # commitment loss weight
         | 
| 20 | 
            +
                    using_znorm=False,      # whether to normalize when computing the nearest neighbors
         | 
| 21 | 
            +
                    quant_conv_ks=3,        # quant conv kernel size
         | 
| 22 | 
            +
                    quant_resi=0.5,         # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
         | 
| 23 | 
            +
                    share_quant_resi=4,     # use 4 \phi layers for K scales: partially-shared \phi
         | 
| 24 | 
            +
                    default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
         | 
| 25 | 
            +
                    v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
         | 
| 26 | 
            +
                    test_mode=True,
         | 
| 27 | 
            +
                ):
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
                    self.test_mode = test_mode
         | 
| 30 | 
            +
                    self.V, self.Cvae = vocab_size, z_channels
         | 
| 31 | 
            +
                    # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
         | 
| 32 | 
            +
                    ddconfig = dict(
         | 
| 33 | 
            +
                        dropout=dropout, ch=ch, z_channels=z_channels,
         | 
| 34 | 
            +
                        in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2,   # from vq-f16/config.yaml above
         | 
| 35 | 
            +
                        using_sa=True, using_mid_sa=True,                           # from vq-f16/config.yaml above
         | 
| 36 | 
            +
                        # resamp_with_conv=True,   # always True, removed.
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
                    ddconfig.pop('double_z', None)  # only KL-VAE should use double_z=True
         | 
| 39 | 
            +
                    self.encoder = Encoder(double_z=False, **ddconfig)
         | 
| 40 | 
            +
                    self.decoder = Decoder(**ddconfig)
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    self.vocab_size = vocab_size
         | 
| 43 | 
            +
                    self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
         | 
| 44 | 
            +
                    self.quantize: VectorQuantizer2 = VectorQuantizer2(
         | 
| 45 | 
            +
                        vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
         | 
| 46 | 
            +
                        default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
         | 
| 49 | 
            +
                    self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    if self.test_mode:
         | 
| 52 | 
            +
                        self.eval()
         | 
| 53 | 
            +
                        [p.requires_grad_(False) for p in self.parameters()]
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                # ===================== `forward` is only used in VAE training =====================
         | 
| 56 | 
            +
                def forward(self, inp, ret_usages=False):   # -> rec_B3HW, idx_N, loss
         | 
| 57 | 
            +
                    VectorQuantizer2.forward
         | 
| 58 | 
            +
                    f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
         | 
| 59 | 
            +
                    return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
         | 
| 60 | 
            +
                # ===================== `forward` is only used in VAE training =====================
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                def fhat_to_img(self, f_hat: torch.Tensor):
         | 
| 63 | 
            +
                    return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
         | 
| 64 | 
            +
                
         | 
| 65 | 
            +
                def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]:    # return List[Bl]
         | 
| 66 | 
            +
                    f = self.quant_conv(self.encoder(inp_img_no_grad))
         | 
| 67 | 
            +
                    return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
         | 
| 70 | 
            +
                    B = ms_idx_Bl[0].shape[0]
         | 
| 71 | 
            +
                    ms_h_BChw = []
         | 
| 72 | 
            +
                    for idx_Bl in ms_idx_Bl:
         | 
| 73 | 
            +
                        l = idx_Bl.shape[1]
         | 
| 74 | 
            +
                        pn = round(l ** 0.5)
         | 
| 75 | 
            +
                        ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
         | 
| 76 | 
            +
                    return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
         | 
| 79 | 
            +
                    if last_one:
         | 
| 80 | 
            +
                        return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
         | 
| 85 | 
            +
                    f = self.quant_conv(self.encoder(x))
         | 
| 86 | 
            +
                    ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
         | 
| 87 | 
            +
                    if last_one:
         | 
| 88 | 
            +
                        return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
         | 
| 89 | 
            +
                    else:
         | 
| 90 | 
            +
                        return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
         | 
| 91 | 
            +
                
         | 
| 92 | 
            +
                def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
         | 
| 93 | 
            +
                    if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
         | 
| 94 | 
            +
                        state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
         | 
| 95 | 
            +
                    return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
         | 
    	
        utils/amp_sc.py
    ADDED
    
    | @@ -0,0 +1,89 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class NullCtx:
         | 
| 8 | 
            +
                def __enter__(self):
         | 
| 9 | 
            +
                    pass
         | 
| 10 | 
            +
                
         | 
| 11 | 
            +
                def __exit__(self, exc_type, exc_val, exc_tb):
         | 
| 12 | 
            +
                    pass
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class AmpOptimizer:
         | 
| 16 | 
            +
                def __init__(
         | 
| 17 | 
            +
                    self,
         | 
| 18 | 
            +
                    mixed_precision: int,
         | 
| 19 | 
            +
                    optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],
         | 
| 20 | 
            +
                    grad_clip: float, n_gradient_accumulation: int = 1,
         | 
| 21 | 
            +
                ):
         | 
| 22 | 
            +
                    self.enable_amp = mixed_precision > 0
         | 
| 23 | 
            +
                    self.using_fp16_rather_bf16 = mixed_precision == 1
         | 
| 24 | 
            +
                    
         | 
| 25 | 
            +
                    if self.enable_amp:
         | 
| 26 | 
            +
                        self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)
         | 
| 27 | 
            +
                        self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler
         | 
| 28 | 
            +
                    else:
         | 
| 29 | 
            +
                        self.amp_ctx = NullCtx()
         | 
| 30 | 
            +
                        self.scaler = None
         | 
| 31 | 
            +
                    
         | 
| 32 | 
            +
                    self.optimizer, self.names, self.paras = optimizer, names, paras   # paras have been filtered so everyone requires grad
         | 
| 33 | 
            +
                    self.grad_clip = grad_clip
         | 
| 34 | 
            +
                    self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
         | 
| 35 | 
            +
                    self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')
         | 
| 36 | 
            +
                    
         | 
| 37 | 
            +
                    self.r_accu = 1 / n_gradient_accumulation   # r_accu == 1.0 / n_gradient_accumulation
         | 
| 38 | 
            +
                
         | 
| 39 | 
            +
                def backward_clip_step(
         | 
| 40 | 
            +
                    self, stepping: bool, loss: torch.Tensor,
         | 
| 41 | 
            +
                ) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:
         | 
| 42 | 
            +
                    # backward
         | 
| 43 | 
            +
                    loss = loss.mul(self.r_accu)   # r_accu == 1.0 / n_gradient_accumulation
         | 
| 44 | 
            +
                    orig_norm = scaler_sc = None
         | 
| 45 | 
            +
                    if self.scaler is not None:
         | 
| 46 | 
            +
                        self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        loss.backward(retain_graph=False, create_graph=False)
         | 
| 49 | 
            +
                    
         | 
| 50 | 
            +
                    if stepping:
         | 
| 51 | 
            +
                        if self.scaler is not None: self.scaler.unscale_(self.optimizer)
         | 
| 52 | 
            +
                        if self.early_clipping:
         | 
| 53 | 
            +
                            orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)
         | 
| 54 | 
            +
                        
         | 
| 55 | 
            +
                        if self.scaler is not None:
         | 
| 56 | 
            +
                            self.scaler.step(self.optimizer)
         | 
| 57 | 
            +
                            scaler_sc: float = self.scaler.get_scale()
         | 
| 58 | 
            +
                            if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
         | 
| 59 | 
            +
                                self.scaler.update(new_scale=32768.)
         | 
| 60 | 
            +
                            else:
         | 
| 61 | 
            +
                                self.scaler.update()
         | 
| 62 | 
            +
                            try:
         | 
| 63 | 
            +
                                scaler_sc = float(math.log2(scaler_sc))
         | 
| 64 | 
            +
                            except Exception as e:
         | 
| 65 | 
            +
                                print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
         | 
| 66 | 
            +
                                raise e
         | 
| 67 | 
            +
                        else:
         | 
| 68 | 
            +
                            self.optimizer.step()
         | 
| 69 | 
            +
                        
         | 
| 70 | 
            +
                        if self.late_clipping:
         | 
| 71 | 
            +
                            orig_norm = self.optimizer.global_grad_norm
         | 
| 72 | 
            +
                        
         | 
| 73 | 
            +
                        self.optimizer.zero_grad(set_to_none=True)
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    return orig_norm, scaler_sc
         | 
| 76 | 
            +
                
         | 
| 77 | 
            +
                def state_dict(self):
         | 
| 78 | 
            +
                    return {
         | 
| 79 | 
            +
                        'optimizer': self.optimizer.state_dict()
         | 
| 80 | 
            +
                    } if self.scaler is None else {
         | 
| 81 | 
            +
                        'scaler': self.scaler.state_dict(),
         | 
| 82 | 
            +
                        'optimizer': self.optimizer.state_dict()
         | 
| 83 | 
            +
                    }
         | 
| 84 | 
            +
                
         | 
| 85 | 
            +
                def load_state_dict(self, state, strict=True):
         | 
| 86 | 
            +
                    if self.scaler is not None:
         | 
| 87 | 
            +
                        try: self.scaler.load_state_dict(state['scaler'])
         | 
| 88 | 
            +
                        except Exception as e: print(f'[fp16 load_state_dict err] {e}')
         | 
| 89 | 
            +
                    self.optimizer.load_state_dict(state['optimizer'])
         | 
    	
        utils/arg_util.py
    ADDED
    
    | @@ -0,0 +1,284 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import re
         | 
| 5 | 
            +
            import subprocess
         | 
| 6 | 
            +
            import sys
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            from collections import OrderedDict
         | 
| 9 | 
            +
            from typing import Optional, Union
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            try:
         | 
| 15 | 
            +
                from tap import Tap
         | 
| 16 | 
            +
            except ImportError as e:
         | 
| 17 | 
            +
                print(f'`>>>>>>>> from tap import Tap` failed, please run:      pip3 install typed-argument-parser     <<<<<<<<', file=sys.stderr, flush=True)
         | 
| 18 | 
            +
                print(f'`>>>>>>>> from tap import Tap` failed, please run:      pip3 install typed-argument-parser     <<<<<<<<', file=sys.stderr, flush=True)
         | 
| 19 | 
            +
                time.sleep(5)
         | 
| 20 | 
            +
                raise e
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import dist
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class Args(Tap):
         | 
| 26 | 
            +
                data_path: str = '/path/to/imagenet'
         | 
| 27 | 
            +
                exp_name: str = 'text'
         | 
| 28 | 
            +
                
         | 
| 29 | 
            +
                # VAE
         | 
| 30 | 
            +
                vfast: int = 0      # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
         | 
| 31 | 
            +
                # VAR
         | 
| 32 | 
            +
                tfast: int = 0      # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
         | 
| 33 | 
            +
                depth: int = 16     # VAR depth
         | 
| 34 | 
            +
                # VAR initialization
         | 
| 35 | 
            +
                ini: float = -1     # -1: automated model parameter initialization
         | 
| 36 | 
            +
                hd: float = 0.02    # head.w *= hd
         | 
| 37 | 
            +
                aln: float = 0.5    # the multiplier of ada_lin.w's initialization
         | 
| 38 | 
            +
                alng: float = 1e-5  # the multiplier of ada_lin.w[gamma channels]'s initialization
         | 
| 39 | 
            +
                # VAR optimization
         | 
| 40 | 
            +
                fp16: int = 0           # 1: using fp16, 2: bf16
         | 
| 41 | 
            +
                tblr: float = 1e-4      # base lr
         | 
| 42 | 
            +
                tlr: float = None       # lr = base lr * (bs / 256)
         | 
| 43 | 
            +
                twd: float = 0.05       # initial wd
         | 
| 44 | 
            +
                twde: float = 0         # final wd, =twde or twd
         | 
| 45 | 
            +
                tclip: float = 2.       # <=0 for not using grad clip
         | 
| 46 | 
            +
                ls: float = 0.0         # label smooth
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                bs: int = 768           # global batch size
         | 
| 49 | 
            +
                batch_size: int = 0     # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
         | 
| 50 | 
            +
                glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
         | 
| 51 | 
            +
                ac: int = 1             # gradient accumulation
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                ep: int = 250
         | 
| 54 | 
            +
                wp: float = 0
         | 
| 55 | 
            +
                wp0: float = 0.005      # initial lr ratio at the begging of lr warm up
         | 
| 56 | 
            +
                wpe: float = 0.01       # final lr ratio at the end of training
         | 
| 57 | 
            +
                sche: str = 'lin0'      # lr schedule
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                opt: str = 'adamw'      # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
         | 
| 60 | 
            +
                afuse: bool = True      # fused adamw
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                # other hps
         | 
| 63 | 
            +
                saln: bool = False      # whether to use shared adaln
         | 
| 64 | 
            +
                anorm: bool = True      # whether to use L2 normalized attention
         | 
| 65 | 
            +
                fuse: bool = True       # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                # data
         | 
| 68 | 
            +
                pn: str = '1_2_3_4_5_6_8_10_13_16'
         | 
| 69 | 
            +
                patch_size: int = 16
         | 
| 70 | 
            +
                patch_nums: tuple = None    # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
         | 
| 71 | 
            +
                resos: tuple = None         # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                data_load_reso: int = None  # [automatically set; don't specify this] would be max(patch_nums) * patch_size
         | 
| 74 | 
            +
                mid_reso: float = 1.125     # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
         | 
| 75 | 
            +
                hflip: bool = False         # augmentation: horizontal flip
         | 
| 76 | 
            +
                workers: int = 0        # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                # progressive training
         | 
| 79 | 
            +
                pg: float = 0.0         # >0 for use progressive training during [0%, this] of training
         | 
| 80 | 
            +
                pg0: int = 4            # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
         | 
| 81 | 
            +
                pgwp: float = 0         # num of warmup epochs at each progressive stage
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                # would be automatically set in runtime
         | 
| 84 | 
            +
                cmd: str = ' '.join(sys.argv[1:])  # [automatically set; don't specify this]
         | 
| 85 | 
            +
                branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
         | 
| 86 | 
            +
                commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]'  # [automatically set; don't specify this]
         | 
| 87 | 
            +
                commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip()    # [automatically set; don't specify this]
         | 
| 88 | 
            +
                acc_mean: float = None      # [automatically set; don't specify this]
         | 
| 89 | 
            +
                acc_tail: float = None      # [automatically set; don't specify this]
         | 
| 90 | 
            +
                L_mean: float = None        # [automatically set; don't specify this]
         | 
| 91 | 
            +
                L_tail: float = None        # [automatically set; don't specify this]
         | 
| 92 | 
            +
                vacc_mean: float = None     # [automatically set; don't specify this]
         | 
| 93 | 
            +
                vacc_tail: float = None     # [automatically set; don't specify this]
         | 
| 94 | 
            +
                vL_mean: float = None       # [automatically set; don't specify this]
         | 
| 95 | 
            +
                vL_tail: float = None       # [automatically set; don't specify this]
         | 
| 96 | 
            +
                grad_norm: float = None     # [automatically set; don't specify this]
         | 
| 97 | 
            +
                cur_lr: float = None        # [automatically set; don't specify this]
         | 
| 98 | 
            +
                cur_wd: float = None        # [automatically set; don't specify this]
         | 
| 99 | 
            +
                cur_it: str = ''            # [automatically set; don't specify this]
         | 
| 100 | 
            +
                cur_ep: str = ''            # [automatically set; don't specify this]
         | 
| 101 | 
            +
                remain_time: str = ''       # [automatically set; don't specify this]
         | 
| 102 | 
            +
                finish_time: str = ''       # [automatically set; don't specify this]
         | 
| 103 | 
            +
                
         | 
| 104 | 
            +
                # environment
         | 
| 105 | 
            +
                local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output')  # [automatically set; don't specify this]
         | 
| 106 | 
            +
                tb_log_dir_path: str = '...tb-...'  # [automatically set; don't specify this]
         | 
| 107 | 
            +
                log_txt_path: str = '...'           # [automatically set; don't specify this]
         | 
| 108 | 
            +
                last_ckpt_path: str = '...'         # [automatically set; don't specify this]
         | 
| 109 | 
            +
                
         | 
| 110 | 
            +
                tf32: bool = True       # whether to use TensorFloat32
         | 
| 111 | 
            +
                device: str = 'cpu'     # [automatically set; don't specify this]
         | 
| 112 | 
            +
                seed: int = None        # seed
         | 
| 113 | 
            +
                def seed_everything(self, benchmark: bool):
         | 
| 114 | 
            +
                    torch.backends.cudnn.enabled = True
         | 
| 115 | 
            +
                    torch.backends.cudnn.benchmark = benchmark
         | 
| 116 | 
            +
                    if self.seed is None:
         | 
| 117 | 
            +
                        torch.backends.cudnn.deterministic = False
         | 
| 118 | 
            +
                    else:
         | 
| 119 | 
            +
                        torch.backends.cudnn.deterministic = True
         | 
| 120 | 
            +
                        seed = self.seed * dist.get_world_size() + dist.get_rank()
         | 
| 121 | 
            +
                        os.environ['PYTHONHASHSEED'] = str(seed)
         | 
| 122 | 
            +
                        random.seed(seed)
         | 
| 123 | 
            +
                        np.random.seed(seed)
         | 
| 124 | 
            +
                        torch.manual_seed(seed)
         | 
| 125 | 
            +
                        if torch.cuda.is_available():
         | 
| 126 | 
            +
                            torch.cuda.manual_seed(seed)
         | 
| 127 | 
            +
                            torch.cuda.manual_seed_all(seed)
         | 
| 128 | 
            +
                same_seed_for_all_ranks: int = 0     # this is only for distributed sampler
         | 
| 129 | 
            +
                def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]:   # for random augmentation
         | 
| 130 | 
            +
                    if self.seed is None:
         | 
| 131 | 
            +
                        return None
         | 
| 132 | 
            +
                    g = torch.Generator()
         | 
| 133 | 
            +
                    g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
         | 
| 134 | 
            +
                    return g
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                local_debug: bool = 'KEVIN_LOCAL' in os.environ
         | 
| 137 | 
            +
                dbg_nan: bool = False   # 'KEVIN_LOCAL' in os.environ
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                def compile_model(self, m, fast):
         | 
| 140 | 
            +
                    if fast == 0 or self.local_debug:
         | 
| 141 | 
            +
                        return m
         | 
| 142 | 
            +
                    return torch.compile(m, mode={
         | 
| 143 | 
            +
                        1: 'reduce-overhead',
         | 
| 144 | 
            +
                        2: 'max-autotune',
         | 
| 145 | 
            +
                        3: 'default',
         | 
| 146 | 
            +
                    }[fast]) if hasattr(torch, 'compile') else m
         | 
| 147 | 
            +
                
         | 
| 148 | 
            +
                def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
         | 
| 149 | 
            +
                    d = (OrderedDict if key_ordered else dict)()
         | 
| 150 | 
            +
                    # self.as_dict() would contain methods, but we only need variables
         | 
| 151 | 
            +
                    for k in self.class_variables.keys():
         | 
| 152 | 
            +
                        if k not in {'device'}:     # these are not serializable
         | 
| 153 | 
            +
                            d[k] = getattr(self, k)
         | 
| 154 | 
            +
                    return d
         | 
| 155 | 
            +
                
         | 
| 156 | 
            +
                def load_state_dict(self, d: Union[OrderedDict, dict, str]):
         | 
| 157 | 
            +
                    if isinstance(d, str):  # for compatibility with old version
         | 
| 158 | 
            +
                        d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
         | 
| 159 | 
            +
                    for k in d.keys():
         | 
| 160 | 
            +
                        try:
         | 
| 161 | 
            +
                            setattr(self, k, d[k])
         | 
| 162 | 
            +
                        except Exception as e:
         | 
| 163 | 
            +
                            print(f'k={k}, v={d[k]}')
         | 
| 164 | 
            +
                            raise e
         | 
| 165 | 
            +
                
         | 
| 166 | 
            +
                @staticmethod
         | 
| 167 | 
            +
                def set_tf32(tf32: bool):
         | 
| 168 | 
            +
                    if torch.cuda.is_available():
         | 
| 169 | 
            +
                        torch.backends.cudnn.allow_tf32 = bool(tf32)
         | 
| 170 | 
            +
                        torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
         | 
| 171 | 
            +
                        if hasattr(torch, 'set_float32_matmul_precision'):
         | 
| 172 | 
            +
                            torch.set_float32_matmul_precision('high' if tf32 else 'highest')
         | 
| 173 | 
            +
                            print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
         | 
| 174 | 
            +
                        print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
         | 
| 175 | 
            +
                        print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
         | 
| 176 | 
            +
                
         | 
| 177 | 
            +
                def dump_log(self):
         | 
| 178 | 
            +
                    if not dist.is_local_master():
         | 
| 179 | 
            +
                        return
         | 
| 180 | 
            +
                    if '1/' in self.cur_ep: # first time to dump log
         | 
| 181 | 
            +
                        with open(self.log_txt_path, 'w') as fp:
         | 
| 182 | 
            +
                            json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
         | 
| 183 | 
            +
                            fp.write('\n')
         | 
| 184 | 
            +
                    
         | 
| 185 | 
            +
                    log_dict = {}
         | 
| 186 | 
            +
                    for k, v in {
         | 
| 187 | 
            +
                        'it': self.cur_it, 'ep': self.cur_ep,
         | 
| 188 | 
            +
                        'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
         | 
| 189 | 
            +
                        'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
         | 
| 190 | 
            +
                        'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
         | 
| 191 | 
            +
                        'remain_time': self.remain_time, 'finish_time': self.finish_time,
         | 
| 192 | 
            +
                    }.items():
         | 
| 193 | 
            +
                        if hasattr(v, 'item'): v = v.item()
         | 
| 194 | 
            +
                        log_dict[k] = v
         | 
| 195 | 
            +
                    with open(self.log_txt_path, 'a') as fp:
         | 
| 196 | 
            +
                        fp.write(f'{log_dict}\n')
         | 
| 197 | 
            +
                
         | 
| 198 | 
            +
                def __str__(self):
         | 
| 199 | 
            +
                    s = []
         | 
| 200 | 
            +
                    for k in self.class_variables.keys():
         | 
| 201 | 
            +
                        if k not in {'device', 'dbg_ks_fp'}:     # these are not serializable
         | 
| 202 | 
            +
                            s.append(f'  {k:20s}: {getattr(self, k)}')
         | 
| 203 | 
            +
                    s = '\n'.join(s)
         | 
| 204 | 
            +
                    return f'{{\n{s}\n}}\n'
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            def init_dist_and_get_args():
         | 
| 208 | 
            +
                for i in range(len(sys.argv)):
         | 
| 209 | 
            +
                    if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
         | 
| 210 | 
            +
                        del sys.argv[i]
         | 
| 211 | 
            +
                        break
         | 
| 212 | 
            +
                args = Args(explicit_bool=True).parse_args(known_only=True)
         | 
| 213 | 
            +
                if args.local_debug:
         | 
| 214 | 
            +
                    args.pn = '1_2_3'
         | 
| 215 | 
            +
                    args.seed = 1
         | 
| 216 | 
            +
                    args.aln = 1e-2
         | 
| 217 | 
            +
                    args.alng = 1e-5
         | 
| 218 | 
            +
                    args.saln = False
         | 
| 219 | 
            +
                    args.afuse = False
         | 
| 220 | 
            +
                    args.pg = 0.8
         | 
| 221 | 
            +
                    args.pg0 = 1
         | 
| 222 | 
            +
                else:
         | 
| 223 | 
            +
                    if args.data_path == '/path/to/imagenet':
         | 
| 224 | 
            +
                        raise ValueError(f'{"*"*40}  please specify --data_path=/path/to/imagenet  {"*"*40}')
         | 
| 225 | 
            +
                
         | 
| 226 | 
            +
                # warn args.extra_args
         | 
| 227 | 
            +
                if len(args.extra_args) > 0:
         | 
| 228 | 
            +
                    print(f'======================================================================================')
         | 
| 229 | 
            +
                    print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
         | 
| 230 | 
            +
                    print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
         | 
| 231 | 
            +
                    print(f'======================================================================================\n\n')
         | 
| 232 | 
            +
                
         | 
| 233 | 
            +
                # init torch distributed
         | 
| 234 | 
            +
                from utils import misc
         | 
| 235 | 
            +
                os.makedirs(args.local_out_dir_path, exist_ok=True)
         | 
| 236 | 
            +
                misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
         | 
| 237 | 
            +
                
         | 
| 238 | 
            +
                # set env
         | 
| 239 | 
            +
                args.set_tf32(args.tf32)
         | 
| 240 | 
            +
                args.seed_everything(benchmark=args.pg == 0)
         | 
| 241 | 
            +
                
         | 
| 242 | 
            +
                # update args: data loading
         | 
| 243 | 
            +
                args.device = dist.get_device()
         | 
| 244 | 
            +
                if args.pn == '256':
         | 
| 245 | 
            +
                    args.pn = '1_2_3_4_5_6_8_10_13_16'
         | 
| 246 | 
            +
                elif args.pn == '512':
         | 
| 247 | 
            +
                    args.pn = '1_2_3_4_6_9_13_18_24_32'
         | 
| 248 | 
            +
                elif args.pn == '1024':
         | 
| 249 | 
            +
                    args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
         | 
| 250 | 
            +
                args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
         | 
| 251 | 
            +
                args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
         | 
| 252 | 
            +
                args.data_load_reso = max(args.resos)
         | 
| 253 | 
            +
                
         | 
| 254 | 
            +
                # update args: bs and lr
         | 
| 255 | 
            +
                bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
         | 
| 256 | 
            +
                args.batch_size = bs_per_gpu
         | 
| 257 | 
            +
                args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
         | 
| 258 | 
            +
                args.workers = min(max(0, args.workers), args.batch_size)
         | 
| 259 | 
            +
                
         | 
| 260 | 
            +
                args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
         | 
| 261 | 
            +
                args.twde = args.twde or args.twd
         | 
| 262 | 
            +
                
         | 
| 263 | 
            +
                if args.wp == 0:
         | 
| 264 | 
            +
                    args.wp = args.ep * 1/50
         | 
| 265 | 
            +
                
         | 
| 266 | 
            +
                # update args: progressive training
         | 
| 267 | 
            +
                if args.pgwp == 0:
         | 
| 268 | 
            +
                    args.pgwp = args.ep * 1/300
         | 
| 269 | 
            +
                if args.pg > 0:
         | 
| 270 | 
            +
                    args.sche = f'lin{args.pg:g}'
         | 
| 271 | 
            +
                
         | 
| 272 | 
            +
                # update args: paths
         | 
| 273 | 
            +
                args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
         | 
| 274 | 
            +
                args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
         | 
| 275 | 
            +
                _reg_valid_name = re.compile(r'[^\w\-+,.]')
         | 
| 276 | 
            +
                tb_name = _reg_valid_name.sub(
         | 
| 277 | 
            +
                    '_',
         | 
| 278 | 
            +
                    f'tb-VARd{args.depth}'
         | 
| 279 | 
            +
                    f'__pn{args.pn}'
         | 
| 280 | 
            +
                    f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'
         | 
| 281 | 
            +
                )
         | 
| 282 | 
            +
                args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
         | 
| 283 | 
            +
                
         | 
| 284 | 
            +
                return args
         | 
    	
        utils/data.py
    ADDED
    
    | @@ -0,0 +1,54 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os.path as osp
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import PIL.Image as PImage
         | 
| 4 | 
            +
            from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
         | 
| 5 | 
            +
            from torchvision.transforms import InterpolationMode, transforms
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def normalize_01_into_pm1(x):  # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
         | 
| 9 | 
            +
                return x.add(x).add_(-1)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def build_dataset(
         | 
| 13 | 
            +
                data_path: str, final_reso: int,
         | 
| 14 | 
            +
                hflip=False, mid_reso=1.125,
         | 
| 15 | 
            +
            ):
         | 
| 16 | 
            +
                # build augmentations
         | 
| 17 | 
            +
                mid_reso = round(mid_reso * final_reso)  # first resize to mid_reso, then crop to final_reso
         | 
| 18 | 
            +
                train_aug, val_aug = [
         | 
| 19 | 
            +
                    transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
         | 
| 20 | 
            +
                    transforms.RandomCrop((final_reso, final_reso)),
         | 
| 21 | 
            +
                    transforms.ToTensor(), normalize_01_into_pm1,
         | 
| 22 | 
            +
                ], [
         | 
| 23 | 
            +
                    transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
         | 
| 24 | 
            +
                    transforms.CenterCrop((final_reso, final_reso)),
         | 
| 25 | 
            +
                    transforms.ToTensor(), normalize_01_into_pm1,
         | 
| 26 | 
            +
                ]
         | 
| 27 | 
            +
                if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip())
         | 
| 28 | 
            +
                train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug)
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                # build dataset
         | 
| 31 | 
            +
                train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug)
         | 
| 32 | 
            +
                val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug)
         | 
| 33 | 
            +
                num_classes = 1000
         | 
| 34 | 
            +
                print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}')
         | 
| 35 | 
            +
                print_aug(train_aug, '[train]')
         | 
| 36 | 
            +
                print_aug(val_aug, '[val]')
         | 
| 37 | 
            +
                
         | 
| 38 | 
            +
                return num_classes, train_set, val_set
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def pil_loader(path):
         | 
| 42 | 
            +
                with open(path, 'rb') as f:
         | 
| 43 | 
            +
                    img: PImage.Image = PImage.open(f).convert('RGB')
         | 
| 44 | 
            +
                return img
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def print_aug(transform, label):
         | 
| 48 | 
            +
                print(f'Transform {label} = ')
         | 
| 49 | 
            +
                if hasattr(transform, 'transforms'):
         | 
| 50 | 
            +
                    for t in transform.transforms:
         | 
| 51 | 
            +
                        print(t)
         | 
| 52 | 
            +
                else:
         | 
| 53 | 
            +
                    print(transform)
         | 
| 54 | 
            +
                print('---------------------------\n')
         | 
    	
        utils/data_sampler.py
    ADDED
    
    | @@ -0,0 +1,103 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from torch.utils.data.sampler import Sampler
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class EvalDistributedSampler(Sampler):
         | 
| 7 | 
            +
                def __init__(self, dataset, num_replicas, rank):
         | 
| 8 | 
            +
                    seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int)
         | 
| 9 | 
            +
                    beg, end = seps[:-1], seps[1:]
         | 
| 10 | 
            +
                    beg, end = beg[rank], end[rank]
         | 
| 11 | 
            +
                    self.indices = tuple(range(beg, end))
         | 
| 12 | 
            +
                
         | 
| 13 | 
            +
                def __iter__(self):
         | 
| 14 | 
            +
                    return iter(self.indices)
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                def __len__(self) -> int:
         | 
| 17 | 
            +
                    return len(self.indices)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class InfiniteBatchSampler(Sampler):
         | 
| 21 | 
            +
                def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0):
         | 
| 22 | 
            +
                    self.dataset_len = dataset_len
         | 
| 23 | 
            +
                    self.batch_size = batch_size
         | 
| 24 | 
            +
                    self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
         | 
| 25 | 
            +
                    self.max_p = self.iters_per_ep * batch_size
         | 
| 26 | 
            +
                    self.fill_last = fill_last
         | 
| 27 | 
            +
                    self.shuffle = shuffle
         | 
| 28 | 
            +
                    self.epoch = start_ep
         | 
| 29 | 
            +
                    self.same_seed_for_all_ranks = seed_for_all_rank
         | 
| 30 | 
            +
                    self.indices = self.gener_indices()
         | 
| 31 | 
            +
                    self.start_ep, self.start_it = start_ep, start_it
         | 
| 32 | 
            +
                
         | 
| 33 | 
            +
                def gener_indices(self):
         | 
| 34 | 
            +
                    if self.shuffle:
         | 
| 35 | 
            +
                        g = torch.Generator()
         | 
| 36 | 
            +
                        g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
         | 
| 37 | 
            +
                        indices = torch.randperm(self.dataset_len, generator=g).numpy()
         | 
| 38 | 
            +
                    else:
         | 
| 39 | 
            +
                        indices = torch.arange(self.dataset_len).numpy()
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                    tails = self.batch_size - (self.dataset_len % self.batch_size)
         | 
| 42 | 
            +
                    if tails != self.batch_size and self.fill_last:
         | 
| 43 | 
            +
                        tails = indices[:tails]
         | 
| 44 | 
            +
                        np.random.shuffle(indices)
         | 
| 45 | 
            +
                        indices = np.concatenate((indices, tails))
         | 
| 46 | 
            +
                    
         | 
| 47 | 
            +
                    # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
         | 
| 48 | 
            +
                    # noinspection PyTypeChecker
         | 
| 49 | 
            +
                    return tuple(indices.tolist())
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
                def __iter__(self):
         | 
| 52 | 
            +
                    self.epoch = self.start_ep
         | 
| 53 | 
            +
                    while True:
         | 
| 54 | 
            +
                        self.epoch += 1
         | 
| 55 | 
            +
                        p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0
         | 
| 56 | 
            +
                        while p < self.max_p:
         | 
| 57 | 
            +
                            q = p + self.batch_size
         | 
| 58 | 
            +
                            yield self.indices[p:q]
         | 
| 59 | 
            +
                            p = q
         | 
| 60 | 
            +
                        if self.shuffle:
         | 
| 61 | 
            +
                            self.indices = self.gener_indices()
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                def __len__(self):
         | 
| 64 | 
            +
                    return self.iters_per_ep
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class DistInfiniteBatchSampler(InfiniteBatchSampler):
         | 
| 68 | 
            +
                def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0):
         | 
| 69 | 
            +
                    assert glb_batch_size % world_size == 0
         | 
| 70 | 
            +
                    self.world_size, self.rank = world_size, rank
         | 
| 71 | 
            +
                    self.dataset_len = dataset_len
         | 
| 72 | 
            +
                    self.glb_batch_size = glb_batch_size
         | 
| 73 | 
            +
                    self.batch_size = glb_batch_size // world_size
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
         | 
| 76 | 
            +
                    self.fill_last = fill_last
         | 
| 77 | 
            +
                    self.shuffle = shuffle
         | 
| 78 | 
            +
                    self.repeated_aug = repeated_aug
         | 
| 79 | 
            +
                    self.epoch = start_ep
         | 
| 80 | 
            +
                    self.same_seed_for_all_ranks = same_seed_for_all_ranks
         | 
| 81 | 
            +
                    self.indices = self.gener_indices()
         | 
| 82 | 
            +
                    self.start_ep, self.start_it = start_ep, start_it
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                def gener_indices(self):
         | 
| 85 | 
            +
                    global_max_p = self.iters_per_ep * self.glb_batch_size  # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
         | 
| 86 | 
            +
                    # print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}')
         | 
| 87 | 
            +
                    if self.shuffle:
         | 
| 88 | 
            +
                        g = torch.Generator()
         | 
| 89 | 
            +
                        g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
         | 
| 90 | 
            +
                        global_indices = torch.randperm(self.dataset_len, generator=g)
         | 
| 91 | 
            +
                        if self.repeated_aug > 1:
         | 
| 92 | 
            +
                            global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
         | 
| 93 | 
            +
                    else:
         | 
| 94 | 
            +
                        global_indices = torch.arange(self.dataset_len)
         | 
| 95 | 
            +
                    filling = global_max_p - global_indices.shape[0]
         | 
| 96 | 
            +
                    if filling > 0 and self.fill_last:
         | 
| 97 | 
            +
                        global_indices = torch.cat((global_indices, global_indices[:filling]))
         | 
| 98 | 
            +
                    # global_indices = tuple(global_indices.numpy().tolist())
         | 
| 99 | 
            +
                    
         | 
| 100 | 
            +
                    seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int)
         | 
| 101 | 
            +
                    local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist()
         | 
| 102 | 
            +
                    self.max_p = len(local_indices)
         | 
| 103 | 
            +
                    return local_indices
         | 
    	
        utils/lr_control.py
    ADDED
    
    | @@ -0,0 +1,108 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from pprint import pformat
         | 
| 3 | 
            +
            from typing import Tuple, List, Dict, Union
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch.nn
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import dist
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
         | 
| 11 | 
            +
                """Decay the learning rate with half-cycle cosine after warmup"""
         | 
| 12 | 
            +
                wp_it = round(wp_it)
         | 
| 13 | 
            +
                
         | 
| 14 | 
            +
                if cur_it < wp_it:
         | 
| 15 | 
            +
                    cur_lr = wp0 + (1-wp0) * cur_it / wp_it
         | 
| 16 | 
            +
                else:
         | 
| 17 | 
            +
                    pasd = (cur_it - wp_it) / (max_it-1 - wp_it)   # [0, 1]
         | 
| 18 | 
            +
                    rest = 1 - pasd     # [1, 0]
         | 
| 19 | 
            +
                    if sche_type == 'cos':
         | 
| 20 | 
            +
                        cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
         | 
| 21 | 
            +
                    elif sche_type == 'lin':
         | 
| 22 | 
            +
                        T = 0.15; max_rest = 1-T
         | 
| 23 | 
            +
                        if pasd < T: cur_lr = 1
         | 
| 24 | 
            +
                        else: cur_lr = wpe + (1-wpe) * rest / max_rest  # 1 to wpe
         | 
| 25 | 
            +
                    elif sche_type == 'lin0':
         | 
| 26 | 
            +
                        T = 0.05; max_rest = 1-T
         | 
| 27 | 
            +
                        if pasd < T: cur_lr = 1
         | 
| 28 | 
            +
                        else: cur_lr = wpe + (1-wpe) * rest / max_rest
         | 
| 29 | 
            +
                    elif sche_type == 'lin00':
         | 
| 30 | 
            +
                        cur_lr = wpe + (1-wpe) * rest
         | 
| 31 | 
            +
                    elif sche_type.startswith('lin'):
         | 
| 32 | 
            +
                        T = float(sche_type[3:]); max_rest = 1-T
         | 
| 33 | 
            +
                        wpe_mid = wpe + (1-wpe) * max_rest
         | 
| 34 | 
            +
                        wpe_mid = (1 + wpe_mid) / 2
         | 
| 35 | 
            +
                        if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
         | 
| 36 | 
            +
                        else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
         | 
| 37 | 
            +
                    elif sche_type == 'exp':
         | 
| 38 | 
            +
                        T = 0.15; max_rest = 1-T
         | 
| 39 | 
            +
                        if pasd < T: cur_lr = 1
         | 
| 40 | 
            +
                        else:
         | 
| 41 | 
            +
                            expo = (pasd-T) / max_rest * math.log(wpe)
         | 
| 42 | 
            +
                            cur_lr = math.exp(expo)
         | 
| 43 | 
            +
                    else:
         | 
| 44 | 
            +
                        raise NotImplementedError(f'unknown sche_type {sche_type}')
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                cur_lr *= peak_lr
         | 
| 47 | 
            +
                pasd = cur_it / (max_it-1)
         | 
| 48 | 
            +
                cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                inf = 1e6
         | 
| 51 | 
            +
                min_lr, max_lr = inf, -1
         | 
| 52 | 
            +
                min_wd, max_wd = inf, -1
         | 
| 53 | 
            +
                for param_group in optimizer.param_groups:
         | 
| 54 | 
            +
                    param_group['lr'] = cur_lr * param_group.get('lr_sc', 1)    # 'lr_sc' could be assigned
         | 
| 55 | 
            +
                    max_lr = max(max_lr, param_group['lr'])
         | 
| 56 | 
            +
                    min_lr = min(min_lr, param_group['lr'])
         | 
| 57 | 
            +
                    
         | 
| 58 | 
            +
                    param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
         | 
| 59 | 
            +
                    max_wd = max(max_wd, param_group['weight_decay'])
         | 
| 60 | 
            +
                    if param_group['weight_decay'] > 0:
         | 
| 61 | 
            +
                        min_wd = min(min_wd, param_group['weight_decay'])
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                if min_lr == inf: min_lr = -1
         | 
| 64 | 
            +
                if min_wd == inf: min_wd = -1
         | 
| 65 | 
            +
                return min_lr, max_lr, min_wd, max_wd
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def filter_params(model, nowd_keys=()) -> Tuple[
         | 
| 69 | 
            +
                List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
         | 
| 70 | 
            +
            ]:
         | 
| 71 | 
            +
                para_groups, para_groups_dbg = {}, {}
         | 
| 72 | 
            +
                names, paras = [], []
         | 
| 73 | 
            +
                names_no_grad = []
         | 
| 74 | 
            +
                count, numel = 0, 0
         | 
| 75 | 
            +
                for name, para in model.named_parameters():
         | 
| 76 | 
            +
                    name = name.replace('_fsdp_wrapped_module.', '')
         | 
| 77 | 
            +
                    if not para.requires_grad:
         | 
| 78 | 
            +
                        names_no_grad.append(name)
         | 
| 79 | 
            +
                        continue  # frozen weights
         | 
| 80 | 
            +
                    count += 1
         | 
| 81 | 
            +
                    numel += para.numel()
         | 
| 82 | 
            +
                    names.append(name)
         | 
| 83 | 
            +
                    paras.append(para)
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
         | 
| 86 | 
            +
                        cur_wd_sc, group_name = 0., 'ND'
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        cur_wd_sc, group_name = 1., 'D'
         | 
| 89 | 
            +
                    cur_lr_sc = 1.
         | 
| 90 | 
            +
                    if group_name not in para_groups:
         | 
| 91 | 
            +
                        para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
         | 
| 92 | 
            +
                        para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
         | 
| 93 | 
            +
                    para_groups[group_name]['params'].append(para)
         | 
| 94 | 
            +
                    para_groups_dbg[group_name]['params'].append(name)
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                for g in para_groups_dbg.values():
         | 
| 97 | 
            +
                    g['params'] = pformat(', '.join(g['params']), width=200)
         | 
| 98 | 
            +
                
         | 
| 99 | 
            +
                print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                for rk in range(dist.get_world_size()):
         | 
| 102 | 
            +
                    dist.barrier()
         | 
| 103 | 
            +
                    if dist.get_rank() == rk:
         | 
| 104 | 
            +
                        print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
         | 
| 105 | 
            +
                print('')
         | 
| 106 | 
            +
                
         | 
| 107 | 
            +
                assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
         | 
| 108 | 
            +
                return names, paras, list(para_groups.values())
         | 
    	
        utils/misc.py
    ADDED
    
    | @@ -0,0 +1,381 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import datetime
         | 
| 2 | 
            +
            import functools
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import subprocess
         | 
| 6 | 
            +
            import sys
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            from collections import defaultdict, deque
         | 
| 9 | 
            +
            from typing import Iterator, List, Tuple
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import pytz
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            import torch.distributed as tdist
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import dist
         | 
| 17 | 
            +
            from utils import arg_util
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            os_system = functools.partial(subprocess.call, shell=True)
         | 
| 20 | 
            +
            def echo(info):
         | 
| 21 | 
            +
                os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
         | 
| 22 | 
            +
            def os_system_get_stdout(cmd):
         | 
| 23 | 
            +
                return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
         | 
| 24 | 
            +
            def os_system_get_stdout_stderr(cmd):
         | 
| 25 | 
            +
                cnt = 0
         | 
| 26 | 
            +
                while True:
         | 
| 27 | 
            +
                    try:
         | 
| 28 | 
            +
                        sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
         | 
| 29 | 
            +
                    except subprocess.TimeoutExpired:
         | 
| 30 | 
            +
                        cnt += 1
         | 
| 31 | 
            +
                        print(f'[fetch free_port file] timeout cnt={cnt}')
         | 
| 32 | 
            +
                    else:
         | 
| 33 | 
            +
                        return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def time_str(fmt='[%m-%d %H:%M:%S]'):
         | 
| 37 | 
            +
                return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def init_distributed_mode(local_out_path, only_sync_master=False, timeout=30):
         | 
| 41 | 
            +
                try:
         | 
| 42 | 
            +
                    dist.initialize(fork=False, timeout=timeout)
         | 
| 43 | 
            +
                    dist.barrier()
         | 
| 44 | 
            +
                except RuntimeError:
         | 
| 45 | 
            +
                    print(f'{">"*75}  NCCL Error  {"<"*75}', flush=True)
         | 
| 46 | 
            +
                    time.sleep(10)
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
                if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
         | 
| 49 | 
            +
                _change_builtin_print(dist.is_local_master())
         | 
| 50 | 
            +
                if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path):
         | 
| 51 | 
            +
                    sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def _change_builtin_print(is_master):
         | 
| 55 | 
            +
                import builtins as __builtin__
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                builtin_print = __builtin__.print
         | 
| 58 | 
            +
                if type(builtin_print) != type(open):
         | 
| 59 | 
            +
                    return
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                def prt(*args, **kwargs):
         | 
| 62 | 
            +
                    force = kwargs.pop('force', False)
         | 
| 63 | 
            +
                    clean = kwargs.pop('clean', False)
         | 
| 64 | 
            +
                    deeper = kwargs.pop('deeper', False)
         | 
| 65 | 
            +
                    if is_master or force:
         | 
| 66 | 
            +
                        if not clean:
         | 
| 67 | 
            +
                            f_back = sys._getframe().f_back
         | 
| 68 | 
            +
                            if deeper and f_back.f_back is not None:
         | 
| 69 | 
            +
                                f_back = f_back.f_back
         | 
| 70 | 
            +
                            file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
         | 
| 71 | 
            +
                            builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
         | 
| 72 | 
            +
                        else:
         | 
| 73 | 
            +
                            builtin_print(*args, **kwargs)
         | 
| 74 | 
            +
                
         | 
| 75 | 
            +
                __builtin__.print = prt
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            class SyncPrint(object):
         | 
| 79 | 
            +
                def __init__(self, local_output_dir, sync_stdout=True):
         | 
| 80 | 
            +
                    self.sync_stdout = sync_stdout
         | 
| 81 | 
            +
                    self.terminal_stream = sys.stdout if sync_stdout else sys.stderr
         | 
| 82 | 
            +
                    fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt')
         | 
| 83 | 
            +
                    existing = os.path.exists(fname)
         | 
| 84 | 
            +
                    self.file_stream = open(fname, 'a')
         | 
| 85 | 
            +
                    if existing:
         | 
| 86 | 
            +
                        self.file_stream.write('\n'*7 + '='*55 + f'   RESTART {time_str()}   ' + '='*55 + '\n')
         | 
| 87 | 
            +
                    self.file_stream.flush()
         | 
| 88 | 
            +
                    self.enabled = True
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                def write(self, message):
         | 
| 91 | 
            +
                    self.terminal_stream.write(message)
         | 
| 92 | 
            +
                    self.file_stream.write(message)
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                def flush(self):
         | 
| 95 | 
            +
                    self.terminal_stream.flush()
         | 
| 96 | 
            +
                    self.file_stream.flush()
         | 
| 97 | 
            +
                
         | 
| 98 | 
            +
                def close(self):
         | 
| 99 | 
            +
                    if not self.enabled:
         | 
| 100 | 
            +
                        return
         | 
| 101 | 
            +
                    self.enabled = False
         | 
| 102 | 
            +
                    self.file_stream.flush()
         | 
| 103 | 
            +
                    self.file_stream.close()
         | 
| 104 | 
            +
                    if self.sync_stdout:
         | 
| 105 | 
            +
                        sys.stdout = self.terminal_stream
         | 
| 106 | 
            +
                        sys.stdout.flush()
         | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        sys.stderr = self.terminal_stream
         | 
| 109 | 
            +
                        sys.stderr.flush()
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                def __del__(self):
         | 
| 112 | 
            +
                    self.close()
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            class DistLogger(object):
         | 
| 116 | 
            +
                def __init__(self, lg, verbose):
         | 
| 117 | 
            +
                    self._lg, self._verbose = lg, verbose
         | 
| 118 | 
            +
                
         | 
| 119 | 
            +
                @staticmethod
         | 
| 120 | 
            +
                def do_nothing(*args, **kwargs):
         | 
| 121 | 
            +
                    pass
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                def __getattr__(self, attr: str):
         | 
| 124 | 
            +
                    return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            class TensorboardLogger(object):
         | 
| 128 | 
            +
                def __init__(self, log_dir, filename_suffix):
         | 
| 129 | 
            +
                    try: import tensorflow_io as tfio
         | 
| 130 | 
            +
                    except: pass
         | 
| 131 | 
            +
                    from torch.utils.tensorboard import SummaryWriter
         | 
| 132 | 
            +
                    self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
         | 
| 133 | 
            +
                    self.step = 0
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                def set_step(self, step=None):
         | 
| 136 | 
            +
                    if step is not None:
         | 
| 137 | 
            +
                        self.step = step
         | 
| 138 | 
            +
                    else:
         | 
| 139 | 
            +
                        self.step += 1
         | 
| 140 | 
            +
                
         | 
| 141 | 
            +
                def update(self, head='scalar', step=None, **kwargs):
         | 
| 142 | 
            +
                    for k, v in kwargs.items():
         | 
| 143 | 
            +
                        if v is None:
         | 
| 144 | 
            +
                            continue
         | 
| 145 | 
            +
                        # assert isinstance(v, (float, int)), type(v)
         | 
| 146 | 
            +
                        if step is None:  # iter wise
         | 
| 147 | 
            +
                            it = self.step
         | 
| 148 | 
            +
                            if it == 0 or (it + 1) % 500 == 0:
         | 
| 149 | 
            +
                                if hasattr(v, 'item'): v = v.item()
         | 
| 150 | 
            +
                                self.writer.add_scalar(f'{head}/{k}', v, it)
         | 
| 151 | 
            +
                        else:  # epoch wise
         | 
| 152 | 
            +
                            if hasattr(v, 'item'): v = v.item()
         | 
| 153 | 
            +
                            self.writer.add_scalar(f'{head}/{k}', v, step)
         | 
| 154 | 
            +
                
         | 
| 155 | 
            +
                def log_tensor_as_distri(self, tag, tensor1d, step=None):
         | 
| 156 | 
            +
                    if step is None:  # iter wise
         | 
| 157 | 
            +
                        step = self.step
         | 
| 158 | 
            +
                        loggable = step == 0 or (step + 1) % 500 == 0
         | 
| 159 | 
            +
                    else:  # epoch wise
         | 
| 160 | 
            +
                        loggable = True
         | 
| 161 | 
            +
                    if loggable:
         | 
| 162 | 
            +
                        try:
         | 
| 163 | 
            +
                            self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
         | 
| 164 | 
            +
                        except Exception as e:
         | 
| 165 | 
            +
                            print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
         | 
| 166 | 
            +
                
         | 
| 167 | 
            +
                def log_image(self, tag, img_chw, step=None):
         | 
| 168 | 
            +
                    if step is None:  # iter wise
         | 
| 169 | 
            +
                        step = self.step
         | 
| 170 | 
            +
                        loggable = step == 0 or (step + 1) % 500 == 0
         | 
| 171 | 
            +
                    else:  # epoch wise
         | 
| 172 | 
            +
                        loggable = True
         | 
| 173 | 
            +
                    if loggable:
         | 
| 174 | 
            +
                        self.writer.add_image(tag, img_chw, step, dataformats='CHW')
         | 
| 175 | 
            +
                
         | 
| 176 | 
            +
                def flush(self):
         | 
| 177 | 
            +
                    self.writer.flush()
         | 
| 178 | 
            +
                
         | 
| 179 | 
            +
                def close(self):
         | 
| 180 | 
            +
                    self.writer.close()
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            class SmoothedValue(object):
         | 
| 184 | 
            +
                """Track a series of values and provide access to smoothed values over a
         | 
| 185 | 
            +
                window or the global series average.
         | 
| 186 | 
            +
                """
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                def __init__(self, window_size=30, fmt=None):
         | 
| 189 | 
            +
                    if fmt is None:
         | 
| 190 | 
            +
                        fmt = "{median:.4f} ({global_avg:.4f})"
         | 
| 191 | 
            +
                    self.deque = deque(maxlen=window_size)
         | 
| 192 | 
            +
                    self.total = 0.0
         | 
| 193 | 
            +
                    self.count = 0
         | 
| 194 | 
            +
                    self.fmt = fmt
         | 
| 195 | 
            +
                
         | 
| 196 | 
            +
                def update(self, value, n=1):
         | 
| 197 | 
            +
                    self.deque.append(value)
         | 
| 198 | 
            +
                    self.count += n
         | 
| 199 | 
            +
                    self.total += value * n
         | 
| 200 | 
            +
                
         | 
| 201 | 
            +
                def synchronize_between_processes(self):
         | 
| 202 | 
            +
                    """
         | 
| 203 | 
            +
                    Warning: does not synchronize the deque!
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
         | 
| 206 | 
            +
                    tdist.barrier()
         | 
| 207 | 
            +
                    tdist.all_reduce(t)
         | 
| 208 | 
            +
                    t = t.tolist()
         | 
| 209 | 
            +
                    self.count = int(t[0])
         | 
| 210 | 
            +
                    self.total = t[1]
         | 
| 211 | 
            +
                
         | 
| 212 | 
            +
                @property
         | 
| 213 | 
            +
                def median(self):
         | 
| 214 | 
            +
                    return np.median(self.deque) if len(self.deque) else 0
         | 
| 215 | 
            +
                
         | 
| 216 | 
            +
                @property
         | 
| 217 | 
            +
                def avg(self):
         | 
| 218 | 
            +
                    return sum(self.deque) / (len(self.deque) or 1)
         | 
| 219 | 
            +
                
         | 
| 220 | 
            +
                @property
         | 
| 221 | 
            +
                def global_avg(self):
         | 
| 222 | 
            +
                    return self.total / (self.count or 1)
         | 
| 223 | 
            +
                
         | 
| 224 | 
            +
                @property
         | 
| 225 | 
            +
                def max(self):
         | 
| 226 | 
            +
                    return max(self.deque)
         | 
| 227 | 
            +
                
         | 
| 228 | 
            +
                @property
         | 
| 229 | 
            +
                def value(self):
         | 
| 230 | 
            +
                    return self.deque[-1] if len(self.deque) else 0
         | 
| 231 | 
            +
                
         | 
| 232 | 
            +
                def time_preds(self, counts) -> Tuple[float, str, str]:
         | 
| 233 | 
            +
                    remain_secs = counts * self.median
         | 
| 234 | 
            +
                    return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
         | 
| 235 | 
            +
                
         | 
| 236 | 
            +
                def __str__(self):
         | 
| 237 | 
            +
                    return self.fmt.format(
         | 
| 238 | 
            +
                        median=self.median,
         | 
| 239 | 
            +
                        avg=self.avg,
         | 
| 240 | 
            +
                        global_avg=self.global_avg,
         | 
| 241 | 
            +
                        max=self.max,
         | 
| 242 | 
            +
                        value=self.value)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            class MetricLogger(object):
         | 
| 246 | 
            +
                def __init__(self, delimiter='  '):
         | 
| 247 | 
            +
                    self.meters = defaultdict(SmoothedValue)
         | 
| 248 | 
            +
                    self.delimiter = delimiter
         | 
| 249 | 
            +
                    self.iter_end_t = time.time()
         | 
| 250 | 
            +
                    self.log_iters = []
         | 
| 251 | 
            +
                
         | 
| 252 | 
            +
                def update(self, **kwargs):
         | 
| 253 | 
            +
                    for k, v in kwargs.items():
         | 
| 254 | 
            +
                        if v is None:
         | 
| 255 | 
            +
                            continue
         | 
| 256 | 
            +
                        if hasattr(v, 'item'): v = v.item()
         | 
| 257 | 
            +
                        # assert isinstance(v, (float, int)), type(v)
         | 
| 258 | 
            +
                        assert isinstance(v, (float, int))
         | 
| 259 | 
            +
                        self.meters[k].update(v)
         | 
| 260 | 
            +
                
         | 
| 261 | 
            +
                def __getattr__(self, attr):
         | 
| 262 | 
            +
                    if attr in self.meters:
         | 
| 263 | 
            +
                        return self.meters[attr]
         | 
| 264 | 
            +
                    if attr in self.__dict__:
         | 
| 265 | 
            +
                        return self.__dict__[attr]
         | 
| 266 | 
            +
                    raise AttributeError("'{}' object has no attribute '{}'".format(
         | 
| 267 | 
            +
                        type(self).__name__, attr))
         | 
| 268 | 
            +
                
         | 
| 269 | 
            +
                def __str__(self):
         | 
| 270 | 
            +
                    loss_str = []
         | 
| 271 | 
            +
                    for name, meter in self.meters.items():
         | 
| 272 | 
            +
                        if len(meter.deque):
         | 
| 273 | 
            +
                            loss_str.append(
         | 
| 274 | 
            +
                                "{}: {}".format(name, str(meter))
         | 
| 275 | 
            +
                            )
         | 
| 276 | 
            +
                    return self.delimiter.join(loss_str)
         | 
| 277 | 
            +
                
         | 
| 278 | 
            +
                def synchronize_between_processes(self):
         | 
| 279 | 
            +
                    for meter in self.meters.values():
         | 
| 280 | 
            +
                        meter.synchronize_between_processes()
         | 
| 281 | 
            +
                
         | 
| 282 | 
            +
                def add_meter(self, name, meter):
         | 
| 283 | 
            +
                    self.meters[name] = meter
         | 
| 284 | 
            +
                
         | 
| 285 | 
            +
                def log_every(self, start_it, max_iters, itrt, print_freq, header=None):
         | 
| 286 | 
            +
                    self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())
         | 
| 287 | 
            +
                    self.log_iters.add(start_it)
         | 
| 288 | 
            +
                    if not header:
         | 
| 289 | 
            +
                        header = ''
         | 
| 290 | 
            +
                    start_time = time.time()
         | 
| 291 | 
            +
                    self.iter_end_t = time.time()
         | 
| 292 | 
            +
                    self.iter_time = SmoothedValue(fmt='{avg:.4f}')
         | 
| 293 | 
            +
                    self.data_time = SmoothedValue(fmt='{avg:.4f}')
         | 
| 294 | 
            +
                    space_fmt = ':' + str(len(str(max_iters))) + 'd'
         | 
| 295 | 
            +
                    log_msg = [
         | 
| 296 | 
            +
                        header,
         | 
| 297 | 
            +
                        '[{0' + space_fmt + '}/{1}]',
         | 
| 298 | 
            +
                        'eta: {eta}',
         | 
| 299 | 
            +
                        '{meters}',
         | 
| 300 | 
            +
                        'time: {time}',
         | 
| 301 | 
            +
                        'data: {data}'
         | 
| 302 | 
            +
                    ]
         | 
| 303 | 
            +
                    log_msg = self.delimiter.join(log_msg)
         | 
| 304 | 
            +
                    
         | 
| 305 | 
            +
                    if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
         | 
| 306 | 
            +
                        for i in range(start_it, max_iters):
         | 
| 307 | 
            +
                            obj = next(itrt)
         | 
| 308 | 
            +
                            self.data_time.update(time.time() - self.iter_end_t)
         | 
| 309 | 
            +
                            yield i, obj
         | 
| 310 | 
            +
                            self.iter_time.update(time.time() - self.iter_end_t)
         | 
| 311 | 
            +
                            if i in self.log_iters:
         | 
| 312 | 
            +
                                eta_seconds = self.iter_time.global_avg * (max_iters - i)
         | 
| 313 | 
            +
                                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
         | 
| 314 | 
            +
                                print(log_msg.format(
         | 
| 315 | 
            +
                                    i, max_iters, eta=eta_string,
         | 
| 316 | 
            +
                                    meters=str(self),
         | 
| 317 | 
            +
                                    time=str(self.iter_time), data=str(self.data_time)), flush=True)
         | 
| 318 | 
            +
                            self.iter_end_t = time.time()
         | 
| 319 | 
            +
                    else:
         | 
| 320 | 
            +
                        if isinstance(itrt, int): itrt = range(itrt)
         | 
| 321 | 
            +
                        for i, obj in enumerate(itrt):
         | 
| 322 | 
            +
                            self.data_time.update(time.time() - self.iter_end_t)
         | 
| 323 | 
            +
                            yield i, obj
         | 
| 324 | 
            +
                            self.iter_time.update(time.time() - self.iter_end_t)
         | 
| 325 | 
            +
                            if i in self.log_iters:
         | 
| 326 | 
            +
                                eta_seconds = self.iter_time.global_avg * (max_iters - i)
         | 
| 327 | 
            +
                                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
         | 
| 328 | 
            +
                                print(log_msg.format(
         | 
| 329 | 
            +
                                    i, max_iters, eta=eta_string,
         | 
| 330 | 
            +
                                    meters=str(self),
         | 
| 331 | 
            +
                                    time=str(self.iter_time), data=str(self.data_time)), flush=True)
         | 
| 332 | 
            +
                            self.iter_end_t = time.time()
         | 
| 333 | 
            +
                    
         | 
| 334 | 
            +
                    total_time = time.time() - start_time
         | 
| 335 | 
            +
                    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
         | 
| 336 | 
            +
                    print('{}   Total time:      {}   ({:.3f} s / it)'.format(
         | 
| 337 | 
            +
                        header, total_time_str, total_time / max_iters), flush=True)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
             | 
| 340 | 
            +
            def glob_with_latest_modified_first(pattern, recursive=False):
         | 
| 341 | 
            +
                return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
             | 
| 344 | 
            +
            def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]:
         | 
| 345 | 
            +
                info = []
         | 
| 346 | 
            +
                file = os.path.join(args.local_out_dir_path, pattern)
         | 
| 347 | 
            +
                all_ckpt = glob_with_latest_modified_first(file)
         | 
| 348 | 
            +
                if len(all_ckpt) == 0:
         | 
| 349 | 
            +
                    info.append(f'[auto_resume] no ckpt found @ {file}')
         | 
| 350 | 
            +
                    info.append(f'[auto_resume quit]')
         | 
| 351 | 
            +
                    return info, 0, 0, {}, {}
         | 
| 352 | 
            +
                else:
         | 
| 353 | 
            +
                    info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...')
         | 
| 354 | 
            +
                    ckpt = torch.load(all_ckpt[0], map_location='cpu')
         | 
| 355 | 
            +
                    ep, it = ckpt['epoch'], ckpt['iter']
         | 
| 356 | 
            +
                    info.append(f'[auto_resume success] resume from ep{ep}, it{it}')
         | 
| 357 | 
            +
                    return info, ep, it, ckpt['trainer'], ckpt['args']
         | 
| 358 | 
            +
             | 
| 359 | 
            +
             | 
| 360 | 
            +
            def create_npz_from_sample_folder(sample_folder: str):
         | 
| 361 | 
            +
                """
         | 
| 362 | 
            +
                Builds a single .npz file from a folder of .png samples. Refer to DiT.
         | 
| 363 | 
            +
                """
         | 
| 364 | 
            +
                import os, glob
         | 
| 365 | 
            +
                import numpy as np
         | 
| 366 | 
            +
                from tqdm import tqdm
         | 
| 367 | 
            +
                from PIL import Image
         | 
| 368 | 
            +
                
         | 
| 369 | 
            +
                samples = []
         | 
| 370 | 
            +
                pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG'))
         | 
| 371 | 
            +
                assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000'
         | 
| 372 | 
            +
                for png in tqdm(pngs, desc='Building .npz file from samples (png only)'):
         | 
| 373 | 
            +
                    with Image.open(png) as sample_pil:
         | 
| 374 | 
            +
                        sample_np = np.asarray(sample_pil).astype(np.uint8)
         | 
| 375 | 
            +
                    samples.append(sample_np)
         | 
| 376 | 
            +
                samples = np.stack(samples)
         | 
| 377 | 
            +
                assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3)
         | 
| 378 | 
            +
                npz_path = f'{sample_folder}.npz'
         | 
| 379 | 
            +
                np.savez(npz_path, arr_0=samples)
         | 
| 380 | 
            +
                print(f'Saved .npz file to {npz_path} [shape={samples.shape}].')
         | 
| 381 | 
            +
                return npz_path
         | 
