Spaces:
Runtime error
Runtime error
| import torch.nn.functional as F | |
| import torch | |
| def l1_loss(output, target_rgb, target_raw, weight=1.0): | |
| raw_loss = F.l1_loss(output["reconstruct_raw"], target_raw) | |
| rgb_loss = F.l1_loss(output["reconstruct_rgb"], target_rgb) | |
| total_loss = raw_loss + weight * rgb_loss | |
| return total_loss, raw_loss, rgb_loss | |
| def l2_loss(output, target_rgb, target_raw, weight=1.0): | |
| raw_loss = F.mse_loss(output["reconstruct_raw"], target_raw) | |
| rgb_loss = F.mse_loss(output["reconstruct_rgb"], target_rgb) | |
| total_loss = raw_loss + weight * rgb_loss | |
| return total_loss, raw_loss, rgb_loss | |