Spaces:
Runtime error
Runtime error
| ''' | |
| Code adapted from Stitch it in Time by Tzaban et al. | |
| https://github.com/rotemtzaban/STIT | |
| ''' | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| import os | |
| import clip | |
| imagenet_templates = [ | |
| 'a bad photo of a {}.', | |
| 'a photo of many {}.', | |
| 'a sculpture of a {}.', | |
| 'a photo of the hard to see {}.', | |
| 'a low resolution photo of the {}.', | |
| 'a rendering of a {}.', | |
| 'graffiti of a {}.', | |
| 'a bad photo of the {}.', | |
| 'a cropped photo of the {}.', | |
| 'a tattoo of a {}.', | |
| 'the embroidered {}.', | |
| 'a photo of a hard to see {}.', | |
| 'a bright photo of a {}.', | |
| 'a photo of a clean {}.', | |
| 'a photo of a dirty {}.', | |
| 'a dark photo of the {}.', | |
| 'a drawing of a {}.', | |
| 'a photo of my {}.', | |
| 'the plastic {}.', | |
| 'a photo of the cool {}.', | |
| 'a close-up photo of a {}.', | |
| 'a black and white photo of the {}.', | |
| 'a painting of the {}.', | |
| 'a painting of a {}.', | |
| 'a pixelated photo of the {}.', | |
| 'a sculpture of the {}.', | |
| 'a bright photo of the {}.', | |
| 'a cropped photo of a {}.', | |
| 'a plastic {}.', | |
| 'a photo of the dirty {}.', | |
| 'a jpeg corrupted photo of a {}.', | |
| 'a blurry photo of the {}.', | |
| 'a photo of the {}.', | |
| 'a good photo of the {}.', | |
| 'a rendering of the {}.', | |
| 'a {} in a video game.', | |
| 'a photo of one {}.', | |
| 'a doodle of a {}.', | |
| 'a close-up photo of the {}.', | |
| 'a photo of a {}.', | |
| 'the origami {}.', | |
| 'the {} in a video game.', | |
| 'a sketch of a {}.', | |
| 'a doodle of the {}.', | |
| 'a origami {}.', | |
| 'a low resolution photo of a {}.', | |
| 'the toy {}.', | |
| 'a rendition of the {}.', | |
| 'a photo of the clean {}.', | |
| 'a photo of a large {}.', | |
| 'a rendition of a {}.', | |
| 'a photo of a nice {}.', | |
| 'a photo of a weird {}.', | |
| 'a blurry photo of a {}.', | |
| 'a cartoon {}.', | |
| 'art of a {}.', | |
| 'a sketch of the {}.', | |
| 'a embroidered {}.', | |
| 'a pixelated photo of a {}.', | |
| 'itap of the {}.', | |
| 'a jpeg corrupted photo of the {}.', | |
| 'a good photo of a {}.', | |
| 'a plushie {}.', | |
| 'a photo of the nice {}.', | |
| 'a photo of the small {}.', | |
| 'a photo of the weird {}.', | |
| 'the cartoon {}.', | |
| 'art of the {}.', | |
| 'a drawing of the {}.', | |
| 'a photo of the large {}.', | |
| 'a black and white photo of a {}.', | |
| 'the plushie {}.', | |
| 'a dark photo of a {}.', | |
| 'itap of a {}.', | |
| 'graffiti of the {}.', | |
| 'a toy {}.', | |
| 'itap of my {}.', | |
| 'a photo of a cool {}.', | |
| 'a photo of a small {}.', | |
| 'a tattoo of the {}.', | |
| ] | |
| CONV_CODE_INDICES = [(0, 512), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] | |
| FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \ | |
| [(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)] | |
| def zeroshot_classifier(model, classnames, templates, device): | |
| with torch.no_grad(): | |
| zeroshot_weights = [] | |
| for classname in tqdm(classnames): | |
| texts = [template.format(classname) for template in templates] # format with class | |
| texts = clip.tokenize(texts).to(device) # tokenize | |
| class_embeddings = model.encode_text(texts) # embed with text encoder | |
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embedding = class_embeddings.mean(dim=0) | |
| class_embedding /= class_embedding.norm() | |
| zeroshot_weights.append(class_embedding) | |
| zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) | |
| return zeroshot_weights | |
| def expand_to_full_dim(partial_tensor): | |
| full_dim_tensor = torch.zeros(size=(1, 9088)) | |
| start_idx = 0 | |
| for conv_start, conv_end in CONV_CODE_INDICES: | |
| length = conv_end - conv_start | |
| full_dim_tensor[:, conv_start:conv_end] = partial_tensor[start_idx:start_idx + length] | |
| start_idx += length | |
| return full_dim_tensor | |
| def get_direction(neutral_class, target_class, beta, di, clip_model=None): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if clip_model is None: | |
| clip_model, _ = clip.load("ViT-B/32", device=device) | |
| class_names = [neutral_class, target_class] | |
| class_weights = zeroshot_classifier(clip_model, class_names, imagenet_templates, device) | |
| dt = class_weights[:, 1] - class_weights[:, 0] | |
| dt = dt / dt.norm() | |
| dt = dt.float() | |
| di = di.float() | |
| relevance = di @ dt | |
| mask = relevance.abs() > beta | |
| direction = relevance * mask | |
| direction_max = direction.abs().max() | |
| if direction_max > 0: | |
| direction = direction / direction_max | |
| else: | |
| raise ValueError(f'Beta value {beta} is too high for mapping from {neutral_class} to {target_class},' | |
| f' try setting it to a lower value') | |
| return direction | |
| def style_tensor_to_style_dict(style_tensor, refernce_generator): | |
| style_layers = refernce_generator.modulation_layers | |
| style_dict = {} | |
| for layer_idx, layer in enumerate(style_layers): | |
| style_dict[layer] = style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] | |
| return style_dict | |
| def style_dict_to_style_tensor(style_dict, reference_generator): | |
| style_layers = reference_generator.modulation_layers | |
| style_tensor = torch.zeros(size=(1, 9088)) | |
| for layer in style_dict: | |
| layer_idx = style_layers.index(layer) | |
| style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer] | |
| return style_tensor | |
| def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None): | |
| edit_direction = get_direction(source_class, target_class, beta, di, clip_model) | |
| edit_full_dim = expand_to_full_dim(edit_direction) | |
| source_s = style_dict_to_style_tensor(source_latent, reference_generator) | |
| return source_s + alpha * edit_full_dim |