Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import itertools | |
| import json | |
| import linecache | |
| import math | |
| import os | |
| import pickle | |
| import socket | |
| from logging import getLogger | |
| from pathlib import Path | |
| from typing import Callable, Dict, Iterable, List, Tuple, Union | |
| import git | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from rouge_score import rouge_scorer, scoring | |
| from sacrebleu import corpus_bleu | |
| from sentence_splitter import add_newline_to_end_of_each_sentence | |
| from torch import nn | |
| from torch.utils.data import Dataset, Sampler | |
| from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer | |
| from transformers.models.bart.modeling_bart import shift_tokens_right | |
| from transformers.utils import cached_property | |
| try: | |
| from fairseq.data.data_utils import batch_by_size | |
| FAIRSEQ_AVAILABLE = True | |
| except (ImportError, ModuleNotFoundError): | |
| FAIRSEQ_AVAILABLE = False | |
| def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): | |
| """From fairseq""" | |
| if target.dim() == lprobs.dim() - 1: | |
| target = target.unsqueeze(-1) | |
| nll_loss = -lprobs.gather(dim=-1, index=target) | |
| smooth_loss = -lprobs.sum(dim=-1, keepdim=True) | |
| if ignore_index is not None: | |
| pad_mask = target.eq(ignore_index) | |
| nll_loss.masked_fill_(pad_mask, 0.0) | |
| smooth_loss.masked_fill_(pad_mask, 0.0) | |
| else: | |
| nll_loss = nll_loss.squeeze(-1) | |
| smooth_loss = smooth_loss.squeeze(-1) | |
| nll_loss = nll_loss.sum() # mean()? Scared to break other math. | |
| smooth_loss = smooth_loss.sum() | |
| eps_i = epsilon / lprobs.size(-1) | |
| loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss | |
| return loss, nll_loss | |
| def lmap(f: Callable, x: Iterable) -> List: | |
| """list(map(f, x))""" | |
| return list(map(f, x)) | |
| def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: | |
| """Uses sacrebleu's corpus_bleu implementation.""" | |
| return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} | |
| def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]: | |
| def non_pad_len(tokens: np.ndarray) -> int: | |
| return np.count_nonzero(tokens != tokenizer.pad_token_id) | |
| def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: | |
| pred_ids = pred.predictions | |
| label_ids = pred.label_ids | |
| pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
| label_ids[label_ids == -100] = tokenizer.pad_token_id | |
| label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) | |
| pred_str = lmap(str.strip, pred_str) | |
| label_str = lmap(str.strip, label_str) | |
| return pred_str, label_str | |
| def summarization_metrics(pred: EvalPrediction) -> Dict: | |
| pred_str, label_str = decode_pred(pred) | |
| rouge: Dict = calculate_rouge(pred_str, label_str) | |
| summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) | |
| rouge.update({"gen_len": summ_len}) | |
| return rouge | |
| def translation_metrics(pred: EvalPrediction) -> Dict: | |
| pred_str, label_str = decode_pred(pred) | |
| bleu: Dict = calculate_bleu(pred_str, label_str) | |
| gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) | |
| bleu.update({"gen_len": gen_len}) | |
| return bleu | |
| compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics | |
| return compute_metrics_fn | |
| def trim_batch( | |
| input_ids, | |
| pad_token_id, | |
| attention_mask=None, | |
| ): | |
| """Remove columns that are populated exclusively by pad_token_id""" | |
| keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) | |
| if attention_mask is None: | |
| return input_ids[:, keep_column_mask] | |
| else: | |
| return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) | |
| class AbstractSeq2SeqDataset(Dataset): | |
| def __init__( | |
| self, | |
| tokenizer, | |
| data_dir, | |
| max_source_length, | |
| max_target_length, | |
| type_path="train", | |
| n_obs=None, | |
| prefix="", | |
| **dataset_kwargs, | |
| ): | |
| super().__init__() | |
| self.src_file = Path(data_dir).joinpath(type_path + ".source") | |
| self.tgt_file = Path(data_dir).joinpath(type_path + ".target") | |
| self.len_file = Path(data_dir).joinpath(type_path + ".len") | |
| if os.path.exists(self.len_file): | |
| self.src_lens = pickle_load(self.len_file) | |
| self.used_char_len = False | |
| else: | |
| self.src_lens = self.get_char_lens(self.src_file) | |
| self.used_char_len = True | |
| self.max_source_length = max_source_length | |
| self.max_target_length = max_target_length | |
| assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" | |
| self.tokenizer = tokenizer | |
| self.prefix = prefix if prefix is not None else "" | |
| if n_obs is not None: | |
| self.src_lens = self.src_lens[:n_obs] | |
| self.pad_token_id = self.tokenizer.pad_token_id | |
| self.dataset_kwargs = dataset_kwargs | |
| dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}) | |
| def __len__(self): | |
| return len(self.src_lens) | |
| def get_char_lens(data_file): | |
| return [len(x) for x in Path(data_file).open().readlines()] | |
| def tgt_lens(self): | |
| """Length in characters of target documents""" | |
| return self.get_char_lens(self.tgt_file) | |
| def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs): | |
| if distributed: | |
| return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs) | |
| else: | |
| return SortishSampler(self.src_lens, batch_size, shuffle=shuffle) | |
| def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs): | |
| assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`" | |
| assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler" | |
| sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False)) | |
| def num_tokens_in_example(i): | |
| return min(self.src_lens[i], self.max_target_length) | |
| # call fairseq cython function | |
| batch_sampler: List[List[int]] = batch_by_size( | |
| sorted_indices, | |
| num_tokens_fn=num_tokens_in_example, | |
| max_tokens=max_tokens_per_batch, | |
| required_batch_size_multiple=64, | |
| ) | |
| shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))] | |
| # move the largest batch to the front to OOM quickly (uses an approximation for padding) | |
| approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches] | |
| largest_batch_idx = np.argmax(approximate_toks_per_batch) | |
| shuffled_batches[0], shuffled_batches[largest_batch_idx] = ( | |
| shuffled_batches[largest_batch_idx], | |
| shuffled_batches[0], | |
| ) | |
| return shuffled_batches | |
| def __getitem__(self, item): | |
| raise NotImplementedError("You must implement this") | |
| def collate_fn(self, batch): | |
| raise NotImplementedError("You must implement this") | |
| class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): | |
| def __getitem__(self, index) -> Dict[str, torch.Tensor]: | |
| """Call tokenizer on src and tgt_lines""" | |
| index = index + 1 # linecache starts at 1 | |
| source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") | |
| tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | |
| assert source_line, f"empty source line for index {index}" | |
| assert tgt_line, f"empty tgt line for index {index}" | |
| source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length) | |
| target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length) | |
| source_ids = source_inputs["input_ids"].squeeze() | |
| target_ids = target_inputs["input_ids"].squeeze() | |
| src_mask = source_inputs["attention_mask"].squeeze() | |
| return { | |
| "input_ids": source_ids, | |
| "attention_mask": src_mask, | |
| "labels": target_ids, | |
| } | |
| def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): | |
| """Only used by LegacyDataset""" | |
| return tokenizer( | |
| [line], | |
| max_length=max_length, | |
| padding="max_length" if pad_to_max_length else None, | |
| truncation=True, | |
| return_tensors=return_tensors, | |
| **self.dataset_kwargs, | |
| ) | |
| def collate_fn(self, batch) -> Dict[str, torch.Tensor]: | |
| input_ids = torch.stack([x["input_ids"] for x in batch]) | |
| masks = torch.stack([x["attention_mask"] for x in batch]) | |
| target_ids = torch.stack([x["labels"] for x in batch]) | |
| pad_token_id = self.pad_token_id | |
| y = trim_batch(target_ids, pad_token_id) | |
| source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) | |
| batch = { | |
| "input_ids": source_ids, | |
| "attention_mask": source_mask, | |
| "labels": y, | |
| } | |
| return batch | |
| class Seq2SeqDataset(AbstractSeq2SeqDataset): | |
| """A dataset that calls prepare_seq2seq_batch.""" | |
| def __getitem__(self, index) -> Dict[str, str]: | |
| index = index + 1 # linecache starts at 1 | |
| source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") | |
| tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") | |
| assert source_line, f"empty source line for index {index}" | |
| assert tgt_line, f"empty tgt line for index {index}" | |
| return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1} | |
| def collate_fn(self, batch) -> Dict[str, torch.Tensor]: | |
| """Call prepare_seq2seq_batch.""" | |
| batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( | |
| [x["src_texts"] for x in batch], | |
| tgt_texts=[x["tgt_texts"] for x in batch], | |
| max_length=self.max_source_length, | |
| max_target_length=self.max_target_length, | |
| return_tensors="pt", | |
| **self.dataset_kwargs, | |
| ).data | |
| batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) | |
| return batch_encoding | |
| class Seq2SeqDataCollator: | |
| def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None): | |
| self.tokenizer = tokenizer | |
| self.pad_token_id = tokenizer.pad_token_id | |
| self.decoder_start_token_id = decoder_start_token_id | |
| assert ( | |
| self.pad_token_id is not None | |
| ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." | |
| self.data_args = data_args | |
| self.tpu_num_cores = tpu_num_cores | |
| self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} | |
| if data_args.src_lang is not None: | |
| self.dataset_kwargs["src_lang"] = data_args.src_lang | |
| if data_args.tgt_lang is not None: | |
| self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang | |
| def __call__(self, batch) -> Dict[str, torch.Tensor]: | |
| if hasattr(self.tokenizer, "prepare_seq2seq_batch"): | |
| batch = self._encode(batch) | |
| input_ids, attention_mask, labels = ( | |
| batch["input_ids"], | |
| batch["attention_mask"], | |
| batch["labels"], | |
| ) | |
| else: | |
| input_ids = torch.stack([x["input_ids"] for x in batch]) | |
| attention_mask = torch.stack([x["attention_mask"] for x in batch]) | |
| labels = torch.stack([x["labels"] for x in batch]) | |
| labels = trim_batch(labels, self.pad_token_id) | |
| input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) | |
| if isinstance(self.tokenizer, T5Tokenizer): | |
| decoder_input_ids = self._shift_right_t5(labels) | |
| else: | |
| decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id) | |
| batch = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "decoder_input_ids": decoder_input_ids, | |
| "labels": labels, | |
| } | |
| return batch | |
| def _shift_right_t5(self, input_ids): | |
| # shift inputs to the right | |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
| shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() | |
| shifted_input_ids[..., 0] = self.pad_token_id | |
| return shifted_input_ids | |
| def _encode(self, batch) -> Dict[str, torch.Tensor]: | |
| batch_encoding = self.tokenizer.prepare_seq2seq_batch( | |
| [x["src_texts"] for x in batch], | |
| tgt_texts=[x["tgt_texts"] for x in batch], | |
| max_length=self.data_args.max_source_length, | |
| max_target_length=self.data_args.max_target_length, | |
| padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack | |
| return_tensors="pt", | |
| **self.dataset_kwargs, | |
| ) | |
| return batch_encoding.data | |
| class SortishSampler(Sampler): | |
| "Go through the text data by order of src length with a bit of randomness. From fastai repo." | |
| def __init__(self, data, batch_size, shuffle=True): | |
| self.data, self.bs, self.shuffle = data, batch_size, shuffle | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def __iter__(self): | |
| return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle)) | |
| def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array: | |
| "Go through the text data by order of src length with a bit of randomness. From fastai repo." | |
| if not shuffle: | |
| return np.argsort(np.array(data) * -1) | |
| def key_fn(i): | |
| return data[i] | |
| idxs = np.random.permutation(len(data)) | |
| sz = bs * 50 | |
| ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] | |
| sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx]) | |
| sz = bs | |
| ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] | |
| max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, | |
| ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. | |
| sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=int) | |
| sort_idx = np.concatenate((ck_idx[0], sort_idx)) | |
| return sort_idx | |
| class DistributedSortishSampler(Sampler): | |
| """Copied from torch DistributedSampler""" | |
| def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True): | |
| if num_replicas is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| num_replicas = dist.get_world_size() | |
| if rank is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| rank = dist.get_rank() | |
| self.dataset = dataset | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| self.epoch = 0 | |
| if add_extra_examples: | |
| self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) | |
| self.total_size = self.num_samples * self.num_replicas | |
| else: | |
| self.total_size = len(dataset) | |
| self.num_samples = len(self.available_indices) | |
| self.batch_size = batch_size | |
| self.add_extra_examples = add_extra_examples | |
| self.shuffle = shuffle | |
| def __iter__(self) -> Iterable: | |
| g = torch.Generator() | |
| g.manual_seed(self.epoch) | |
| sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] | |
| sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle) | |
| indices = [self.available_indices[i] for i in sortish_indices] | |
| assert len(indices) == self.num_samples | |
| return iter(indices) | |
| def available_indices(self) -> np.array: | |
| indices = list(range(len(self.dataset))) | |
| # add extra samples to make it evenly divisible | |
| indices += indices[: (self.total_size - len(indices))] | |
| assert len(indices) == self.total_size | |
| # subsample | |
| available_indices = indices[self.rank : self.total_size : self.num_replicas] | |
| return available_indices | |
| def __len__(self): | |
| return self.num_samples | |
| def set_epoch(self, epoch): | |
| self.epoch = epoch | |
| logger = getLogger(__name__) | |
| def use_task_specific_params(model, task): | |
| """Update config with summarization specific params.""" | |
| task_specific_params = model.config.task_specific_params | |
| if task_specific_params is not None: | |
| pars = task_specific_params.get(task, {}) | |
| logger.info(f"setting model.config to task specific params for {task}:\n {pars}") | |
| logger.info("note: command line args may override some of these") | |
| model.config.update(pars) | |
| def pickle_load(path): | |
| """pickle.load(path)""" | |
| with open(path, "rb") as f: | |
| return pickle.load(f) | |
| def pickle_save(obj, path): | |
| """pickle.dump(obj, path)""" | |
| with open(path, "wb") as f: | |
| return pickle.dump(obj, f) | |
| def flatten_list(summary_ids: List[List]): | |
| return list(itertools.chain.from_iterable(summary_ids)) | |
| def save_git_info(folder_path: str) -> None: | |
| """Save git information to output_dir/git_log.json""" | |
| repo_infos = get_git_info() | |
| save_json(repo_infos, os.path.join(folder_path, "git_log.json")) | |
| def save_json(content, path, indent=4, **json_dump_kwargs): | |
| with open(path, "w") as f: | |
| json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs) | |
| def load_json(path): | |
| with open(path) as f: | |
| return json.load(f) | |
| def get_git_info(): | |
| try: | |
| repo = git.Repo(search_parent_directories=True) | |
| repo_infos = { | |
| "repo_id": str(repo), | |
| "repo_sha": str(repo.head.object.hexsha), | |
| "repo_branch": str(repo.active_branch), | |
| "hostname": str(socket.gethostname()), | |
| } | |
| return repo_infos | |
| except TypeError: | |
| return { | |
| "repo_id": None, | |
| "repo_sha": None, | |
| "repo_branch": None, | |
| "hostname": None, | |
| } | |
| ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] | |
| def extract_rouge_mid_statistics(dct): | |
| new_dict = {} | |
| for k1, v1 in dct.items(): | |
| mid = v1.mid | |
| new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]} | |
| return new_dict | |
| def calculate_rouge( | |
| pred_lns: List[str], | |
| tgt_lns: List[str], | |
| use_stemmer=True, | |
| rouge_keys=ROUGE_KEYS, | |
| return_precision_and_recall=False, | |
| bootstrap_aggregation=True, | |
| newline_sep=True, | |
| ) -> Dict: | |
| """Calculate rouge using rouge_scorer package. | |
| Args: | |
| pred_lns: list of summaries generated by model | |
| tgt_lns: list of groundtruth summaries (e.g. contents of val.target) | |
| use_stemmer: Bool indicating whether Porter stemmer should be used to | |
| strip word suffixes to improve matching. | |
| rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum | |
| return_precision_and_recall: (False) whether to also return precision and recall. | |
| bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False | |
| this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]`` | |
| newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL | |
| on multi sentence summaries (CNN/DM dataset). | |
| Returns: | |
| Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys | |
| """ | |
| scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer) | |
| aggregator = scoring.BootstrapAggregator() | |
| for pred, tgt in zip(tgt_lns, pred_lns): | |
| # rougeLsum expects "\n" separated sentences within a summary | |
| if newline_sep: | |
| pred = add_newline_to_end_of_each_sentence(pred) | |
| tgt = add_newline_to_end_of_each_sentence(tgt) | |
| scores = scorer.score(pred, tgt) | |
| aggregator.add_scores(scores) | |
| if bootstrap_aggregation: | |
| result = aggregator.aggregate() | |
| if return_precision_and_recall: | |
| return extract_rouge_mid_statistics(result) # here we return dict | |
| else: | |
| return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} | |
| else: | |
| return aggregator._scores # here we return defaultdict(list) | |
| # Utilities for freezing parameters and checking whether they are frozen | |
| def freeze_params(model: nn.Module): | |
| """Set requires_grad=False for each of model.parameters()""" | |
| for par in model.parameters(): | |
| par.requires_grad = False | |
| def freeze_embeds(model): | |
| """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" | |
| model_type = model.config.model_type | |
| if model_type in ["t5", "mt5"]: | |
| freeze_params(model.shared) | |
| for d in [model.encoder, model.decoder]: | |
| freeze_params(d.embed_tokens) | |
| elif model_type == "fsmt": | |
| for d in [model.model.encoder, model.model.decoder]: | |
| freeze_params(d.embed_positions) | |
| freeze_params(d.embed_tokens) | |
| else: | |
| freeze_params(model.model.shared) | |
| for d in [model.model.encoder, model.model.decoder]: | |
| freeze_params(d.embed_positions) | |
| freeze_params(d.embed_tokens) | |
| def grad_status(model: nn.Module) -> Iterable: | |
| return (par.requires_grad for par in model.parameters()) | |
| def any_requires_grad(model: nn.Module) -> bool: | |
| return any(grad_status(model)) | |
| def assert_all_frozen(model): | |
| model_grads: List[bool] = list(grad_status(model)) | |
| n_require_grad = sum(lmap(int, model_grads)) | |
| npars = len(model_grads) | |
| assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" | |
| def assert_not_all_frozen(model): | |
| model_grads: List[bool] = list(grad_status(model)) | |
| npars = len(model_grads) | |
| assert any(model_grads), f"none of {npars} weights require grad" | |
| def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]: | |
| """ | |
| Parse an argv list of unspecified command line args to a dict. | |
| Assumes all values are either numeric or boolean in the form of true/false. | |
| """ | |
| result = {} | |
| assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}" | |
| num_pairs = len(unparsed_args) // 2 | |
| for pair_num in range(num_pairs): | |
| i = 2 * pair_num | |
| assert unparsed_args[i].startswith("--") | |
| if unparsed_args[i + 1].lower() == "true": | |
| value = True | |
| elif unparsed_args[i + 1].lower() == "false": | |
| value = False | |
| else: | |
| try: | |
| value = int(unparsed_args[i + 1]) | |
| except ValueError: | |
| value = float(unparsed_args[i + 1]) # this can raise another informative ValueError | |
| result[unparsed_args[i][2:]] = value | |
| return result | |
| def write_txt_file(ordered_tgt, path): | |
| f = Path(path).open("w") | |
| for ln in ordered_tgt: | |
| f.write(ln + "\n") | |
| f.flush() | |
| def chunks(lst, n): | |
| """Yield successive n-sized chunks from lst.""" | |
| for i in range(0, len(lst), n): | |
| yield lst[i : i + n] | |
| def check_output_dir(args, expected_items=0): | |
| """ | |
| Checks whether to bail out if output_dir already exists and has more than expected_items in it | |
| `args`: needs to have the following attributes of `args`: | |
| - output_dir | |
| - do_train | |
| - overwrite_output_dir | |
| `expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM) | |
| """ | |
| if ( | |
| os.path.exists(args.output_dir) | |
| and len(os.listdir(args.output_dir)) > expected_items | |
| and args.do_train | |
| and not args.overwrite_output_dir | |
| ): | |
| raise ValueError( | |
| f"Output directory ({args.output_dir}) already exists and " | |
| f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). " | |
| "Use --overwrite_output_dir to overcome." | |
| ) | |