Spaces:
Build error
Build error
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| from einops import rearrange | |
| # return mask where padding is FALSE | |
| def lengths_to_mask(lengths, max_len): | |
| # max_len = max(lengths) | |
| mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) | |
| return mask #(b, len) | |
| # return mask where padding is ALL FALSE | |
| def get_pad_mask_idx(seq, pad_idx): | |
| return (seq != pad_idx).unsqueeze(1) | |
| # Given seq: (b, s) | |
| # Return mat: (1, s, s) | |
| # Example Output: | |
| # [[[ True, False, False], | |
| # [ True, True, False], | |
| # [ True, True, True]]] | |
| # For causal attention | |
| def get_subsequent_mask(seq): | |
| sz_b, seq_len = seq.shape | |
| subsequent_mask = (1 - torch.triu( | |
| torch.ones((1, seq_len, seq_len)), diagonal=1)).bool() | |
| return subsequent_mask.to(seq.device) | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def eval_decorator(fn): | |
| def inner(model, *args, **kwargs): | |
| was_training = model.training | |
| model.eval() | |
| out = fn(model, *args, **kwargs) | |
| model.train(was_training) | |
| return out | |
| return inner | |
| def l2norm(t): | |
| return F.normalize(t, dim = -1) | |
| # tensor helpers | |
| # Get a random subset of TRUE mask, with prob | |
| def get_mask_subset_prob(mask, prob): | |
| subset_mask = torch.bernoulli(mask, p=prob) & mask | |
| return subset_mask | |
| # Get mask of special_tokens in ids | |
| def get_mask_special_tokens(ids, special_ids): | |
| mask = torch.zeros_like(ids).bool() | |
| for special_id in special_ids: | |
| mask |= (ids==special_id) | |
| return mask | |
| # network builder helpers | |
| def _get_activation_fn(activation): | |
| if activation == "relu": | |
| return F.relu | |
| elif activation == "gelu": | |
| return F.gelu | |
| raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) | |
| # classifier free guidance functions | |
| def uniform(shape, device=None): | |
| return torch.zeros(shape, device=device).float().uniform_(0, 1) | |
| def prob_mask_like(shape, prob, device=None): | |
| if prob == 1: | |
| return torch.ones(shape, device=device, dtype=torch.bool) | |
| elif prob == 0: | |
| return torch.zeros(shape, device=device, dtype=torch.bool) | |
| else: | |
| return uniform(shape, device=device) < prob | |
| # sampling helpers | |
| def log(t, eps = 1e-20): | |
| return torch.log(t.clamp(min = eps)) | |
| def gumbel_noise(t): | |
| noise = torch.zeros_like(t).uniform_(0, 1) | |
| return -log(-log(noise)) | |
| def gumbel_sample(t, temperature = 1., dim = 1): | |
| return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) | |
| # Example input: | |
| # [[ 0.3596, 0.0862, 0.9771, -1.0000, -1.0000, -1.0000], | |
| # [ 0.4141, 0.1781, 0.6628, 0.5721, -1.0000, -1.0000], | |
| # [ 0.9428, 0.3586, 0.1659, 0.8172, 0.9273, -1.0000]] | |
| # Example output: | |
| # [[ -inf, -inf, 0.9771, -inf, -inf, -inf], | |
| # [ -inf, -inf, 0.6628, -inf, -inf, -inf], | |
| # [0.9428, -inf, -inf, -inf, -inf, -inf]] | |
| def top_k(logits, thres = 0.9, dim = 1): | |
| k = math.ceil((1 - thres) * logits.shape[dim]) | |
| val, ind = logits.topk(k, dim = dim) | |
| probs = torch.full_like(logits, float('-inf')) | |
| probs.scatter_(dim, ind, val) | |
| # func verified | |
| # print(probs) | |
| # print(logits) | |
| # raise | |
| return probs | |
| # noise schedules | |
| # More on large value, less on small | |
| def cosine_schedule(t): | |
| return torch.cos(t * math.pi * 0.5) | |
| def scale_cosine_schedule(t, scale): | |
| return torch.clip(scale*torch.cos(t * math.pi * 0.5) + 1 - scale, min=0., max=1.) | |
| # More on small value, less on large | |
| def q_schedule(bs, low, high, device): | |
| noise = uniform((bs,), device=device) | |
| schedule = 1 - cosine_schedule(noise) | |
| return torch.round(schedule * (high - low - 1)).long() + low | |
| def cal_performance(pred, labels, ignore_index=None, smoothing=0., tk=1): | |
| loss = cal_loss(pred, labels, ignore_index, smoothing=smoothing) | |
| # pred_id = torch.argmax(pred, dim=1) | |
| # mask = labels.ne(ignore_index) | |
| # n_correct = pred_id.eq(labels).masked_select(mask) | |
| # acc = torch.mean(n_correct.float()).item() | |
| pred_id_k = torch.topk(pred, k=tk, dim=1).indices | |
| pred_id = pred_id_k[:, 0] | |
| mask = labels.ne(ignore_index) | |
| n_correct = (pred_id_k == labels.unsqueeze(1)).any(dim=1).masked_select(mask) | |
| acc = torch.mean(n_correct.float()).item() | |
| return loss, pred_id, acc | |
| def cal_loss(pred, labels, ignore_index=None, smoothing=0.): | |
| '''Calculate cross entropy loss, apply label smoothing if needed.''' | |
| # print(pred.shape, labels.shape) #torch.Size([64, 1028, 55]) torch.Size([64, 55]) | |
| # print(pred.shape, labels.shape) #torch.Size([64, 1027, 55]) torch.Size([64, 55]) | |
| if smoothing: | |
| space = 2 | |
| n_class = pred.size(1) | |
| mask = labels.ne(ignore_index) | |
| one_hot = rearrange(F.one_hot(labels, n_class + space), 'a ... b -> a b ...')[:, :n_class] | |
| # one_hot = torch.zeros_like(pred).scatter(1, labels.unsqueeze(1), 1) | |
| sm_one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1) | |
| neg_log_prb = -F.log_softmax(pred, dim=1) | |
| loss = (sm_one_hot * neg_log_prb).sum(dim=1) | |
| # loss = F.cross_entropy(pred, sm_one_hot, reduction='none') | |
| loss = torch.mean(loss.masked_select(mask)) | |
| else: | |
| loss = F.cross_entropy(pred, labels, ignore_index=ignore_index) | |
| return loss |