|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from audiocraft.models.watermark import AudioSeal |
|
|
from tests.common_utils.wav_utils import get_white_noise |
|
|
|
|
|
|
|
|
class TestWatermarkModel: |
|
|
|
|
|
def test_base(self): |
|
|
sr = 16_000 |
|
|
duration = 1.0 |
|
|
wav = get_white_noise(1, int(sr * duration)).unsqueeze(0) |
|
|
wm = AudioSeal.get_pretrained(name="base") |
|
|
|
|
|
secret_message = torch.randint(0, 2, (1, 16), dtype=torch.int32) |
|
|
watermarked_wav = wm(wav, message=secret_message, sample_rate=sr, alpha=0.8) |
|
|
result = wm.detect_watermark(watermarked_wav) |
|
|
|
|
|
detected = ( |
|
|
torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1] |
|
|
) |
|
|
detect_prob = detected.cpu().item() |
|
|
|
|
|
assert detect_prob >= 0.0 |
|
|
|