Spaces:
Runtime error
Runtime error
| from ..trainer_videobase import VideoBaseTrainer | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| import os | |
| import torch | |
| from transformers.utils import WEIGHTS_NAME | |
| import json | |
| class VQVAETrainer(VideoBaseTrainer): | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| model = model.module | |
| x = inputs.get("video") | |
| x = x / 2 | |
| z = model.pre_vq_conv(model.encoder(x)) | |
| vq_output = model.codebook(z) | |
| x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) | |
| recon_loss = F.mse_loss(x_recon, x) / 0.06 | |
| commitment_loss = vq_output['commitment_loss'] | |
| loss = recon_loss + commitment_loss | |
| return loss | |