Spaces:
Runtime error
Runtime error
| import transformers | |
| import torch | |
| import torch.nn as nn | |
| import re | |
| import logging | |
| from nn import FixableDropout | |
| from utils import scr | |
| LOG = logging.getLogger(__name__) | |
| class CastModule(nn.Module): | |
| def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None): | |
| super().__init__() | |
| self.underlying = module | |
| self.in_cast = in_cast | |
| self.out_cast = out_cast | |
| def cast(self, obj, dtype): | |
| if dtype is None: | |
| return obj | |
| if isinstance(obj, torch.Tensor): | |
| return obj.to(dtype) | |
| else: | |
| return obj | |
| def forward(self, *args, **kwargs): | |
| args = tuple(self.cast(a, self.in_cast) for a in args) | |
| kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()} | |
| outputs = self.underlying(*args, **kwargs) | |
| if isinstance(outputs, torch.Tensor): | |
| outputs = self.cast(outputs, self.out_cast) | |
| elif isinstance(outputs, tuple): | |
| outputs = tuple(self.cast(o, self.out_cast) for o in outputs) | |
| else: | |
| raise RuntimeError(f"Not sure how to cast type {type(outputs)}") | |
| return outputs | |
| def extra_repr(self): | |
| return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}" | |
| class BertClassifier(torch.nn.Module): | |
| def __init__(self, model_name, hidden_dim=768): | |
| super().__init__() | |
| if model_name.startswith("bert"): | |
| self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr()) | |
| else: | |
| self.model = transformers.AutoModel.from_pretrained(model_name, cache_dir=scr()) | |
| self.classifier = torch.nn.Linear(hidden_dim, 1) | |
| def config(self): | |
| return self.model.config | |
| def forward(self, *args, **kwargs): | |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"} | |
| model_output = self.model(*args, **filtered_kwargs) | |
| if "pooler_output" in model_output.keys(): | |
| pred = self.classifier(model_output.pooler_output) | |
| else: | |
| pred = self.classifier(model_output.last_hidden_state[:, 0]) | |
| if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: | |
| last_hidden_state = model_output.last_hidden_state | |
| return pred, last_hidden_state | |
| else: | |
| return pred | |
| def replace_dropout(model): | |
| for m in model.modules(): | |
| for n, c in m.named_children(): | |
| if isinstance(c, nn.Dropout): | |
| setattr(m, n, FixableDropout(c.p)) | |
| def resample(m, seed=None): | |
| for c in m.children(): | |
| if hasattr(c, "resample"): | |
| c.resample(seed) | |
| else: | |
| resample(c, seed) | |
| model.resample_dropout = resample.__get__(model) | |
| def get_model(config): | |
| if config.model.class_name == "BertClassifier": | |
| model = BertClassifier(config.model.name) | |
| else: | |
| ModelClass = getattr(transformers, config.model.class_name) | |
| LOG.info(f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}") | |
| model = ModelClass.from_pretrained(config.model.name, cache_dir=scr()) | |
| if config.model.pt is not None: | |
| LOG.info(f"Loading model initialization from {config.model.pt}") | |
| state_dict = torch.load(config.model.pt, map_location="cpu") | |
| try: | |
| model.load_state_dict(state_dict) | |
| except RuntimeError: | |
| LOG.info("Default load failed; stripping prefix and trying again.") | |
| state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| LOG.info("Loaded model initialization") | |
| if config.dropout is not None: | |
| n_reset = 0 | |
| for m in model.modules(): | |
| if isinstance(m, nn.Dropout): | |
| m.p = config.dropout | |
| n_reset += 1 | |
| if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout | |
| if isinstance(m.dropout, float): | |
| m.dropout = config.dropout | |
| n_reset += 1 | |
| if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout | |
| if isinstance(m.activation_dropout, float): | |
| m.activation_dropout = config.dropout | |
| n_reset += 1 | |
| LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}") | |
| param_names = [n for n, _ in model.named_parameters()] | |
| bad_inner_params = [p for p in config.model.inner_params if p not in param_names] | |
| if len(bad_inner_params) != 0: | |
| raise ValueError(f"Params {bad_inner_params} do not exist in model of type {type(model)}.") | |
| if config.no_grad_layers is not None: | |
| if config.half: | |
| model.bfloat16() | |
| def upcast(mod): | |
| modlist = None | |
| for child in mod.children(): | |
| if isinstance(child, nn.ModuleList): | |
| assert modlist is None, f"Found multiple modlists for {mod}" | |
| modlist = child | |
| if modlist is None: | |
| raise RuntimeError("Couldn't find a ModuleList child") | |
| LOG.info(f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting") | |
| modlist[config.no_grad_layers:].to(torch.float32) | |
| modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers]) | |
| modlist[-1] = CastModule(modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16) | |
| parents = [] | |
| if hasattr(model, "transformer"): | |
| parents.append(model.transformer) | |
| if hasattr(model, "encoder"): | |
| parents.append(model.encoder) | |
| if hasattr(model, "decoder"): | |
| parents.append(model.decoder) | |
| if hasattr(model, "model"): | |
| parents.extend([model.model.encoder, model.model.decoder]) | |
| for t in parents: | |
| t.no_grad_layers = config.no_grad_layers | |
| if config.half and config.alg != "rep": | |
| upcast(t) | |
| if config.half and config.alg != "rep": | |
| idxs = [] | |
| for p in config.model.inner_params: | |
| for comp in p.split('.'): | |
| if comp.isdigit(): | |
| idxs.append(int(comp)) | |
| max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers) | |
| for pidx, p in enumerate(config.model.inner_params): | |
| comps = p.split('.') | |
| if max_idx in comps or min_idx in comps: | |
| index = comps.index(max_idx) if max_idx in comps else comps.index(min_idx) | |
| comps.insert(index + 1, 'underlying') | |
| new_p = '.'.join(comps) | |
| LOG.info(f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'") | |
| config.model.inner_params[pidx] = new_p | |
| return model | |
| def get_tokenizer(config): | |
| tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name | |
| return getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir=scr()) | |
| if __name__ == '__main__': | |
| m = BertClassifier("bert-base-uncased") | |
| m(torch.arange(5)[None, :]) | |
| import pdb; pdb.set_trace() | |