Spaces:
Running
Running
| from typing import List | |
| import torch | |
| from methods.base.model import BaseModel | |
| import tqdm | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from abc import abstractmethod | |
| from methods.elasticdnn.model.base import ElasticDNNUtil | |
| from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util | |
| from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util | |
| from utils.dl.common.model import LayerActivation | |
| class ElasticDNN_OfflineFMModel(BaseModel): | |
| def get_required_model_components(self) -> List[str]: | |
| return ['main'] | |
| def generate_md_by_reducing_width(self, reducing_width_ratio, samples: torch.Tensor): | |
| pass | |
| def forward_to_get_task_loss(self, x, y, *args, **kwargs): | |
| raise NotImplementedError | |
| def get_feature_hook(self) -> LayerActivation: | |
| pass | |
| def get_elastic_dnn_util(self) -> ElasticDNNUtil: | |
| pass | |
| def get_lora_util(self) -> FMLoRA_Util: | |
| pass | |
| def get_task_head_params(self): | |
| pass | |
| class ElasticDNN_OfflineClsFMModel(ElasticDNN_OfflineFMModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| self.to_eval_mode() | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| x, y = x.to(self.device), y.to(self.device) | |
| output = self.infer(x) | |
| pred = F.softmax(output.logits, dim=1).argmax(dim=1) | |
| #correct = torch.eq(torch.argmax(output.logits,dim = 1), y).sum().item() | |
| correct = torch.eq(pred, y).sum().item() | |
| acc += correct | |
| sample_num += len(y) | |
| pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
| f'cur_batch_acc: {(correct / len(y)):.4f}') | |
| acc /= sample_num | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](x) | |
| import numpy as np | |
| class StreamSegMetrics: | |
| """ | |
| Stream Metrics for Semantic Segmentation Task | |
| """ | |
| def __init__(self, n_classes): | |
| self.n_classes = n_classes | |
| self.confusion_matrix = np.zeros((n_classes, n_classes)) | |
| def update(self, label_trues, label_preds): | |
| for lt, lp in zip(label_trues, label_preds): | |
| self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() ) | |
| def to_str(results): | |
| string = "\n" | |
| for k, v in results.items(): | |
| if k!="Class IoU": | |
| string += "%s: %f\n"%(k, v) | |
| return string | |
| def _fast_hist(self, label_true, label_pred): | |
| mask = (label_true >= 0) & (label_true < self.n_classes) | |
| hist = np.bincount( | |
| self.n_classes * label_true[mask].astype(int) + label_pred[mask], | |
| minlength=self.n_classes ** 2, | |
| ).reshape(self.n_classes, self.n_classes) | |
| return hist | |
| def get_results(self): | |
| """Returns accuracy score evaluation result. | |
| - overall accuracy | |
| - mean accuracy | |
| - mean IU | |
| - fwavacc | |
| """ | |
| hist = self.confusion_matrix | |
| acc = np.diag(hist).sum() / hist.sum() | |
| acc_cls = np.diag(hist) / hist.sum(axis=1) | |
| acc_cls = np.nanmean(acc_cls) | |
| iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) | |
| mean_iu = np.nanmean(iu) | |
| freq = hist.sum(axis=1) / hist.sum() | |
| fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() | |
| cls_iu = dict(zip(range(self.n_classes), iu)) | |
| return { | |
| "Overall Acc": acc, | |
| "Mean Acc": acc_cls, | |
| "FreqW Acc": fwavacc, | |
| "Mean IoU": mean_iu, | |
| "Class IoU": cls_iu, | |
| } | |
| def reset(self): | |
| self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) | |
| class ElasticDNN_OfflineSegFMModel(ElasticDNN_OfflineFMModel): | |
| def __init__(self, name: str, models_dict_path: str, device: str, num_classes): | |
| super().__init__(name, models_dict_path, device) | |
| self.num_classes = num_classes | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| device = self.device | |
| self.to_eval_mode() | |
| metrics = StreamSegMetrics(self.num_classes) | |
| metrics.reset() | |
| import tqdm | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=False, dynamic_ncols=True) | |
| with torch.no_grad(): | |
| for batch_index, (x, y) in pbar: | |
| x, y = x.to(device, dtype=x.dtype, non_blocking=True, copy=False), \ | |
| y.to(device, dtype=y.dtype, non_blocking=True, copy=False) | |
| output = self.infer(x) | |
| pred = output.detach().max(dim=1)[1].cpu().numpy() | |
| metrics.update((y + 0).cpu().numpy(), pred) | |
| res = metrics.get_results() | |
| pbar.set_description(f'cur batch mIoU: {res["Mean IoU"]:.4f}') | |
| res = metrics.get_results() | |
| return res['Mean IoU'] | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](x) | |
| class ElasticDNN_OfflineDetFMModel(ElasticDNN_OfflineFMModel): | |
| def __init__(self, name: str, models_dict_path: str, device: str, num_classes): | |
| super().__init__(name, models_dict_path, device) | |
| self.num_classes = num_classes | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| # print('DeeplabV3: start test acc') | |
| _d = test_loader.dataset | |
| from data import build_dataloader | |
| if _d.__class__.__name__ == 'MergedDataset': | |
| # print('\neval on merged datasets') | |
| datasets = _d.datasets | |
| test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None) for d in datasets] | |
| accs = [self.get_accuracy(loader) for loader in test_loaders] | |
| # print(accs) | |
| return sum(accs) / len(accs) | |
| # print('dataset len', len(test_loader.dataset)) | |
| model = self.models_dict['main'] | |
| device = self.device | |
| model.eval() | |
| # print('# classes', model.num_classes) | |
| model = model.to(device) | |
| from dnns.yolov3.coco_evaluator import COCOEvaluator | |
| from utils.common.others import HiddenPrints | |
| with torch.no_grad(): | |
| with HiddenPrints(): | |
| evaluator = COCOEvaluator( | |
| dataloader=test_loader, | |
| img_size=(224, 224), | |
| confthre=0.01, | |
| nmsthre=0.65, | |
| num_classes=self.num_classes, | |
| testdev=False | |
| ) | |
| res = evaluator.evaluate(model, False, False) | |
| map50 = res[1] | |
| # print('eval info', res[-1]) | |
| return map50 | |
| def infer(self, x, *args, **kwargs): | |
| if len(args) > 0: | |
| print(args, len(args)) | |
| return self.models_dict['main'](x, *args) # forward(x, label) | |
| return self.models_dict['main'](x) | |
| class ElasticDNN_OfflineSenClsFMModel(ElasticDNN_OfflineFMModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| self.to_eval_mode() | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| y = y.to(self.device) | |
| output = self.infer(x) | |
| pred = F.softmax(output, dim=1).argmax(dim=1) | |
| correct = torch.eq(pred, y).sum().item() | |
| acc += correct | |
| sample_num += len(y) | |
| pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
| f'cur_batch_acc: {(correct / len(y)):.4f}') | |
| acc /= sample_num | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'].forward(**x) | |
| from accelerate.utils.operations import pad_across_processes | |
| class ElasticDNN_OfflineTrFMModel(ElasticDNN_OfflineFMModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| # TODO: BLEU | |
| from sacrebleu import corpus_bleu | |
| acc = 0 | |
| num_batches = 0 | |
| self.to_eval_mode() | |
| from data.datasets.sentiment_classification.global_bert_tokenizer import get_tokenizer | |
| tokenizer = get_tokenizer() | |
| def _decode(o): | |
| # https://github.com/huggingface/transformers/blob/main/examples/research_projects/seq2seq-distillation/finetune.py#L133 | |
| o = tokenizer.batch_decode(o, skip_special_tokens=True) | |
| return [oi.strip() for oi in o] | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| label = y.to(self.device) | |
| # generated_tokens = self.infer(x, generate=True) | |
| generated_tokens = self.infer(x).logits.argmax(-1) | |
| # pad tokens | |
| generated_tokens = pad_across_processes( | |
| generated_tokens, dim=1, pad_index=tokenizer.pad_token_id | |
| ) | |
| # pad label | |
| label = pad_across_processes( | |
| label, dim=1, pad_index=tokenizer.pad_token_id | |
| ) | |
| label = label.cpu().numpy() | |
| label = np.where(label != -100, label, tokenizer.pad_token_id) | |
| decoded_output = _decode(generated_tokens) | |
| decoded_y = _decode(y) | |
| decoded_y = [decoded_y] | |
| if batch_index == 0: | |
| print(decoded_y, decoded_output) | |
| bleu = corpus_bleu(decoded_output, decoded_y).score | |
| pbar.set_description(f'cur_batch_bleu: {bleu:.4f}') | |
| acc += bleu | |
| num_batches += 1 | |
| acc /= num_batches | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| if 'token_type_ids' in x.keys(): | |
| del x['token_type_ids'] | |
| if 'generate' in kwargs: | |
| return self.models_dict['main'].generate( | |
| x['input_ids'], | |
| attention_mask=x["attention_mask"], | |
| max_length=512 | |
| ) | |
| return self.models_dict['main'](**x) | |
| from nltk.metrics import accuracy as nltk_acc | |
| class ElasticDNN_OfflineTokenClsFMModel(ElasticDNN_OfflineFMModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| self.to_eval_mode() | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| # print(x) | |
| y = y.to(self.device) | |
| output = self.infer(x) | |
| # torch.Size([16, 512, 43]) torch.Size([16, 512]) | |
| for oi, yi, xi in zip(output, y, x['input_ids']): | |
| # oi: 512, 43; yi: 512 | |
| seq_len = xi.nonzero().size(0) | |
| # print(output.size(), y.size()) | |
| pred = F.softmax(oi, dim=-1).argmax(dim=-1) | |
| correct = torch.eq(pred[1: seq_len], yi[1: seq_len]).sum().item() | |
| # print(output.size(), y.size()) | |
| acc += correct | |
| sample_num += seq_len | |
| pbar.set_description(f'seq_len: {seq_len}, cur_seq_acc: {(correct / seq_len):.4f}') | |
| acc /= sample_num | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](**x) | |
| class ElasticDNN_OfflineMMClsFMModel(ElasticDNN_OfflineFMModel): | |
| # def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map): | |
| # super().__init__(name, models_dict_path, device) | |
| # self.class_to_label_idx_map = class_to_label_idx_map | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| self.to_eval_mode() | |
| batch_size = 1 | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| if batch_index * batch_size > 2000: | |
| break | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| y = y.to(self.device) | |
| # print(x) | |
| raw_texts = x['texts'][:] | |
| x['texts'] = list(set(x['texts'])) | |
| # print(x['texts']) | |
| batch_size = len(y) | |
| x['for_training'] = False | |
| output = self.infer(x) | |
| output = output.logits_per_image | |
| # print(output.size()) | |
| # exit() | |
| # y = torch.arange(len(y), device=self.device) | |
| y = torch.LongTensor([x['texts'].index(rt) for rt in raw_texts]).to(self.device) | |
| # print(y) | |
| # exit() | |
| # print(output.size(), y.size()) | |
| pred = F.softmax(output, dim=1).argmax(dim=1) | |
| correct = torch.eq(pred, y).sum().item() | |
| acc += correct | |
| sample_num += len(y) | |
| pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
| f'cur_batch_acc: {(correct / len(y)):.4f}') | |
| acc /= sample_num | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| x['for_training'] = self.models_dict['main'].training | |
| return self.models_dict['main'](**x) | |
| class VQAScore: | |
| def __init__(self): | |
| # self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
| # self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
| self.score = torch.tensor(0.0) | |
| self.total = torch.tensor(0.0) | |
| def update(self, logits, target): | |
| logits, target = ( | |
| logits.detach().float().to(self.score.device), | |
| target.detach().float().to(self.score.device), | |
| ) | |
| logits = torch.max(logits, 1)[1] | |
| one_hots = torch.zeros(*target.size()).to(target) | |
| one_hots.scatter_(1, logits.view(-1, 1), 1) | |
| scores = one_hots * target | |
| self.score += scores.sum() | |
| self.total += len(logits) | |
| def compute(self): | |
| return self.score / self.total | |
| class ElasticDNN_OfflineVQAFMModel(ElasticDNN_OfflineFMModel): | |
| # def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map): | |
| # super().__init__(name, models_dict_path, device) | |
| # self.class_to_label_idx_map = class_to_label_idx_map | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| vqa_score = VQAScore() | |
| self.to_eval_mode() | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| y = y.to(self.device) | |
| output = self.infer(x).logits | |
| vqa_score.update(output, y) | |
| pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}') | |
| return vqa_score.compute() | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](**x) | |
| class ElasticDNN_OfflineMDModel(BaseModel): | |
| def get_required_model_components(self) -> List[str]: | |
| return ['main'] | |
| def forward_to_get_task_loss(self, x, y, *args, **kwargs): | |
| raise NotImplementedError | |
| def get_feature_hook(self) -> LayerActivation: | |
| pass | |
| def get_distill_loss(self, student_output, teacher_output): | |
| pass | |
| def get_matched_param_of_fm(self, self_param_name, fm: nn.Module): | |
| pass | |
| class ElasticDNN_OfflineClsMDModel(ElasticDNN_OfflineMDModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| self.to_eval_mode() | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| x, y = x.to(self.device), y.to(self.device) | |
| output = self.infer(x) | |
| pred = F.softmax(output, dim=1).argmax(dim=1) | |
| correct = torch.eq(pred, y).sum().item() | |
| acc += correct | |
| sample_num += len(y) | |
| pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
| f'cur_batch_acc: {(correct / len(y)):.4f}') | |
| acc /= sample_num | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](x) | |
| class ElasticDNN_OfflineSegMDModel(ElasticDNN_OfflineMDModel): | |
| def __init__(self, name: str, models_dict_path: str, device: str, num_classes): | |
| super().__init__(name, models_dict_path, device) | |
| self.num_classes = num_classes | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| device = self.device | |
| self.to_eval_mode() | |
| metrics = StreamSegMetrics(self.num_classes) | |
| metrics.reset() | |
| import tqdm | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=False, dynamic_ncols=True) | |
| with torch.no_grad(): | |
| for batch_index, (x, y) in pbar: | |
| x, y = x.to(device, dtype=x.dtype, non_blocking=True, copy=False), \ | |
| y.to(device, dtype=y.dtype, non_blocking=True, copy=False) | |
| output = self.infer(x) | |
| pred = output.detach().max(dim=1)[1].cpu().numpy() | |
| metrics.update((y + 0).cpu().numpy(), pred) | |
| res = metrics.get_results() | |
| pbar.set_description(f'cur batch mIoU: {res["Mean IoU"]:.4f}') | |
| res = metrics.get_results() | |
| return res['Mean IoU'] | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](x) | |
| class ElasticDNN_OfflineDetMDModel(ElasticDNN_OfflineMDModel): | |
| def __init__(self, name: str, models_dict_path: str, device: str, num_classes): | |
| super().__init__(name, models_dict_path, device) | |
| self.num_classes = num_classes | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| # print('DeeplabV3: start test acc') | |
| _d = test_loader.dataset | |
| from data import build_dataloader | |
| if _d.__class__.__name__ == 'MergedDataset': | |
| # print('\neval on merged datasets') | |
| datasets = _d.datasets | |
| test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None) for d in datasets] | |
| accs = [self.get_accuracy(loader) for loader in test_loaders] | |
| # print(accs) | |
| return sum(accs) / len(accs) | |
| # print('dataset len', len(test_loader.dataset)) | |
| model = self.models_dict['main'] | |
| device = self.device | |
| model.eval() | |
| # print('# classes', model.num_classes) | |
| model = model.to(device) | |
| from dnns.yolov3.coco_evaluator import COCOEvaluator | |
| from utils.common.others import HiddenPrints | |
| with torch.no_grad(): | |
| with HiddenPrints(): | |
| evaluator = COCOEvaluator( | |
| dataloader=test_loader, | |
| img_size=(224, 224), | |
| confthre=0.01, | |
| nmsthre=0.65, | |
| num_classes=self.num_classes, | |
| testdev=False | |
| ) | |
| res = evaluator.evaluate(model, False, False) | |
| map50 = res[1] | |
| # print('eval info', res[-1]) | |
| return map50 | |
| def infer(self, x, *args, **kwargs): | |
| if len(args) > 0: | |
| return self.models_dict['main'](x, *args) # forward(x, label) | |
| return self.models_dict['main'](x) | |
| class ElasticDNN_OfflineSenClsMDModel(ElasticDNN_OfflineMDModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| self.to_eval_mode() | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| y = y.to(self.device) | |
| output = self.infer(x) | |
| pred = F.softmax(output, dim=1).argmax(dim=1) | |
| # print(pred, y) | |
| correct = torch.eq(pred, y).sum().item() | |
| acc += correct | |
| sample_num += len(y) | |
| pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
| f'cur_batch_acc: {(correct / len(y)):.4f}') | |
| acc /= sample_num | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](**x) | |
| class ElasticDNN_OfflineTrMDModel(ElasticDNN_OfflineMDModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| # TODO: BLEU | |
| from sacrebleu import corpus_bleu | |
| acc = 0 | |
| num_batches = 0 | |
| self.to_eval_mode() | |
| from data.datasets.sentiment_classification.global_bert_tokenizer import get_tokenizer | |
| tokenizer = get_tokenizer() | |
| def _decode(o): | |
| # https://github.com/huggingface/transformers/blob/main/examples/research_projects/seq2seq-distillation/finetune.py#L133 | |
| o = tokenizer.batch_decode(o, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| return [oi.strip().replace(' ', '') for oi in o] | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| y = y.to(self.device) | |
| output = self.infer(x) | |
| decoded_output = _decode(output.argmax(-1)) | |
| decoded_y = _decode(y) | |
| decoded_y = [decoded_y] | |
| # print(x, decoded_y, decoded_output, output.argmax(-1)) | |
| bleu = corpus_bleu(decoded_output, decoded_y).score | |
| pbar.set_description(f'cur_batch_bleu: {bleu:.4f}') | |
| acc += bleu | |
| num_batches += 1 | |
| acc /= num_batches | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](**x) | |
| class ElasticDNN_OfflineTokenClsMDModel(ElasticDNN_OfflineMDModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| self.to_eval_mode() | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| # print(x) | |
| y = y.to(self.device) | |
| output = self.infer(x) | |
| # torch.Size([16, 512, 43]) torch.Size([16, 512]) | |
| for oi, yi, xi in zip(output, y, x['input_ids']): | |
| # oi: 512, 43; yi: 512 | |
| seq_len = xi.nonzero().size(0) | |
| # print(output.size(), y.size()) | |
| pred = F.softmax(oi, dim=-1).argmax(dim=-1) | |
| correct = torch.eq(pred[1: seq_len], yi[1: seq_len]).sum().item() | |
| # print(output.size(), y.size()) | |
| acc += correct | |
| sample_num += seq_len | |
| pbar.set_description(f'seq_len: {seq_len}, cur_seq_acc: {(correct / seq_len):.4f}') | |
| acc /= sample_num | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](**x) | |
| class ElasticDNN_OfflineMMClsMDModel(ElasticDNN_OfflineMDModel): | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| self.to_eval_mode() | |
| batch_size = 1 | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| if batch_index * batch_size > 2000: | |
| break | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| y = y.to(self.device) | |
| # print(x) | |
| raw_texts = x['texts'][:] | |
| x['texts'] = list(set(x['texts'])) | |
| # print(x['texts']) | |
| batch_size = len(y) | |
| x['for_training'] = False | |
| output = self.infer(x) | |
| output = output.logits_per_image | |
| # print(output.size()) | |
| # exit() | |
| # y = torch.arange(len(y), device=self.device) | |
| y = torch.LongTensor([x['texts'].index(rt) for rt in raw_texts]).to(self.device) | |
| # print(y) | |
| # exit() | |
| # print(output.size(), y.size()) | |
| pred = F.softmax(output, dim=1).argmax(dim=1) | |
| correct = torch.eq(pred, y).sum().item() | |
| acc += correct | |
| sample_num += len(y) | |
| pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
| f'cur_batch_acc: {(correct / len(y)):.4f}') | |
| acc /= sample_num | |
| return acc | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](**x) | |
| class ElasticDNN_OfflineVQAMDModel(ElasticDNN_OfflineMDModel): | |
| # def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map): | |
| # super().__init__(name, models_dict_path, device) | |
| # self.class_to_label_idx_map = class_to_label_idx_map | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| acc = 0 | |
| sample_num = 0 | |
| vqa_score = VQAScore() | |
| self.to_eval_mode() | |
| with torch.no_grad(): | |
| pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
| for batch_index, (x, y) in pbar: | |
| for k, v in x.items(): | |
| if isinstance(v, torch.Tensor): | |
| x[k] = v.to(self.device) | |
| y = y.to(self.device) | |
| output = self.infer(x).logits | |
| vqa_score.update(output, y) | |
| pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}') | |
| return vqa_score.compute() | |
| def infer(self, x, *args, **kwargs): | |
| return self.models_dict['main'](**x) |