benthic_classification / convert_model.py
danielhshi8224
Initialize repo
c458c3e
raw
history blame contribute delete
975 Bytes
import torch
from transformers import ConvNextForImageClassification, AutoConfig, AutoImageProcessor
ckpt_path = "./ConvNextmodel.pth"
base_model = "facebook/convnext-tiny-224" # the model you started from
num_labels = 7
# 1️⃣ Load the raw state dict
state_dict = torch.load(ckpt_path, map_location="cpu")
if any(k.startswith("module.") for k in state_dict):
state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
# 2️⃣ Rebuild config and model
config = AutoConfig.from_pretrained(base_model)
config.num_labels = num_labels
model = ConvNextForImageClassification(config)
missing, unexpected = model.load_state_dict(state_dict, strict=True)
print("Missing:", missing, "Unexpected:", unexpected)
# 3️⃣ Save in HF format
save_dir = "./convnext-tiny-224-7cls"
model.save_pretrained(save_dir)
# 4️⃣ Also save processor (for transforms)
processor = AutoImageProcessor.from_pretrained(base_model)
processor.save_pretrained(save_dir)