Spaces:
Sleeping
Sleeping
| import contextlib | |
| import torch | |
| import scipy | |
| import math | |
| from sklearn.preprocessing import power_transform, PowerTransformer, StandardScaler | |
| # from torchvision.transforms.functional import to_tensor | |
| from pfns4bo import transformer | |
| from pfns4bo import bar_distribution | |
| import torch | |
| import numpy as np | |
| import pfns4bo | |
| from pfns4bo.scripts.acquisition_functions import TransformerBOMethod | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| device = torch.device("cpu") | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dtype = torch.float32 | |
| from sklearn.utils import resample | |
| def Rosen_PFN(model_name, | |
| trained_X, | |
| trained_Y, | |
| X_pen, | |
| trasform_type, | |
| what_do_you_want | |
| ): | |
| PFN = TransformerBOMethod(torch.load(model_name).requires_grad_(False), device=device) | |
| # X_pen.requires_grad_(True) | |
| # with torch.no_grad(): | |
| dim = trained_X.shape[1] | |
| x_given = trained_X | |
| x_eval = X_pen | |
| x_predict = torch.cat([x_given, x_eval], dim=0) | |
| x_full_feed = torch.cat([x_given, x_given, x_eval], dim=0).unsqueeze(1) | |
| if trasform_type== 'std': | |
| pt = StandardScaler() | |
| pt.fit(trained_Y) | |
| PT_trained_Y = pt.transform(trained_Y) | |
| # trained_Y = to_tensor(PT_trained_Y).to(torch.float32).reshape(trained_Y.shape) | |
| elif trasform_type== 'power': | |
| pt = PowerTransformer(method="yeo-johnson") | |
| pt.fit(trained_Y.detach().numpy()) | |
| # PT_trained_Y = pt.transform(trained_Y.detach().numpy()) | |
| # trained_Y = to_tensor(PT_trained_Y).to(torch.float32).reshape(trained_Y.shape) | |
| # print(trained_Y.shape) | |
| # print(trained_Y) | |
| trained_Y, _ = general_power_transform(trained_Y, | |
| trained_Y, | |
| .0, | |
| less_safe=False) #.squeeze(1) | |
| # print(trained_Y.shape) | |
| # .squeeze(1) | |
| # y_given = general_power_transform(y_given.unsqueeze(1), | |
| # y_given.unsqueeze(1), | |
| # .0, | |
| # less_safe=False).squeeze(1) | |
| y_given = trained_Y | |
| y_given = y_given.reshape(-1) | |
| y_full_feed = y_given.unsqueeze(1) | |
| criterion: bar_distribution.BarDistribution = PFN.model.criterion | |
| style = None | |
| logits = PFN.model( | |
| (style, | |
| x_full_feed.repeat_interleave(dim=1, repeats=y_full_feed.shape[1]), | |
| y_full_feed.repeat(1,x_full_feed.shape[1])), | |
| single_eval_pos=len(x_given) | |
| ) | |
| # logits = logits.softmax(-1).log_() | |
| logits = logits.softmax(-1).log() | |
| logits_given = logits[:len(x_given)] | |
| logits_eval = logits[len(x_given):] | |
| best_f = torch.max(y_given) | |
| if what_do_you_want == 'mean': | |
| output = criterion.mean(logits_eval) | |
| if trasform_type== 'std' or trasform_type== 'power': | |
| if pt.standardize: | |
| XX = output.clone() | |
| scale = torch.from_numpy(pt._scaler.scale_) | |
| std_mean = torch.from_numpy(pt._scaler.mean_) | |
| XX = torch_std_inverse_transform(XX, scale, std_mean) | |
| for i, lmbda in enumerate(pt.lambdas_): | |
| with np.errstate(invalid="ignore"): # hide NaN warnings | |
| XX = torch_power_inverse_transform(XX, lmbda) | |
| # print(XX) | |
| return XX | |
| # output = pt.inverse_transform(output) | |
| # output = torch.from_numpy(output) | |
| elif what_do_you_want == 'ei': | |
| output = criterion.ei(logits_eval, best_f) | |
| elif what_do_you_want == 'ucb': | |
| acq_function = criterion.ucb | |
| ucb_rest_prob = .05 | |
| if ucb_rest_prob is not None: | |
| acq_function = lambda *args: criterion.ucb(*args, rest_prob=ucb_rest_prob) | |
| output = acq_ensembling(acq_function(logits_eval, best_f)) | |
| elif what_do_you_want == 'variance': | |
| output = criterion.variance(logits_eval) | |
| elif what_do_you_want == 'mode': | |
| output = criterion.mode(logits_eval) | |
| elif what_do_you_want == 'ts': | |
| mn = criterion.mean(logits_eval) | |
| if trasform_type== 'std' or trasform_type== 'power': | |
| if pt.standardize: | |
| XX = mn.clone() | |
| scale = torch.from_numpy(pt._scaler.scale_) | |
| std_mean = torch.from_numpy(pt._scaler.mean_) | |
| XX = torch_std_inverse_transform(XX, scale, std_mean) | |
| for i, lmbda in enumerate(pt.lambdas_): | |
| with np.errstate(invalid="ignore"): # hide NaN warnings | |
| XX = torch_power_inverse_transform(XX, lmbda) | |
| var = criterion.variance(logits_eval) | |
| return XX, var | |
| return output | |
| def Rosen_PFN_Parallel(model_name, | |
| trained_X, | |
| trained_Y, | |
| GX, | |
| X_pen, | |
| trasform_type, | |
| what_do_you_want | |
| ): | |
| PFN = TransformerBOMethod(torch.load(model_name), device=device) | |
| with torch.no_grad(): | |
| dim = trained_X.shape[1] | |
| x_given = trained_X | |
| x_eval = X_pen | |
| x_predict = torch.cat([x_given, x_eval], dim=0) | |
| x_full_feed = torch.cat([x_given, x_given, x_eval], dim=0).unsqueeze(1) | |
| y_given = trained_Y | |
| y_given = y_given.reshape(-1) | |
| ###################################################################### | |
| # Objective Power Transform | |
| y_given, pt_y = general_power_transform(y_given.unsqueeze(1), | |
| y_given.unsqueeze(1), | |
| .0, | |
| less_safe=False) | |
| y_given = y_given.squeeze(1) | |
| ###################################################################### | |
| ###################################################################### | |
| # Constraints Power Transform | |
| # Changes for Parallel: | |
| GX = -GX | |
| GX_t, pt_GX = general_power_transform(GX, GX, .0, less_safe=False) | |
| G_thres, _ = general_power_transform(GX, | |
| torch.zeros((1, GX.shape[1])).to(GX.device), | |
| .0, | |
| less_safe=False) | |
| GX = GX_t | |
| ###################################################################### | |
| y_full_feed = y_given.unsqueeze(1) | |
| criterion: bar_distribution.BarDistribution = PFN.model.criterion | |
| style = None | |
| logits = PFN.model( | |
| (style, | |
| x_full_feed.repeat_interleave(dim=1, repeats=y_full_feed.shape[1]+GX.shape[1]), | |
| torch.cat([y_full_feed, GX], dim=1).unsqueeze(2) ), | |
| single_eval_pos=len(x_given) | |
| ) | |
| logits = logits.softmax(-1).log_() | |
| logits_given = logits[:len(x_given)] | |
| logits_eval = logits[len(x_given):] | |
| best_f = torch.max(y_given) | |
| objective_given = logits_given[:,0,:].unsqueeze(1) | |
| objective_eval = logits_eval[:,0,:].unsqueeze(1) | |
| constraint_given = logits_given[:,1:,:] | |
| constraint_eval = logits_eval[:,1:,:] | |
| if what_do_you_want == 'mean': | |
| obj_output = criterion.mean(objective_eval) | |
| con_output = criterion.mean(constraint_eval) | |
| elif what_do_you_want == 'ei': | |
| # Changes for CEI | |
| # Objective | |
| tau = torch.max(y_given) | |
| objective_acq_value = acq_ensembling(criterion.ei(objective_eval, tau)) | |
| # Constraints | |
| constraints_acq_value = acq_ensembling(criterion.pi(constraint_eval[:,0,:].unsqueeze(1), G_thres[0, 0].item())) | |
| constraints_acq_value = constraints_acq_value.unsqueeze(1) | |
| for jj in range(1,constraint_eval.shape[1]): | |
| next_constraints_acq_value = acq_ensembling(criterion.pi(constraint_eval[:,jj,:].unsqueeze(1), G_thres[0, jj].item())) | |
| next_constraints_acq_value = next_constraints_acq_value.unsqueeze(1) | |
| constraints_acq_value = torch.cat([constraints_acq_value,next_constraints_acq_value], dim=1) | |
| return objective_acq_value, constraints_acq_value | |
| elif what_do_you_want == 'variance': | |
| output = criterion.variance(logits_eval) | |
| elif what_do_you_want == 'mode': | |
| output = criterion.mode(logits_eval) | |
| elif what_do_you_want == 'cts': | |
| obj_mnn = criterion.mean(objective_eval) | |
| obj_mnn = pt_y.inverse_transform(obj_mnn) | |
| obj_mnn = torch.from_numpy(obj_mnn) | |
| con_mnn = criterion.mean(constraint_eval) | |
| con_mnn = pt_GX.inverse_transform(con_mnn) | |
| con_mnn = torch.from_numpy(-con_mnn) | |
| obj_varr = criterion.variance(objective_eval) | |
| con_varr = criterion.variance(constraint_eval) | |
| return obj_mnn, obj_varr, con_mnn, con_varr | |
| return output | |
| def acq_ensembling(acq_values): # (points, ensemble dim) | |
| return acq_values.max(1).values | |
| def torch_std_inverse_transform(X, scale, mean): | |
| X *= scale | |
| X += mean | |
| return X | |
| def torch_power_inverse_transform(x, lmbda): | |
| out = torch.zeros_like(x) | |
| pos = x >= 0 | |
| # when x >= 0 | |
| if abs(lmbda) < np.spacing(1.0): | |
| out[pos] = torch.exp(x[pos])-1 | |
| else: # lmbda != 0 | |
| out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1 | |
| # when x < 0 | |
| if abs(lmbda - 2) > np.spacing(1.0): | |
| out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda)) | |
| else: # lmbda == 2 | |
| out[~pos] = 1 - torch.exp(-x[~pos]) | |
| return out | |
| ################################################################################ | |
| ## PFN defined functions | |
| ################################################################################ | |
| def log01(x, eps=.0000001, input_between_zero_and_one=False): | |
| logx = torch.log(x + eps) | |
| if input_between_zero_and_one: | |
| return (logx - math.log(eps)) / (math.log(1 + eps) - math.log(eps)) | |
| return (logx - logx.min(0)[0]) / (logx.max(0)[0] - logx.min(0)[0]) | |
| def log01_batch(x, eps=.0000001, input_between_zero_and_one=False): | |
| x = x.repeat(1, x.shape[-1] + 1, 1) | |
| for b in range(x.shape[-1]): | |
| x[:, b, b] = log01(x[:, b, b], eps=eps, input_between_zero_and_one=input_between_zero_and_one) | |
| return x | |
| def lognormed_batch(x, eval_pos, eps=.0000001): | |
| x = x.repeat(1, x.shape[-1] + 1, 1) | |
| for b in range(x.shape[-1]): | |
| logx = torch.log(x[:, b, b]+eps) | |
| x[:, b, b] = (logx - logx[:eval_pos].mean(0))/logx[:eval_pos].std(0) | |
| return x | |
| def _rank_transform(x_train, x): | |
| assert len(x_train.shape) == len(x.shape) == 1 | |
| relative_to = torch.cat((torch.zeros_like(x_train[:1]),x_train.unique(sorted=True,), torch.ones_like(x_train[-1:])),-1) | |
| higher_comparison = (relative_to < x[...,None]).sum(-1).clamp(min=1) | |
| pos_inside_interval = (x - relative_to[higher_comparison-1])/(relative_to[higher_comparison] - relative_to[higher_comparison-1]) | |
| x_transformed = higher_comparison - 1 + pos_inside_interval | |
| return x_transformed/(len(relative_to)-1.) | |
| def rank_transform(x_train, x): | |
| assert x.shape[1] == x_train.shape[1], f"{x.shape=} and {x_train.shape=}" | |
| # make sure everything is between 0 and 1 | |
| assert (x_train >= 0.).all() and (x_train <= 1.).all(), f"{x_train=}" | |
| assert (x >= 0.).all() and (x <= 1.).all(), f"{x=}" | |
| return_x = x.clone() | |
| for feature_dim in range(x.shape[1]): | |
| return_x[:, feature_dim] = _rank_transform(x_train[:, feature_dim], x[:, feature_dim]) | |
| return return_x | |
| def general_power_transform(x_train, x_apply, eps, less_safe=False): | |
| # print('in function') | |
| # print(x_train) | |
| # print(x_apply) | |
| # print('in function') | |
| if eps > 0: | |
| try: | |
| pt = PowerTransformer(method='box-cox') | |
| pt.fit(x_train.cpu()+eps) | |
| x_out = torch.tensor(pt.transform(x_apply.cpu()+eps), dtype=x_apply.dtype, device=x_apply.device) | |
| except Exception as e: | |
| print(e) | |
| x_out = x_apply - x_train.mean(0) | |
| print(x_train) | |
| print(x_out) | |
| else: | |
| pt = PowerTransformer(method='yeo-johnson') | |
| if not less_safe and (x_train.std() > 1_000 or x_train.mean().abs() > 1_000): | |
| x_apply = (x_apply - x_train.mean(0)) / x_train.std(0) | |
| x_train = (x_train - x_train.mean(0)) / x_train.std(0) | |
| # print('inputs are LAARGEe, normalizing them') | |
| try: | |
| pt.fit(x_train.cpu().double()) | |
| # except ValueError as e: | |
| except Exception as e: | |
| # print(x_train) | |
| # print('caught this errrr', e) | |
| if less_safe: | |
| x_train = (x_train - x_train.mean(0)) / x_train.std(0) | |
| x_apply = (x_apply - x_train.mean(0)) / x_train.std(0) | |
| else: | |
| x_train = x_train - x_train.mean(0) | |
| x_apply = x_apply - x_train.mean(0) | |
| # print(x_train) | |
| pt.fit(x_train.cpu().double()) | |
| # print(x_train) | |
| x_out = torch.tensor(pt.transform(x_apply.cpu()), dtype=x_apply.dtype, device=x_apply.device) | |
| if torch.isnan(x_out).any() or torch.isinf(x_out).any(): | |
| print('WARNING: power transform failed') | |
| print(f"{x_train=} and {x_apply=}") | |
| x_out = x_apply - x_train.mean(0) | |
| return x_out, pt |