Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import torch | |
| import torch.nn as nn | |
| import json | |
| import os | |
| class AestheticScorer(nn.Module): | |
| def __init__(self, input_size=0, use_activation=False, dropout=0.2, config=None, hidden_dim=1024, reduce_dims=False, output_activation=None): | |
| super().__init__() | |
| self.config = { | |
| "input_size": input_size, | |
| "use_activation": use_activation, | |
| "dropout": dropout, | |
| "hidden_dim": hidden_dim, | |
| "reduce_dims": reduce_dims, | |
| "output_activation": output_activation | |
| } | |
| if config != None: | |
| self.config.update(config) | |
| layers = [ | |
| nn.Linear(self.config["input_size"], self.config["hidden_dim"]), | |
| nn.ReLU() if self.config["use_activation"] else None, | |
| nn.Dropout(self.config["dropout"]), | |
| nn.Linear(self.config["hidden_dim"], round(self.config["hidden_dim"] / (2 if reduce_dims else 1))), | |
| nn.ReLU() if self.config["use_activation"] else None, | |
| nn.Dropout(self.config["dropout"]), | |
| nn.Linear(round(self.config["hidden_dim"] / (2 if reduce_dims else 1)), round(self.config["hidden_dim"] / (4 if reduce_dims else 1))), | |
| nn.ReLU() if self.config["use_activation"] else None, | |
| nn.Dropout(self.config["dropout"]), | |
| nn.Linear(round(self.config["hidden_dim"] / (4 if reduce_dims else 1)), round(self.config["hidden_dim"] / (8 if reduce_dims else 1))), | |
| nn.ReLU() if self.config["use_activation"] else None, | |
| nn.Linear(round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), 1), | |
| ] | |
| if self.config["output_activation"] == "sigmoid": | |
| layers.append( | |
| nn.Sigmoid() | |
| ) | |
| layers = [ x for x in layers if x is not None] | |
| self.layers = nn.Sequential( | |
| *layers | |
| ) | |
| def forward(self, x): | |
| if self.config["output_activation"] == "sigmoid": | |
| upper, lower = 10, 1 | |
| scale = upper - lower | |
| return (self.layers(x) * scale) + lower | |
| else: | |
| return self.layers(x) | |
| def save(self, save_name): | |
| split_name = os.path.splitext(save_name) | |
| with open(f"{split_name[0]}.config", "w") as outfile: | |
| outfile.write(json.dumps(self.config, indent=4)) | |
| for i in range(6): # saving sometiles fails, so retry 5 times, might be windows issue | |
| try: | |
| torch.save(self.state_dict(), save_name) | |
| break | |
| except RuntimeError as e: | |
| # check if error contains string "File" | |
| if "cannot be opened" in str(e) and i < 5: | |
| print("Model save failed, retrying...") | |
| else: | |
| raise e | |
| def preprocess(embeddings): | |
| return embeddings / embeddings.norm(p=2, dim=-1, keepdim=True) | |
| def load_model(weight_path, device='cuda' if torch.cuda.is_available() else 'cpu'): | |
| split_path = os.path.splitext(weight_path) | |
| with open(f"{split_path[0]}.config", "r") as config_file: | |
| config = json.load(config_file) | |
| model = AestheticScorer(config=config) | |
| model.load_state_dict(torch.load(weight_path, map_location=device)) | |
| model.eval() | |
| return model | |
