File size: 2,145 Bytes
40ac571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import torchvision.transforms as T


class LVDM(nn.Module):
    def __init__(self, referencenet, unet, pose_guider):
        super().__init__()
        self.referencenet = referencenet
        self.unet = unet
        self.pose_guider = pose_guider

    def forward(
        self,
        noisy_latents,
        timesteps,
        ref_image_latents,
        clip_image_embeds,
        pose_img,
        uncond_fwd: bool = False,
    ):
        # noisy_latents.shape = torch.Size([4, 4, 1, 112, 80])
        # timesteps = tensor([426, 277, 802, 784], device='cuda:5')
        # ref_image_latents.shape = torch.Size([4, 4, 112, 80])
        # clip_image_embeds.shape = torch.Size([4, 1, 768])
        # pose_img.shape = torch.Size([4, 3, 1, 896, 640])
        # uncond_fwd = False

        pose_cond_tensor = pose_img.to(device="cuda")
        pose_fea = self.pose_guider(pose_cond_tensor)
        # pose_fea.shape = torch.Size([4, 320, 1, 112, 80])

        # not uncond_fwd = True
        if not uncond_fwd:
            ref_timesteps = torch.zeros_like(timesteps)
            reference_down_block_res_samples, reference_mid_block_res_sample, reference_up_block_res_samples = \
                self.referencenet(ref_image_latents,
                                  ref_timesteps,
                                  encoder_hidden_states=clip_image_embeds,
                                  return_dict=False)

        self.unet.set_do_classifier_free_guidance(do_classifier_free_guidance=False)
        model_pred = self.unet(noisy_latents,
                               timesteps,
                               pose_cond_fea=pose_fea,
                               encoder_hidden_states=clip_image_embeds,
                               reference_down_block_res_samples=reference_down_block_res_samples if not uncond_fwd else None,
                               reference_mid_block_res_sample=reference_mid_block_res_sample if not uncond_fwd else None,
                               reference_up_block_res_samples=reference_up_block_res_samples if not uncond_fwd else None).sample
        return model_pred