|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import typing as tp |
|
|
from functools import partial |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import flashy |
|
|
from omegaconf import DictConfig |
|
|
import multiprocessing |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from . import base, builders |
|
|
from ..models.builders import get_watermark_model |
|
|
from ..modules.watermark import pad, mix |
|
|
|
|
|
from ..metrics.miou import calculate_miou |
|
|
from ..metrics.pesq import PesqMetric |
|
|
|
|
|
from ..utils import checkpoint |
|
|
from ..utils.audio_effects import ( |
|
|
compress_with_encodec, |
|
|
get_audio_effects, |
|
|
select_audio_effects, |
|
|
) |
|
|
from ..utils.samples.manager import SampleManager |
|
|
from ..data.audio import save_spectrograms |
|
|
from ..utils.utils import get_pool_executor |
|
|
|
|
|
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio |
|
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility |
|
|
|
|
|
|
|
|
if tp.TYPE_CHECKING: |
|
|
from ..models.watermark import WMModel |
|
|
|
|
|
|
|
|
def get_encodec_audio_effect(encodec_cfg: DictConfig, sr: int) -> tp.Dict: |
|
|
""" |
|
|
Construct encodec-based compression data agumentation. This method is |
|
|
is put here instead of in `audiocraft.utils.audio_effects` because |
|
|
it depends on the package `audiocraft.solvers`, which is one layer |
|
|
higher than `audiocraft.utils`, so we avoid the circle dependency |
|
|
from any solvers using `audiocraft.utils.audio_effects` to do the |
|
|
augmentation |
|
|
""" |
|
|
from ..solvers.compression import CompressionSolver |
|
|
|
|
|
codec_model = CompressionSolver.model_from_checkpoint(encodec_cfg.ckpt) |
|
|
codec_model.train() |
|
|
return { |
|
|
f"encodec_nq={n_q}": partial( |
|
|
compress_with_encodec, |
|
|
model=codec_model, |
|
|
n_q=n_q, |
|
|
sample_rate=sr, |
|
|
) |
|
|
for n_q in encodec_cfg.n_qs |
|
|
} |
|
|
|
|
|
|
|
|
def random_message(nbits: int, batch_size: int) -> torch.Tensor: |
|
|
"""Return random message as 0/1 tensor.""" |
|
|
if nbits == 0: |
|
|
return torch.tensor([]) |
|
|
return torch.randint(0, 2, (batch_size, nbits)) |
|
|
|
|
|
|
|
|
class WatermarkSolver(base.StandardSolver): |
|
|
"""Solver for different watermarking models""" |
|
|
|
|
|
def __init__(self, cfg: DictConfig): |
|
|
super().__init__(cfg) |
|
|
self.rng: torch.Generator |
|
|
self.model: WMModel |
|
|
if hasattr(cfg, "fsdp"): |
|
|
assert not getattr( |
|
|
cfg.fsdp, "use", False |
|
|
), "FSDP not supported by WatermarkSolver." |
|
|
self._init_losses() |
|
|
self._init_augmentations() |
|
|
self.balancer = builders.get_balancer(self.loss_weights, self.cfg.balancer) |
|
|
self.path_specs = os.path.join(self.folder, "spectrograms") |
|
|
os.makedirs(self.path_specs, exist_ok=True) |
|
|
|
|
|
def _init_losses(self): |
|
|
assert hasattr(self.cfg, "losses") and isinstance( |
|
|
self.cfg.losses, (DictConfig, tp.Mapping) |
|
|
), "WatermarkSolver must declare training losses in the config" |
|
|
|
|
|
self.adv_losses = builders.get_adversarial_losses(self.cfg) |
|
|
self.register_stateful("adv_losses") |
|
|
|
|
|
self.aux_losses = nn.ModuleDict() |
|
|
self.info_losses = nn.ModuleDict() |
|
|
self.wm_losses = nn.ModuleDict() |
|
|
loss_weights = {} |
|
|
for loss_name, weight in self.cfg.losses.items(): |
|
|
|
|
|
|
|
|
|
|
|
if weight == -1: |
|
|
continue |
|
|
|
|
|
if loss_name in ["adv", "feat"]: |
|
|
for adv_name, _ in self.adv_losses.items(): |
|
|
loss_weights[f"{loss_name}_{adv_name}"] = weight |
|
|
elif weight > 0: |
|
|
if loss_name[:3] == "wm_": |
|
|
self.wm_losses[loss_name] = builders.get_loss( |
|
|
loss_name, self.cfg |
|
|
).to(self.device) |
|
|
loss_weights[loss_name] = weight |
|
|
else: |
|
|
self.aux_losses[loss_name] = builders.get_loss( |
|
|
loss_name, self.cfg |
|
|
).to(self.device) |
|
|
loss_weights[loss_name] = weight |
|
|
else: |
|
|
self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg).to( |
|
|
self.device |
|
|
) |
|
|
|
|
|
self.loss_weights = loss_weights |
|
|
|
|
|
def _init_augmentations(self): |
|
|
if not hasattr(self.cfg, "aug_weights") or not hasattr( |
|
|
self.cfg, "audio_effects" |
|
|
): |
|
|
return |
|
|
|
|
|
aug_weights = {} |
|
|
cfg_audio_effects = dict(self.cfg.audio_effects) |
|
|
|
|
|
|
|
|
|
|
|
encodec_cfg = cfg_audio_effects.pop("encodec", None) |
|
|
if encodec_cfg: |
|
|
encodec_effects = get_encodec_audio_effect( |
|
|
encodec_cfg, self.cfg.sample_rate |
|
|
) |
|
|
for aug_name in encodec_effects.keys(): |
|
|
aug_weights[aug_name] = getattr(self.cfg.aug_weights, "encodec", -1) |
|
|
else: |
|
|
encodec_effects = {} |
|
|
|
|
|
other_effects = get_audio_effects(self.cfg) |
|
|
for name in other_effects.keys(): |
|
|
aug_weights[name] = self.cfg.aug_weights.get(name, -1) |
|
|
|
|
|
self.aug_weights = aug_weights |
|
|
self.augmentations = {**encodec_effects, **other_effects} |
|
|
|
|
|
@property |
|
|
def best_metric_name(self) -> tp.Optional[str]: |
|
|
|
|
|
return None |
|
|
|
|
|
def build_model(self): |
|
|
"""Instantiate model and optimizer.""" |
|
|
|
|
|
self.model = get_watermark_model(self.cfg) |
|
|
|
|
|
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) |
|
|
self.register_stateful("model", "optimizer") |
|
|
self.register_best_state("model") |
|
|
self.register_ema("model") |
|
|
|
|
|
def build_dataloaders(self): |
|
|
"""Instantiate audio dataloaders for each stage.""" |
|
|
self.dataloaders = builders.get_audio_datasets(self.cfg) |
|
|
|
|
|
def show(self): |
|
|
"""Show the Watermark model and employed adversarial loss.""" |
|
|
self.log_model_summary(self.model) |
|
|
self.logger.info("Sould print losses here:") |
|
|
|
|
|
def crop( |
|
|
self, signal: torch.Tensor, watermark: torch.Tensor |
|
|
) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Applies a transformation to modify the watermarked signal to train localization. |
|
|
It can be one of the following: |
|
|
- zero padding: add zeros at the begining and the end of the signal |
|
|
- crop: crop the watermark apply a watermark only on some parts of the signal |
|
|
- shuffle: replace some part of the audio with other non watermarked parts |
|
|
from the batch |
|
|
In every cases the function returns a mask that contains indicates the parts that are or |
|
|
not watermarked |
|
|
|
|
|
Args: |
|
|
watermark (torch.Tensor): The watermark to apply on the signal. |
|
|
signal (torch.Tensor): clean signal |
|
|
Returns: |
|
|
watermark (torch.Tensor): modified watermark |
|
|
signal (torch.Tensor): modified signal |
|
|
mask (torch.Tensor): mask indicating which portion is still watermarked |
|
|
""" |
|
|
assert ( |
|
|
self.cfg.crop.prob + self.cfg.crop.shuffle_prob + self.cfg.crop.pad_prob |
|
|
<= 1 |
|
|
), f"The sum of the probabilities {self.cfg.crop.prob=} {self.cfg.crop.shuffle_prob=} \ |
|
|
{self.cfg.crop.pad_prob=} should be less than 1" |
|
|
mask = torch.ones_like(watermark) |
|
|
p = torch.rand(1) |
|
|
if p < self.cfg.crop.pad_prob: |
|
|
start = int(torch.rand(1) * 0.33 * watermark.size(-1)) |
|
|
finish = int((0.66 + torch.rand(1) * 0.33) * watermark.size(-1)) |
|
|
mask[:, :, :start] = 0 |
|
|
mask[:, :, finish:] = 0 |
|
|
if torch.rand(1) > 0.5: |
|
|
mask = 1 - mask |
|
|
signal *= mask |
|
|
|
|
|
elif ( |
|
|
p < self.cfg.crop.prob + self.cfg.crop.pad_prob + self.cfg.crop.shuffle_prob |
|
|
): |
|
|
|
|
|
mask_size = round(watermark.shape[-1] * self.cfg.crop.size) |
|
|
n_windows = int( |
|
|
torch.randint(1, self.cfg.crop.max_n_windows + 1, (1,)).item() |
|
|
) |
|
|
window_size = int(mask_size / n_windows) |
|
|
for _ in range(n_windows): |
|
|
mask_start = torch.randint(0, watermark.shape[-1] - window_size, (1,)) |
|
|
mask[:, :, mask_start: mask_start + window_size] = ( |
|
|
0 |
|
|
) |
|
|
|
|
|
if torch.rand(1) > 0.5: |
|
|
mask = 1 - mask |
|
|
|
|
|
if p < self.cfg.crop.pad_prob + self.cfg.crop.shuffle_prob: |
|
|
|
|
|
signal_cloned = signal.clone().detach() |
|
|
shuffle_idx = torch.randint(0, signal.size(0), (signal.size(0),)) |
|
|
signal = signal * mask + signal_cloned[shuffle_idx] * ( |
|
|
1 - mask |
|
|
) |
|
|
|
|
|
watermark *= mask |
|
|
return signal, watermark, mask |
|
|
|
|
|
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): |
|
|
"""Perform one training or valid step on a given batch.""" |
|
|
x = batch.to(self.device) |
|
|
y = x.clone() |
|
|
nbits = getattr(self.model, "nbits") |
|
|
message = random_message(nbits, y.shape[0]).to(self.device) |
|
|
watermark = self.model.get_watermark(x, message=message) |
|
|
y, watermark, mask = self.crop(y, watermark) |
|
|
|
|
|
y_wm = y + watermark |
|
|
|
|
|
if ( |
|
|
self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0 |
|
|
) and self.is_training: |
|
|
d_losses: dict = {} |
|
|
if ( |
|
|
len(self.adv_losses) > 0 |
|
|
and torch.rand(1, generator=self.rng).item() |
|
|
<= 1 / self.cfg.adversarial.every |
|
|
): |
|
|
for adv_name, adversary in self.adv_losses.items(): |
|
|
disc_loss = adversary.train_adv(y_wm, y) |
|
|
d_losses[f"d_{adv_name}"] = disc_loss |
|
|
metrics["d_loss"] = torch.sum(torch.stack(list(d_losses.values()))) |
|
|
metrics.update(d_losses) |
|
|
|
|
|
balanced_losses: dict = {} |
|
|
other_losses: dict = {} |
|
|
|
|
|
|
|
|
if self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0: |
|
|
for adv_name, adversary in self.adv_losses.items(): |
|
|
adv_loss, feat_loss = adversary(y_wm, y) |
|
|
balanced_losses[f"adv_{adv_name}"] = adv_loss |
|
|
balanced_losses[f"feat_{adv_name}"] = feat_loss |
|
|
|
|
|
|
|
|
for loss_name, criterion in self.aux_losses.items(): |
|
|
loss = criterion(y_wm, y) |
|
|
balanced_losses[loss_name] = loss |
|
|
|
|
|
|
|
|
mode = "all" if self.cfg.select_aug_mode == "all" else "weighted" |
|
|
selected_augs = select_audio_effects( |
|
|
self.augmentations, |
|
|
self.aug_weights, |
|
|
mode=mode, |
|
|
max_length=self.cfg.n_max_aug, |
|
|
) |
|
|
N_augs = len(selected_augs) |
|
|
for ( |
|
|
augmentation_name, |
|
|
augmentation_method, |
|
|
) in selected_augs.items(): |
|
|
|
|
|
y_y_wm = torch.cat([y, y_wm], dim=0) |
|
|
aug_cat, mask_aug = augmentation_method(y_y_wm, mask=mask) |
|
|
aug_y = aug_cat[: y.size(0)] |
|
|
aug_y_wm = aug_cat[y.size(0):] |
|
|
positive = self.model.detect_watermark(aug_y_wm) |
|
|
negative = self.model.detect_watermark(aug_y) |
|
|
for loss_name, criterion in self.wm_losses.items(): |
|
|
loss = criterion(positive, negative, mask_aug, message) |
|
|
other_losses[f"{loss_name}_{augmentation_name}"] = loss |
|
|
|
|
|
|
|
|
metrics.update(balanced_losses) |
|
|
metrics.update(other_losses) |
|
|
if self.is_training: |
|
|
other_loss = torch.tensor(0.0, device=self.device) |
|
|
for name, o_loss in other_losses.items(): |
|
|
if "wm_detection" in name: |
|
|
|
|
|
other_loss += (self.loss_weights["wm_detection"] / N_augs) * o_loss |
|
|
elif "wm_mb" in name: |
|
|
other_loss += (self.loss_weights["wm_mb"] / N_augs) * o_loss |
|
|
else: |
|
|
other_loss += self.loss_weights[name] * o_loss |
|
|
if other_loss.requires_grad: |
|
|
other_loss.backward(retain_graph=True) |
|
|
ratio1 = sum( |
|
|
p.grad.data.norm(p=2).pow(2) |
|
|
for p in self.model.parameters() |
|
|
if p.grad is not None |
|
|
) |
|
|
assert isinstance(ratio1, torch.Tensor) |
|
|
metrics["ratio1"] = ratio1.sqrt() |
|
|
|
|
|
|
|
|
|
|
|
metrics["g_loss"] = self.balancer.backward(balanced_losses, y_wm) |
|
|
|
|
|
metrics.update(self.balancer.metrics) |
|
|
ratio2 = sum( |
|
|
p.grad.data.norm(p=2).pow(2) |
|
|
for p in self.model.parameters() |
|
|
if p.grad is not None |
|
|
) |
|
|
assert isinstance(ratio2, torch.Tensor) |
|
|
metrics["ratio2"] = ratio2.sqrt() |
|
|
|
|
|
|
|
|
flashy.distrib.sync_model(self.model) |
|
|
if self.cfg.optim.max_norm: |
|
|
torch.nn.utils.clip_grad_norm_( |
|
|
self.model.parameters(), self.cfg.optim.max_norm |
|
|
) |
|
|
|
|
|
self.optimizer.step() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
info_losses: dict = {} |
|
|
with torch.no_grad(): |
|
|
for loss_name, criterion in self.info_losses.items(): |
|
|
loss = criterion(y_wm, y) |
|
|
info_losses[loss_name] = loss |
|
|
|
|
|
metrics["pesq"] = tensor_pesq(y_wm, y, sr=self.cfg.sample_rate) |
|
|
|
|
|
metrics["max_mem"] = torch.cuda.max_memory_allocated() / 1e9 |
|
|
|
|
|
metrics.update(info_losses) |
|
|
if self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0: |
|
|
|
|
|
adv_losses = [ |
|
|
loss |
|
|
for loss_name, loss in metrics.items() |
|
|
if loss_name.startswith("adv") |
|
|
] |
|
|
if len(adv_losses) > 0: |
|
|
metrics["adv"] = torch.sum(torch.stack(adv_losses)) |
|
|
feat_losses = [ |
|
|
loss |
|
|
for loss_name, loss in metrics.items() |
|
|
if loss_name.startswith("feat") |
|
|
] |
|
|
if len(feat_losses) > 0: |
|
|
metrics["feat"] = torch.sum(torch.stack(feat_losses)) |
|
|
|
|
|
return metrics |
|
|
|
|
|
def run_epoch(self): |
|
|
|
|
|
self.rng = torch.Generator() |
|
|
self.rng.manual_seed(1234 + self.epoch) |
|
|
|
|
|
super().run_epoch() |
|
|
|
|
|
def evaluate(self) -> dict: |
|
|
"""Evaluate stage. Runs audio reconstruction evaluation.""" |
|
|
self.model.eval() |
|
|
evaluate_stage_name = str(self.current_stage) |
|
|
|
|
|
loader = self.dataloaders["evaluate"] |
|
|
updates = len(loader) |
|
|
lp = self.log_progress( |
|
|
f"{evaluate_stage_name} inference", |
|
|
loader, |
|
|
total=updates, |
|
|
updates=self.log_updates, |
|
|
) |
|
|
average = flashy.averager() |
|
|
|
|
|
pendings = [] |
|
|
ctx = multiprocessing.get_context("spawn") |
|
|
with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: |
|
|
for batch in lp: |
|
|
x = batch.to(self.device) |
|
|
with torch.no_grad(): |
|
|
message = random_message(self.model.nbits, x.shape[0]) |
|
|
watermark = self.model.get_watermark(x, message) |
|
|
x_wm = x + watermark |
|
|
y_pred = x_wm.cpu() |
|
|
y = batch.cpu() |
|
|
pendings.append( |
|
|
pool.submit( |
|
|
evaluate_audio_watermark, |
|
|
y_pred, |
|
|
y, |
|
|
self.cfg, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
for ( |
|
|
augmentation_name, |
|
|
augmentation_method, |
|
|
) in self.augmentations.items(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
aug_positive = self.model.detect_watermark( |
|
|
augmentation_method(x_wm) |
|
|
) |
|
|
aug_negative = self.model.detect_watermark( |
|
|
augmentation_method(x) |
|
|
) |
|
|
|
|
|
pendings.append( |
|
|
pool.submit( |
|
|
evaluate_augmentations, |
|
|
aug_positive.cpu(), |
|
|
aug_negative.cpu(), |
|
|
augmentation_name, |
|
|
message.cpu(), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for window_size in np.linspace(0.1, 0.9, 9): |
|
|
|
|
|
mixed, true_predictions = mix(x, x_wm, window_size=window_size) |
|
|
model_predictions = self.model.detect_watermark(mixed) |
|
|
pendings.append( |
|
|
pool.submit( |
|
|
evaluate_localizations, |
|
|
model_predictions.cpu(), |
|
|
true_predictions.cpu(), |
|
|
f"crop_{window_size:0.1f}", |
|
|
) |
|
|
) |
|
|
mixed, true_predictions = mix( |
|
|
x, x_wm, window_size=window_size, shuffle=True |
|
|
) |
|
|
model_predictions = self.model.detect_watermark(mixed) |
|
|
pendings.append( |
|
|
pool.submit( |
|
|
evaluate_localizations, |
|
|
model_predictions.cpu(), |
|
|
true_predictions.cpu(), |
|
|
f"shuffle_{window_size:0.1f}", |
|
|
) |
|
|
) |
|
|
|
|
|
mixed, true_predictions = pad(x_wm) |
|
|
model_predictions = self.model.detect_watermark(mixed) |
|
|
pendings.append( |
|
|
pool.submit( |
|
|
evaluate_localizations, |
|
|
model_predictions.cpu(), |
|
|
true_predictions.cpu(), |
|
|
"padding", |
|
|
) |
|
|
) |
|
|
mixed, true_predictions = pad(x_wm, central=True) |
|
|
model_predictions = self.model.detect_watermark(mixed) |
|
|
pendings.append( |
|
|
pool.submit( |
|
|
evaluate_localizations, |
|
|
model_predictions.cpu(), |
|
|
true_predictions.cpu(), |
|
|
"central_padding", |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
metrics_lp = self.log_progress( |
|
|
f"{evaluate_stage_name} metrics", pendings, updates=self.log_updates |
|
|
) |
|
|
for pending in metrics_lp: |
|
|
metrics = pending.result() |
|
|
metrics = average(metrics) |
|
|
|
|
|
metrics = flashy.distrib.average_metrics(metrics, len(loader)) |
|
|
if self.cfg.select_aug_mode == "use_eval_acc": |
|
|
|
|
|
|
|
|
for name in self.augmentations.keys(): |
|
|
if ( |
|
|
self.aug_weights[name] != -1 |
|
|
): |
|
|
|
|
|
self.aug_weights[name] = max(1 - metrics[f"aug_{name}_acc"], 0.05) |
|
|
return metrics |
|
|
|
|
|
def generate(self): |
|
|
"""Generate stage.""" |
|
|
self.model.eval() |
|
|
sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) |
|
|
generate_stage_name = str(self.current_stage) |
|
|
|
|
|
loader = self.dataloaders["generate"] |
|
|
updates = len(loader) |
|
|
lp = self.log_progress( |
|
|
generate_stage_name, loader, total=updates, updates=self.log_updates |
|
|
) |
|
|
path_dir = os.path.join(self.path_specs, f"epoch={self.epoch}") |
|
|
os.makedirs(path_dir, exist_ok=True) |
|
|
first_batch = True |
|
|
for batch in lp: |
|
|
reference, _ = batch |
|
|
reference = reference.to(self.device) |
|
|
with torch.no_grad(): |
|
|
message = random_message(self.model.nbits, reference.shape[0]) |
|
|
watermark = self.model.get_watermark(reference, message) |
|
|
x_wm = reference + watermark |
|
|
|
|
|
reference = reference.cpu() |
|
|
sample_manager.add_samples( |
|
|
x_wm.cpu(), self.epoch, ground_truth_wavs=reference |
|
|
) |
|
|
if first_batch and flashy.distrib.is_rank_zero(): |
|
|
for i in range(reference.size(0)): |
|
|
ys = [ |
|
|
reference.cpu()[i].squeeze(0).numpy(), |
|
|
x_wm.cpu()[i].squeeze(0).numpy(), |
|
|
watermark.cpu()[i].squeeze(0).numpy(), |
|
|
] |
|
|
path = os.path.join(path_dir, f"spec_{i}.pdf") |
|
|
save_spectrograms( |
|
|
ys, |
|
|
names=["Ground Truth", "Audio Watermarked", "Watermark"], |
|
|
sr=self.cfg.sample_rate, |
|
|
path=path, |
|
|
) |
|
|
first_batch = False |
|
|
flashy.distrib.barrier() |
|
|
|
|
|
def load_from_pretrained(self, name: str) -> dict: |
|
|
raise ValueError("No pretrained model") |
|
|
|
|
|
@staticmethod |
|
|
def model_from_checkpoint( |
|
|
checkpoint_path: tp.Union[Path, str], |
|
|
device: tp.Union[torch.device, str] = "cpu", |
|
|
) -> "WMModel": |
|
|
"""Instantiate a WatermarkModel from a given checkpoint path or dora sig. |
|
|
|
|
|
Args: |
|
|
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. |
|
|
device (torch.device or str): Device on which the model is loaded. |
|
|
""" |
|
|
checkpoint_path = str(checkpoint_path) |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.info(f"Loading WatermarkModel from checkpoint: {checkpoint_path}") |
|
|
_checkpoint_path = checkpoint.resolve_checkpoint_path( |
|
|
checkpoint_path, use_fsdp=False |
|
|
) |
|
|
assert ( |
|
|
_checkpoint_path is not None |
|
|
), f"Could not resolve WatermarkModel checkpoint path: {checkpoint_path}" |
|
|
state = checkpoint.load_checkpoint(_checkpoint_path) |
|
|
assert ( |
|
|
state is not None and "xp.cfg" in state |
|
|
), f"Could not load WatermarkModel from ckpt: {checkpoint_path}" |
|
|
cfg = state["xp.cfg"] |
|
|
cfg.device = device |
|
|
watermarking_model = get_watermark_model(cfg).to(device) |
|
|
|
|
|
assert "best_state" in state and state["best_state"] != {} |
|
|
assert ( |
|
|
"exported" not in state |
|
|
), "When loading an exported checkpoint, use the //pretrained/ prefix." |
|
|
watermarking_model.load_state_dict(state["best_state"]["model"]) |
|
|
watermarking_model.eval() |
|
|
logger.info("Watermarking model loaded!") |
|
|
return watermarking_model |
|
|
|
|
|
|
|
|
def evaluate_localizations(predictions, true_predictions, name): |
|
|
metrics = {} |
|
|
|
|
|
|
|
|
metrics[f"localization_acc_{name}"] = ( |
|
|
((predictions[:, 1, :] > 0.5) == true_predictions[:, 1, :]) |
|
|
.float() |
|
|
.mean() |
|
|
.item() |
|
|
) |
|
|
metrics[f"localization_miou_{name}"] = calculate_miou( |
|
|
predictions[:, 1, :], true_predictions[:, 1, :] |
|
|
) |
|
|
return metrics |
|
|
|
|
|
|
|
|
def evaluate_augmentations( |
|
|
positive: torch.Tensor, |
|
|
negative: torch.Tensor, |
|
|
augmentation_name: str, |
|
|
message: torch.Tensor, |
|
|
) -> dict: |
|
|
"""calculating evaluation metrics but take name of the augmentation |
|
|
method that has been done before getting positive and negative results""" |
|
|
metrics = {} |
|
|
metrics[f"aug_{augmentation_name}_acc"] = compute_accuracy(positive, negative) |
|
|
metrics[f"aug_{augmentation_name}_fpr"] = compute_FPR(negative) |
|
|
metrics[f"aug_{augmentation_name}_fnr"] = compute_FNR(positive) |
|
|
if message.shape[0] != 0: |
|
|
metrics[f"aug_{augmentation_name}_bit_acc"] = compute_bit_acc(positive, message) |
|
|
|
|
|
|
|
|
metrics["all_aug_acc"] = compute_accuracy(positive, negative) |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def evaluate_audio_watermark( |
|
|
y_pred: torch.Tensor, |
|
|
y: torch.Tensor, |
|
|
cfg: DictConfig, |
|
|
) -> dict: |
|
|
"""Audio reconstruction evaluation method that can be conveniently pickled.""" |
|
|
metrics = {} |
|
|
if cfg.evaluate.metrics.visqol: |
|
|
visqol = builders.get_visqol(cfg.metrics.visqol) |
|
|
metrics["visqol"] = visqol(y_pred, y, cfg.sample_rate) |
|
|
sisnr = ScaleInvariantSignalNoiseRatio().to(y.device) |
|
|
stoi = ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate) |
|
|
metrics["sisnr"] = sisnr(y_pred, y) |
|
|
metrics["stoi"] = stoi(y_pred, y) |
|
|
metrics["pesq"] = tensor_pesq(y_pred, y, sr=cfg.sample_rate) |
|
|
return metrics |
|
|
|
|
|
|
|
|
def tensor_pesq(y_pred: torch.Tensor, y: torch.Tensor, sr: int): |
|
|
|
|
|
return PesqMetric(sr)(y_pred, y).item() |
|
|
|
|
|
|
|
|
def compute_accuracy(positive, negative): |
|
|
N = (positive[:, 1, :].mean(dim=1) > 0.5).sum() + ( |
|
|
negative[:, 0, :].mean(dim=1) > 0.5 |
|
|
).sum() |
|
|
acc = N / (2 * positive.size(0)) |
|
|
return acc |
|
|
|
|
|
|
|
|
def compute_FPR(negative): |
|
|
N = (negative[:, 1, :].mean(dim=1) > 0.5).sum() |
|
|
fpr = N / (negative.size(0)) |
|
|
return fpr |
|
|
|
|
|
|
|
|
def compute_FNR(positive): |
|
|
N = (positive[:, 0, :].mean(dim=1) > 0.5).sum() |
|
|
fpr = N / (positive.size(0)) |
|
|
return fpr |
|
|
|
|
|
|
|
|
def _bit_acc(decoded, original): |
|
|
bit_acc = (decoded == original).float().mean() |
|
|
return bit_acc |
|
|
|
|
|
|
|
|
def compute_bit_acc(positive, original, mask=None): |
|
|
"""Compute bit accuracy. |
|
|
Args: |
|
|
positive: detector outputs [bsz, 2+nbits, time_steps] |
|
|
original: original message (0 or 1) [bsz, nbits] |
|
|
mask: mask of the watermark [bsz, 1, time_steps] |
|
|
""" |
|
|
decoded = positive[:, 2:, :] |
|
|
if mask is not None: |
|
|
|
|
|
new_shape = [*decoded.shape[:-1], -1] |
|
|
decoded = torch.masked_select(decoded, mask == 1).reshape(new_shape) |
|
|
|
|
|
decoded = decoded.mean(dim=-1) > 0 |
|
|
return _bit_acc(decoded, original) |
|
|
|