🔨 [Add] weight converter for YOLOv7
Browse files
yolo/tools/format_converters.py
CHANGED
|
@@ -31,3 +31,55 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
|
|
| 31 |
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
|
| 32 |
new_state_dict[weight_name] = weight_value
|
| 33 |
return new_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
|
| 32 |
new_state_dict[weight_name] = weight_value
|
| 33 |
return new_state_dict
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
head_converter = {
|
| 37 |
+
"head_conv": "m",
|
| 38 |
+
"implicit_a": "ia",
|
| 39 |
+
"implicit_m": "im",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
SPP_converter = {
|
| 43 |
+
"pre_conv.0": "cv1",
|
| 44 |
+
"pre_conv.1": "cv3",
|
| 45 |
+
"pre_conv.2": "cv4",
|
| 46 |
+
"post_conv.0": "cv5",
|
| 47 |
+
"post_conv.1": "cv6",
|
| 48 |
+
"short_conv": "cv2",
|
| 49 |
+
"merge_conv": "cv7",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
REP_converter = {"conv1": "rbr_dense", "conv2": "rbr_1x1", "conv": "0", "bn": "1"}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def convert_weight_v7(old_state_dict, new_state_dict):
|
| 56 |
+
map_weight = []
|
| 57 |
+
for key_name in new_state_dict.keys():
|
| 58 |
+
new_shape = new_state_dict[key_name].shape
|
| 59 |
+
old_key_name = "model." + key_name
|
| 60 |
+
new_key_name = key_name
|
| 61 |
+
if old_key_name not in old_state_dict.keys():
|
| 62 |
+
if "heads" in key_name:
|
| 63 |
+
layer_idx, _, conv_idx, conv_name, *details = key_name.split(".")
|
| 64 |
+
old_key_name = ".".join(["model", str(layer_idx), head_converter[conv_name], conv_idx, *details])
|
| 65 |
+
elif (
|
| 66 |
+
"pre_conv" in key_name
|
| 67 |
+
or "post_conv" in key_name
|
| 68 |
+
or "short_conv" in key_name
|
| 69 |
+
or "merge_conv" in key_name
|
| 70 |
+
):
|
| 71 |
+
for key, value in SPP_converter.items():
|
| 72 |
+
if key in key_name:
|
| 73 |
+
key_name = key_name.replace(key, value)
|
| 74 |
+
old_key_name = "model." + key_name
|
| 75 |
+
elif "conv1" in key_name or "conv2" in key_name:
|
| 76 |
+
for key, value in REP_converter.items():
|
| 77 |
+
if key in key_name:
|
| 78 |
+
key_name = key_name.replace(key, value)
|
| 79 |
+
old_key_name = "model." + key_name
|
| 80 |
+
map_weight.append(old_key_name)
|
| 81 |
+
assert old_key_name in old_state_dict.keys(), f"Weight Name Mismatch!! {old_key_name}"
|
| 82 |
+
old_shape = old_state_dict[old_key_name].shape
|
| 83 |
+
assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
|
| 84 |
+
new_state_dict[new_key_name] = old_state_dict[old_key_name]
|
| 85 |
+
return new_state_dict
|