Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from os.path import join | |
| import torch | |
| def get_latest(name, checkpoint_dir, extra_args=None): | |
| if extra_args is None: | |
| extra_args = dict() | |
| files = os.listdir(join(checkpoint_dir, name)) | |
| steps = torch.tensor([int(f.split("step=")[-1].split(".")[0]) for f in files]) | |
| selected = files[steps.argmax()] | |
| return dict( | |
| chkpt_name=os.path.join(name, selected), | |
| extra_args=extra_args) | |
| DS_PARAM_REGEX = r'_forward_module\.(.+)' | |
| def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None): | |
| ''' | |
| Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching | |
| in parameters which are improperly loaded by the DeepSpeed conversion utility. | |
| deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder. | |
| pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be | |
| placed in the same directory as the DeepSpeed checkpoint directory with the same name but | |
| a .pt extension. | |
| Returns: path to the converted checkpoint. | |
| ''' | |
| from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict | |
| if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)): | |
| raise ValueError( | |
| 'args.ckpt_dir should point to the checkpoint directory' | |
| ' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").' | |
| ) | |
| # Convert state dict to PyTorch format | |
| if not pl_ckpt_path: | |
| pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt | |
| if not os.path.exists(pl_ckpt_path): | |
| convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path) | |
| # Patch in missing parameters that failed to be converted by DeepSpeed utility | |
| pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path) | |
| torch.save(pl_ckpt, pl_ckpt_path) | |
| return pl_ckpt_path | |
| def get_optim_files(checkpoint_dir): | |
| files = sorted([f for f in os.listdir(checkpoint_dir) if "optim" in f]) | |
| return [join(checkpoint_dir, f) for f in files] | |
| def get_model_state_file(checkpoint_dir, zero_stage): | |
| f = [f for f in os.listdir(checkpoint_dir) if "model_states" in f][0] | |
| return join(checkpoint_dir, f) | |
| def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str): | |
| ''' | |
| Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint | |
| into the fp32 state dict. | |
| deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder. | |
| fp32_ckpt_path: Path to the reconstructed | |
| ''' | |
| from pytorch_lightning.utilities.deepspeed import ds_checkpoint_dir | |
| # This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict | |
| checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path) | |
| optim_files = get_optim_files(checkpoint_dir) | |
| optim_state = torch.load(optim_files[0], map_location='cpu') | |
| zero_stage = optim_state["optimizer_state_dict"]["zero_stage"] | |
| deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage) | |
| # Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt | |
| ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu') | |
| ds_sd = ds_ckpt['module'] | |
| fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu') | |
| fp32_sd = fp32_ckpt['state_dict'] | |
| for k, v in ds_sd.items(): | |
| try: | |
| match = re.match(DS_PARAM_REGEX, k) | |
| param_name = match.group(1) | |
| except: | |
| print(f'Failed to extract parameter from DeepSpeed key {k}') | |
| continue | |
| v = v.to(torch.float32) | |
| if param_name not in fp32_sd: | |
| print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd') | |
| fp32_sd[param_name] = v | |
| else: | |
| assert torch.allclose(v, fp32_sd[param_name].to(torch.float32), atol=1e-2) | |
| return fp32_ckpt | |
| def get_version_and_step(f, i): | |
| step = f.split("step=")[-1].split(".")[0] | |
| if "-v" in step: | |
| [step, version] = step.split("-v") | |
| else: | |
| step, version = step, 0 | |
| return int(version), int(step), i | |
| def get_latest_ds(name, extra_args=None): | |
| if extra_args is None: | |
| extra_args = dict() | |
| files = os.listdir(f"../checkpoints/{name}") | |
| latest = sorted([get_version_and_step(f, i) for i, f in enumerate(files)], reverse=True)[0] | |
| selected = files[latest[-1]] | |
| # print(f"Selecting file: {selected}") | |
| ds_chkpt = join(name, selected) | |
| reg_chkpt = join(name + "_fp32", selected) | |
| reg_chkpt_path = join("../checkpoints", reg_chkpt) | |
| if not os.path.exists(reg_chkpt_path): | |
| os.makedirs(os.path.dirname(reg_chkpt_path), exist_ok=True) | |
| print(f"Checkpoint {reg_chkpt} does not exist, converting from deepspeed") | |
| convert_deepspeed_checkpoint(join("../checkpoints", ds_chkpt), reg_chkpt_path) | |
| return dict( | |
| chkpt_name=reg_chkpt, | |
| extra_args=extra_args) | |
| def get_all_models_in_dir(name, checkpoint_dir, extra_args=None): | |
| ret = {} | |
| for model_dir in os.listdir(join(checkpoint_dir, name)): | |
| full_name = f"{name}/{model_dir}/train" | |
| # print(f'"{full_name}",') | |
| ret[full_name] = get_latest(full_name, checkpoint_dir, extra_args) | |
| return ret | |
| def saved_model_dict(checkpoint_dir): | |
| model_info = { | |
| **get_all_models_in_dir( | |
| "9-5-23-mixed", | |
| checkpoint_dir, | |
| extra_args=dict( | |
| mixup_weight=0.0, | |
| sim_use_cls=False, | |
| audio_pool_width=1, | |
| memory_buffer_size=0, | |
| loss_leak=0.0) | |
| ), | |
| **get_all_models_in_dir( | |
| "1-23-24-rebuttal-heads", | |
| checkpoint_dir, | |
| extra_args=dict( | |
| loss_leak=0.0) | |
| ), | |
| **get_all_models_in_dir( | |
| "11-8-23", | |
| checkpoint_dir, | |
| extra_args=dict(loss_leak=0.0)), | |
| **get_all_models_in_dir( | |
| "10-30-23-3", | |
| checkpoint_dir, | |
| extra_args=dict(loss_leak=0.0)), | |
| "davenet": dict( | |
| chkpt_name=None, | |
| extra_args=dict( | |
| audio_blur=1, | |
| image_model_type="davenet", | |
| image_aligner_type=None, | |
| audio_model_type="davenet", | |
| audio_aligner_type=None, | |
| audio_input="davenet_spec", | |
| use_cached_embs=False, | |
| dropout=False, | |
| sim_agg_heads=1, | |
| nonneg_sim=False, | |
| audio_lora=False, | |
| image_lora=False, | |
| norm_vectors=False, | |
| ), | |
| data_args=dict( | |
| use_cached_embs=False, | |
| use_davenet_spec=True, | |
| override_target_length=20, | |
| audio_model_type="davenet", | |
| ), | |
| ), | |
| "cavmae": dict( | |
| chkpt_name=None, | |
| extra_args=dict( | |
| audio_blur=1, | |
| image_model_type="cavmae", | |
| image_aligner_type=None, | |
| audio_model_type="cavmae", | |
| audio_aligner_type=None, | |
| audio_input="spec", | |
| use_cached_embs=False, | |
| sim_agg_heads=1, | |
| dropout=False, | |
| nonneg_sim=False, | |
| audio_lora=False, | |
| image_lora=False, | |
| norm_vectors=False, | |
| learn_audio_cls=False, | |
| sim_agg_type="cavmae", | |
| ), | |
| data_args=dict( | |
| use_cached_embs=False, | |
| use_davenet_spec=True, | |
| audio_model_type="cavmae", | |
| override_target_length=10, | |
| ), | |
| ), | |
| "imagebind": dict( | |
| chkpt_name=None, | |
| extra_args=dict( | |
| audio_blur=1, | |
| image_model_type="imagebind", | |
| image_aligner_type=None, | |
| audio_model_type="imagebind", | |
| audio_aligner_type=None, | |
| audio_input="spec", | |
| use_cached_embs=False, | |
| sim_agg_heads=1, | |
| dropout=False, | |
| nonneg_sim=False, | |
| audio_lora=False, | |
| image_lora=False, | |
| norm_vectors=False, | |
| learn_audio_cls=False, | |
| sim_agg_type="imagebind", | |
| ), | |
| data_args=dict( | |
| use_cached_embs=False, | |
| use_davenet_spec=True, | |
| audio_model_type="imagebind", | |
| override_target_length=10, | |
| ), | |
| ), | |
| } | |
| model_info["denseav_language"] = model_info["10-30-23-3/places_base/train"] | |
| model_info["denseav_sound"] = model_info["11-8-23/hubert_1h_asf_cls_full_image_train_small_lr/train"] | |
| model_info["denseav_2head"] = model_info["1-23-24-rebuttal-heads/mixed-2h/train"] | |
| return model_info | |