SDPose / models /ModifiedUNet.py
teemosliang's picture
realease: publish the huggingface space demo of SDPose
d5532b9 verified
# Author: T. S. Liang @ Rama Alpaca
# Emails: tsliang2001@gmail.com | shuangliang@ramaalpaca.com
# Date: Feb. 2025
"""
Description:
SDPose UNet Forward Modifier
Temporarily overrides `UNet.forward` to expose decoder features via forward hooks.
- keypoint_scheme: "body" (COCO-17) or "wholebody" (COCO-WholeBody-133)
- return_decoder_feats: if True, returns a selected decoder feature (Tensor);
otherwise returns the original UNet output.
"""
def new_forward_kpt17(self, *args, return_decoder_feats=False, **kwargs):
self._decoder_feats = []
def hook_fn(module, input, output):
self._decoder_feats.append(output)
handles = [blk.register_forward_hook(hook_fn) for blk in self.up_blocks]
out = self._old_forward(*args, **kwargs)
for h in handles:
h.remove()
if return_decoder_feats:
feats = self._decoder_feats[::-1]
return feats[0]
else:
return out
def new_forward_kpt133(self, *args, return_decoder_feats=False, **kwargs):
self._decoder_feats = []
def hook_fn(module, input, output):
self._decoder_feats.append(output)
handles = [blk.register_forward_hook(hook_fn) for blk in self.up_blocks]
out = self._old_forward(*args, **kwargs)
for h in handles:
h.remove()
if return_decoder_feats:
feats = self._decoder_feats[::-1]
return feats[1]
else:
return out
def Modified_forward(unet, keypoint_scheme = "body"):
if keypoint_scheme == "body":
unet._old_forward = unet.forward
unet.forward = new_forward_kpt17.__get__(unet, unet.__class__)
elif keypoint_scheme == "wholebody":
unet._old_forward = unet.forward
unet.forward = new_forward_kpt133.__get__(unet, unet.__class__)
return unet