Spaces:
Running
Running
| def convert_weight(old_state_dict, new_state_dict, model_size: int = 38): | |
| # TODO: need to refactor | |
| shift = 1 | |
| for idx in range(model_size): | |
| new_list, old_list = [], [] | |
| for weight_name, weight_value in new_state_dict.items(): | |
| if weight_name.split(".")[0] == str(idx): | |
| new_list.append((weight_name, None)) | |
| for weight_name, weight_value in old_state_dict.items(): | |
| if f"model.{idx+shift}." in weight_name: | |
| old_list.append((weight_name, weight_value)) | |
| if len(new_list) == len(old_list): | |
| for (weight_name, _), (_, weight_value) in zip(new_list, old_list): | |
| new_state_dict[weight_name] = weight_value | |
| else: | |
| for weight_name, weight_value in old_list: | |
| if "dfl" in weight_name: | |
| continue | |
| _, _, conv_name, conv_idx, *details = weight_name.split(".") | |
| if conv_name == "cv4" or conv_name == "cv5": | |
| layer_idx = 22 | |
| shift = 2 | |
| else: | |
| layer_idx = 37 | |
| if conv_name == "cv2" or conv_name == "cv4": | |
| conv_task = "anchor_conv" | |
| if conv_name == "cv3" or conv_name == "cv5": | |
| conv_task = "class_conv" | |
| weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details]) | |
| new_state_dict[weight_name] = weight_value | |
| return new_state_dict | |
| head_converter = { | |
| "head_conv": "m", | |
| "implicit_a": "ia", | |
| "implicit_m": "im", | |
| } | |
| SPP_converter = { | |
| "pre_conv.0": "cv1", | |
| "pre_conv.1": "cv3", | |
| "pre_conv.2": "cv4", | |
| "post_conv.0": "cv5", | |
| "post_conv.1": "cv6", | |
| "short_conv": "cv2", | |
| "merge_conv": "cv7", | |
| } | |
| REP_converter = {"conv1": "rbr_dense", "conv2": "rbr_1x1", "conv": "0", "bn": "1"} | |
| def convert_weight_v7(old_state_dict, new_state_dict): | |
| map_weight = [] | |
| for key_name in new_state_dict.keys(): | |
| new_shape = new_state_dict[key_name].shape | |
| old_key_name = "model." + key_name | |
| new_key_name = key_name | |
| if old_key_name not in old_state_dict.keys(): | |
| if "heads" in key_name: | |
| layer_idx, _, conv_idx, conv_name, *details = key_name.split(".") | |
| old_key_name = ".".join(["model", str(layer_idx), head_converter[conv_name], conv_idx, *details]) | |
| elif ( | |
| "pre_conv" in key_name | |
| or "post_conv" in key_name | |
| or "short_conv" in key_name | |
| or "merge_conv" in key_name | |
| ): | |
| for key, value in SPP_converter.items(): | |
| if key in key_name: | |
| key_name = key_name.replace(key, value) | |
| old_key_name = "model." + key_name | |
| elif "conv1" in key_name or "conv2" in key_name: | |
| for key, value in REP_converter.items(): | |
| if key in key_name: | |
| key_name = key_name.replace(key, value) | |
| old_key_name = "model." + key_name | |
| map_weight.append(old_key_name) | |
| assert old_key_name in old_state_dict.keys(), f"Weight Name Mismatch!! {old_key_name}" | |
| old_shape = old_state_dict[old_key_name].shape | |
| assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}" | |
| new_state_dict[new_key_name] = old_state_dict[old_key_name] | |
| return new_state_dict | |
| replace_dict = {"cv": "conv", ".m.": ".bottleneck."} | |
| def convert_weight_seg(old_state_dict, new_state_dict): | |
| diff = -1 | |
| for old_weight_name in old_state_dict.keys(): | |
| old_idx = int(old_weight_name.split(".")[1]) | |
| if old_idx == 23: | |
| diff = 3 | |
| elif old_idx == 41: | |
| diff = -19 | |
| new_idx = old_idx + diff | |
| new_weight_name = old_weight_name.replace(f".{old_idx}.", f".{new_idx}.") | |
| for key, val in replace_dict.items(): | |
| new_weight_name = new_weight_name.replace(key, val) | |
| if new_weight_name not in new_state_dict.keys(): | |
| heads = "heads" | |
| _, _, conv_name, conv_idx, *details = old_weight_name.split(".") | |
| if "proto" in conv_name: | |
| conv_idx = "3" | |
| new_weight_name = ".".join(["model", str(layer_idx), heads, conv_task, *details]) | |
| continue | |
| if "dfl" in old_weight_name: | |
| continue | |
| if conv_name == "cv2" or conv_name == "cv3" or conv_name == "cv6": | |
| layer_idx = 44 | |
| heads = "detect.heads" | |
| if conv_name == "cv4" or conv_name == "cv5" or conv_name == "cv7": | |
| layer_idx = 25 | |
| heads = "detect.heads" | |
| if conv_name == "cv2" or conv_name == "cv4": | |
| conv_task = "anchor_conv" | |
| if conv_name == "cv3" or conv_name == "cv5": | |
| conv_task = "class_conv" | |
| if conv_name == "cv6" or conv_name == "cv7": | |
| conv_task = "mask_conv" | |
| heads = "heads" | |
| new_weight_name = ".".join(["model", str(layer_idx), heads, conv_idx, conv_task, *details]) | |
| if ( | |
| new_weight_name not in new_state_dict.keys() | |
| or new_state_dict[new_weight_name].shape != old_state_dict[old_weight_name].shape | |
| ): | |
| print(f"new: {new_weight_name}, old: {old_weight_name}") | |
| print(f"{new_state_dict[new_weight_name].shape} {old_state_dict[old_weight_name].shape}") | |
| new_state_dict[new_weight_name] = old_state_dict[old_weight_name] | |
| return new_state_dict | |