Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import DataLoader | |
| import torch.nn as nn | |
| from configs.deepsvg.hierarchical_ordered import Config | |
| from deepsvg import utils | |
| from deepsvg.svglib.svg import SVG | |
| from deepsvg.difflib.tensor import SVGTensor | |
| from deepsvg.svglib.geom import Bbox | |
| from deepsvg.svgtensor_dataset import load_dataset, SVGFinetuneDataset | |
| from deepsvg.utils.utils import batchify | |
| from .state.project import DeepSVGProject, Frame | |
| from .utils import easein_easeout | |
| device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu") | |
| pretrained_path = "./pretrained/hierarchical_ordered.pth.tar" | |
| cfg = Config() | |
| cfg.model_cfg.dropout = 0. # for faster convergence | |
| model = cfg.make_model().to(device) | |
| model.eval() | |
| dataset = load_dataset(cfg) | |
| def decode(z): | |
| commands_y, args_y, _ = model.greedy_sample(z=z) | |
| tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu()) | |
| svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256)) | |
| return svg_path_sample | |
| def encode_svg(svg): | |
| data = dataset.get(model_args=[*cfg.model_args, "tensor_grouped"], svg=svg) | |
| model_args = batchify((data[key] for key in cfg.model_args), device) | |
| z = model(*model_args, encode_mode=True) | |
| return z | |
| def interpolate_svg(svg1, svg2, n=10, ease=True): | |
| z1, z2 = encode_svg(svg1), encode_svg(svg2) | |
| alphas = torch.linspace(0., 1., n+2)[1:-1] | |
| if ease: | |
| alphas = easein_easeout(alphas) | |
| z_list = [(1 - a) * z1 + a * z2 for a in alphas] | |
| svgs = [decode(z) for z in z_list] | |
| return svgs | |
| def finetune_model(project: DeepSVGProject, nb_augmentations=3500): | |
| keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe] | |
| if len(keyframe_ids) < 2: | |
| return | |
| svgs = [project.frames[i].svg for i in keyframe_ids] | |
| utils.load_model(pretrained_path, model) | |
| print("Finetuning...") | |
| finetune_dataset = SVGFinetuneDataset(dataset, svgs, frac=1.0, nb_augmentations=nb_augmentations) | |
| dataloader = DataLoader(finetune_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=False, | |
| num_workers=cfg.loader_num_workers, collate_fn=cfg.collate_fn) | |
| # Optimizer, lr & warmup schedulers | |
| optimizers = cfg.make_optimizers(model) | |
| scheduler_lrs = cfg.make_schedulers(optimizers, epoch_size=len(dataloader)) | |
| scheduler_warmups = cfg.make_warmup_schedulers(optimizers, scheduler_lrs) | |
| loss_fns = [l.to(device) for l in cfg.make_losses()] | |
| epoch = 0 | |
| for step, data in enumerate(dataloader): | |
| model.train() | |
| model_args = [data[arg].to(device) for arg in cfg.model_args] | |
| labels = data["label"].to(device) if "label" in data else None | |
| params_dict, weights_dict = cfg.get_params(step, epoch), cfg.get_weights(step, epoch) | |
| for i, (loss_fn, optimizer, scheduler_lr, scheduler_warmup, optimizer_start) in enumerate( | |
| zip(loss_fns, optimizers, scheduler_lrs, scheduler_warmups, cfg.optimizer_starts), 1): | |
| optimizer.zero_grad() | |
| output = model(*model_args, params=params_dict) | |
| loss_dict = loss_fn(output, labels, weights=weights_dict) | |
| loss_dict["loss"].backward() | |
| if cfg.grad_clip is not None: | |
| nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) | |
| optimizer.step() | |
| if scheduler_lr is not None: | |
| scheduler_lr.step() | |
| if scheduler_warmup is not None: | |
| scheduler_warmup.step() | |
| if step % 20 == 0: | |
| print(f"Step {step}: loss: {loss_dict['loss']}") | |
| print("Finetuning done.") | |
| def compute_interpolation(project: DeepSVGProject): | |
| finetune_model(project) | |
| keyframe_ids = [i for i, frame in enumerate(project.frames) if frame.keyframe] | |
| if len(keyframe_ids) < 2: | |
| return | |
| model.eval() | |
| for i1, i2 in zip(keyframe_ids[:-1], keyframe_ids[1:]): | |
| frames_inbetween = i2 - i1 - 1 | |
| if frames_inbetween == 0: | |
| continue | |
| svgs = interpolate_svg(project.frames[i1].svg, project.frames[i2].svg, n=frames_inbetween, ease=False) | |
| for di, svg in enumerate(svgs, 1): | |
| project.frames[i1 + di] = Frame(i1 + di, keyframe=False, svg=svg) | |