File size: 1,802 Bytes
d5532b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Author: T. S. Liang @ Rama Alpaca
# Emails: tsliang2001@gmail.com | shuangliang@ramaalpaca.com
# Date: Feb. 2025

"""

Description:

    SDPose UNet Forward Modifier

    Temporarily overrides `UNet.forward` to expose decoder features via forward hooks.
    - keypoint_scheme: "body" (COCO-17) or "wholebody" (COCO-WholeBody-133)
    - return_decoder_feats: if True, returns a selected decoder feature (Tensor);
    otherwise returns the original UNet output.
"""

def new_forward_kpt17(self, *args, return_decoder_feats=False, **kwargs):
    self._decoder_feats = []

    def hook_fn(module, input, output):
        self._decoder_feats.append(output)

    handles = [blk.register_forward_hook(hook_fn) for blk in self.up_blocks]

    out = self._old_forward(*args, **kwargs)

    for h in handles:
        h.remove()

    if return_decoder_feats:

        feats = self._decoder_feats[::-1]
        return feats[0]
    else:
        return out

def new_forward_kpt133(self, *args, return_decoder_feats=False, **kwargs):
    self._decoder_feats = []

    def hook_fn(module, input, output):
        self._decoder_feats.append(output)

    handles = [blk.register_forward_hook(hook_fn) for blk in self.up_blocks]

    out = self._old_forward(*args, **kwargs)

    for h in handles:
        h.remove()

    if return_decoder_feats:

        feats = self._decoder_feats[::-1]
        return feats[1]
    else:
        return out

def Modified_forward(unet, keypoint_scheme = "body"):

    if keypoint_scheme == "body":

        unet._old_forward = unet.forward
        unet.forward = new_forward_kpt17.__get__(unet, unet.__class__)

    elif keypoint_scheme == "wholebody":

        unet._old_forward = unet.forward
        unet.forward = new_forward_kpt133.__get__(unet, unet.__class__)

    return unet