Spaces:
Sleeping
Sleeping
| import torch | |
| from .prior import Batch | |
| from ..utils import default_device | |
| loaded_models = {} | |
| def get_model(model_name, device): | |
| if model_name not in loaded_models: | |
| import submitit | |
| group, index = model_name.split(':') | |
| ex = submitit.get_executor() | |
| model = ex.get_group(group)[int(index)].results()[0][2] | |
| model.to(device) | |
| loaded_models[model_name] = model | |
| return loaded_models[model_name] | |
| def get_batch(batch_size, seq_len, num_features, get_batch, model, single_eval_pos, epoch, device=default_device, hyperparameters=None, **kwargs): | |
| """ | |
| Important Assumptions: | |
| 'inf_batch_size', 'max_level', 'sample_only_one_level', 'eval_seq_len' and 'epochs_per_level' in hyperparameters | |
| You can train a new model, based on an old one to only sample from a single level. | |
| You specify `level_0_model` as a group:index string and the model will be loaded from the checkpoint. | |
| :param batch_size: | |
| :param seq_len: | |
| :param num_features: | |
| :param get_batch: | |
| :param model: | |
| :param single_eval_pos: | |
| :param epoch: | |
| :param device: | |
| :param hyperparameters: | |
| :param kwargs: | |
| :return: | |
| """ | |
| if level_0_model := hyperparameters.get('level_0_model', None): | |
| assert hyperparameters['sample_only_one_level'], "level_0_model only makes sense if you sample only one level" | |
| assert hyperparameters['max_level'] == 1, "level_0_model only makes sense if you sample only one level" | |
| level_0_model = get_model(level_0_model, device) | |
| model = level_0_model | |
| # the level describes how many fantasized steps are possible. This starts at 0 for the first epochs. | |
| epochs_per_level = hyperparameters['epochs_per_level'] | |
| share_predict_mean_distribution = hyperparameters.get('share_predict_mean_distribution', 0.) | |
| use_mean_prediction = share_predict_mean_distribution or\ | |
| (model.decoder_dict_once is not None and 'mean_prediction' in model.decoder_dict_once) | |
| num_evals = seq_len - single_eval_pos | |
| level = min(min(epoch // epochs_per_level, hyperparameters['max_level']), num_evals - 1) | |
| if level_0_model: | |
| level = 1 | |
| eval_seq_len = hyperparameters['eval_seq_len'] | |
| add_seq_len = 0 if use_mean_prediction else eval_seq_len | |
| long_seq_len = seq_len + add_seq_len | |
| if level_0_model: | |
| styles = torch.ones(batch_size, 1, device=device, dtype=torch.long) | |
| elif hyperparameters['sample_only_one_level']: | |
| styles = torch.randint(level + 1, (1, 1), device=device).repeat(batch_size, 1) # styles are sorted :) | |
| else: | |
| styles = torch.randint(level + 1, (batch_size,1), device=device).sort(0).values # styles are sorted :) | |
| predict_mean_distribution = None | |
| if share_predict_mean_distribution: | |
| max_used_level = max(styles) | |
| # below code assumes epochs are base 0! | |
| share_of_training = epoch / epochs_per_level | |
| #print(share_of_training, (max_used_level + 1. - share_predict_mean_distribution), max_used_level, level, epoch) | |
| predict_mean_distribution = (share_of_training >= (max_used_level + 1. - share_predict_mean_distribution)) and (max_used_level < hyperparameters['max_level']) | |
| x, y, targets = [], [], [] | |
| for considered_level in range(level+1): | |
| num_elements = (styles == considered_level).sum() | |
| if not num_elements: | |
| continue | |
| returns: Batch = get_batch(batch_size=num_elements, seq_len=long_seq_len, | |
| num_features=num_features, device=device, | |
| hyperparameters=hyperparameters, model=model, | |
| single_eval_pos=single_eval_pos, epoch=epoch, | |
| **kwargs) | |
| levels_x, levels_y, levels_targets = returns.x, returns.y, returns.target_y | |
| assert not returns.other_filled_attributes(), f"Unexpected filled attributes: {returns.other_filled_attributes()}" | |
| assert levels_y is levels_targets | |
| levels_targets = levels_targets.clone() | |
| if len(levels_y.shape) == 2: | |
| levels_y = levels_y.unsqueeze(2) | |
| levels_targets = levels_targets.unsqueeze(2) | |
| if considered_level > 0: | |
| feed_x = levels_x[:single_eval_pos + 1 + add_seq_len].repeat(1, num_evals, 1) | |
| feed_x[single_eval_pos, :] = levels_x[single_eval_pos:seq_len].reshape(-1, *levels_x.shape[2:]) | |
| if not use_mean_prediction: | |
| feed_x[single_eval_pos + 1:] = levels_x[seq_len:].repeat(1, num_evals, 1) | |
| feed_y = levels_y[:single_eval_pos + 1 + add_seq_len].repeat(1, num_evals, 1) | |
| feed_y[single_eval_pos, :] = levels_y[single_eval_pos:seq_len].reshape(-1, *levels_y.shape[2:]) | |
| if not use_mean_prediction: | |
| feed_y[single_eval_pos + 1:] = levels_y[seq_len:].repeat(1, num_evals, 1) | |
| model.eval() | |
| means = [] | |
| for feed_x_b, feed_y_b in zip(torch.split(feed_x, hyperparameters['inf_batch_size'], dim=1), | |
| torch.split(feed_y, hyperparameters['inf_batch_size'], dim=1)): | |
| with torch.cuda.amp.autocast(): | |
| style = torch.zeros(feed_x_b.shape[1], 1, dtype=torch.int64, device=device) + considered_level - 1 | |
| if level_0_model is not None and level_0_model.style_encoder is None: | |
| style = None | |
| out = model( | |
| (style, feed_x_b, feed_y_b), | |
| single_eval_pos=single_eval_pos+1, only_return_standard_out=False | |
| ) | |
| if isinstance(out, tuple): | |
| output, once_output = out | |
| else: | |
| output = out | |
| once_output = {} | |
| if once_output and 'mean_prediction' in once_output: | |
| mean_pred_logits = once_output['mean_prediction'].float() | |
| assert tuple(mean_pred_logits.shape) == (feed_x_b.shape[1], model.criterion.num_bars),\ | |
| f"{tuple(mean_pred_logits.shape)} vs {(feed_x_b.shape[1], model.criterion.num_bars)}" | |
| means.append(model.criterion.icdf(mean_pred_logits, 1.-1./eval_seq_len)) | |
| else: | |
| logits = output['standard'].float() | |
| means.append(model.criterion.mean(logits).max(0).values) | |
| means = torch.cat(means, 0) | |
| levels_targets_new = means.view(seq_len-single_eval_pos, *levels_y.shape[1:]) | |
| levels_targets[single_eval_pos:seq_len] = levels_targets_new #- levels_targets_new.mean(0) | |
| model.train() | |
| levels_x = levels_x[:seq_len] | |
| levels_y = levels_y[:seq_len] | |
| levels_targets = levels_targets[:seq_len] | |
| x.append(levels_x) | |
| y.append(levels_y) | |
| targets.append(levels_targets) | |
| x = torch.cat(x, 1) | |
| # if predict_mean_distribution: print(f'predict mean dist in b, {epoch=}, {max_used_level=}') | |
| return Batch(x=x, y=torch.cat(y, 1), target_y=torch.cat(targets, 1), style=styles, | |
| mean_prediction=predict_mean_distribution.item() if predict_mean_distribution is not None else None) | |