Spaces:
Runtime error
Runtime error
| """Copyright PolyAI Limited.""" | |
| import logging | |
| import pdb | |
| import sys | |
| import traceback | |
| from functools import wraps | |
| from time import time | |
| from typing import List | |
| import torch | |
| from .symbol_table import SymbolTable | |
| def load_checkpoint(ckpt_path: str) -> dict: | |
| """ | |
| Loads checkpoint, while matching phone embedding size. | |
| """ | |
| state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] | |
| new_state_dict = dict() | |
| for p_name in state_dict.keys(): | |
| if p_name.startswith("vocoder"): | |
| continue | |
| new_state_dict[p_name] = state_dict[p_name] | |
| return new_state_dict | |
| def breakpoint_on_error(fn): | |
| """Creates a breakpoint on error | |
| Use as a wrapper | |
| Args: | |
| fn: the function | |
| Returns: | |
| inner function | |
| """ | |
| def inner(*args, **kwargs): | |
| try: | |
| return fn(*args, **kwargs) | |
| except Exception: | |
| """Standard python way of creating a breakpoint on error""" | |
| extype, value, tb = sys.exc_info() | |
| print(f"extype={extype},\nvalue={value}") | |
| traceback.print_exc() | |
| pdb.post_mortem(tb) | |
| return inner | |
| def measure_duration(f): | |
| def wrap(*args, **kw): | |
| ts = time() | |
| result = f(*args, **kw) | |
| te = time() | |
| logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts)) | |
| return result | |
| return wrap | |
| def split_metapath(in_paths: List[str]): | |
| other_paths = [] | |
| for itm_path in in_paths: | |
| other_paths.append(itm_path) | |
| return other_paths | |