Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| ORI_IMAGE_SIZE = 1024 | |
| IMAGE_SIZE = 256 | |
| REL_POS = 31 | |
| checkpoint = torch.load("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195.pth") | |
| image_encoder_pos_embed = checkpoint["image_encoder.pos_embed"] | |
| image_encoder_pos_embed = torch.nn.functional.interpolate(image_encoder_pos_embed.permute(0, 3, 1, 2), scale_factor=IMAGE_SIZE / ORI_IMAGE_SIZE, mode="bilinear", align_corners=True).permute(0, 2, 3, 1) | |
| checkpoint["image_encoder.pos_embed"] = image_encoder_pos_embed | |
| print(image_encoder_pos_embed.shape) | |
| for idx in [5, 11, 17, 23]: | |
| rel_pos_h = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] | |
| rel_pos_w = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] | |
| rel_pos_h = torch.nn.functional.interpolate( | |
| rel_pos_h.permute(1, 0).unsqueeze(0), | |
| size=REL_POS, mode="linear", | |
| align_corners=True, | |
| ).squeeze(0).permute(1, 0) | |
| rel_pos_w = torch.nn.functional.interpolate( | |
| rel_pos_w.permute(1, 0).unsqueeze(0), | |
| size=REL_POS, mode="linear", | |
| align_corners=True, | |
| ).squeeze(0).permute(1, 0) | |
| checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] = rel_pos_h | |
| checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] = rel_pos_w | |
| print(rel_pos_h.shape, rel_pos_w.shape) | |
| torch.save(checkpoint, f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_{IMAGE_SIZE}x{IMAGE_SIZE}.pth") | |