Spaces:
Paused
Paused
| import torch | |
| class CoordStage(object): | |
| def __init__(self, n_embed, down_factor): | |
| self.n_embed = n_embed | |
| self.down_factor = down_factor | |
| def eval(self): | |
| return self | |
| def encode(self, c): | |
| """fake vqmodel interface""" | |
| assert 0.0 <= c.min() and c.max() <= 1.0 | |
| b,ch,h,w = c.shape | |
| assert ch == 1 | |
| c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, | |
| mode="area") | |
| c = c.clamp(0.0, 1.0) | |
| c = self.n_embed*c | |
| c_quant = c.round() | |
| c_ind = c_quant.to(dtype=torch.long) | |
| info = None, None, c_ind | |
| return c_quant, None, info | |
| def decode(self, c): | |
| c = c/self.n_embed | |
| c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, | |
| mode="nearest") | |
| return c | |