xieli commited on
Commit
822d1fc
·
1 Parent(s): 91b9368

feat: update hifigan

Browse files
stepvocoder/cosyvoice2/hifigan/f0_predictor.py CHANGED
@@ -13,8 +13,10 @@
13
  # limitations under the License.
14
  import torch
15
  import torch.nn as nn
16
- from torch.nn.utils import weight_norm
17
-
 
 
18
 
19
  class ConvRNNF0Predictor(nn.Module):
20
  def __init__(self,
 
13
  # limitations under the License.
14
  import torch
15
  import torch.nn as nn
16
+ try:
17
+ from torch.nn.utils.parametrizations import weight_norm
18
+ except ImportError:
19
+ from torch.nn.utils import weight_norm
20
 
21
  class ConvRNNF0Predictor(nn.Module):
22
  def __init__(self,
stepvocoder/cosyvoice2/hifigan/generator.py CHANGED
@@ -25,7 +25,10 @@ from torch.nn import ConvTranspose1d
25
  from torch.nn.utils import remove_weight_norm
26
  from torch.nn.utils import weight_norm
27
  from torch.distributions.uniform import Uniform
28
-
 
 
 
29
  from stepvocoder.cosyvoice2.hifigan.activation import Snake
30
  from stepvocoder.cosyvoice2.utils.common import get_padding, init_weights
31
 
@@ -133,7 +136,7 @@ class SineGen(torch.nn.Module):
133
 
134
  def _f02uv(self, f0):
135
  # generate uv signal
136
- uv = (f0 > self.voiced_threshold).type(f0.dtype)
137
  return uv
138
 
139
  @torch.no_grad()
@@ -142,13 +145,14 @@ class SineGen(torch.nn.Module):
142
  :param f0: [B, 1, sample_len], Hz
143
  :return: [B, 1, sample_len]
144
  """
145
- F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1)),dtype=f0.dtype).to(f0.device)
 
146
  for i in range(self.harmonic_num + 1):
147
  F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
148
 
149
  theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
150
  u_dist = Uniform(low=-np.pi, high=np.pi)
151
- phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device).to(f0.dtype)
152
  phase_vec[:, 0, :] = 0
153
  # generate sine waveforms
154
  sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
@@ -211,6 +215,172 @@ class SourceModuleHnNSF(torch.nn.Module):
211
  sine_wavs = sine_wavs.transpose(1, 2)
212
  uv = uv.transpose(1, 2)
213
  sine_merge = self.l_tanh(self.l_linear(sine_wavs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  # source for noise branch, in the same shape as uv
215
  noise = torch.randn_like(uv) * self.sine_amp / 3
216
  return sine_merge, noise, uv
@@ -252,7 +422,10 @@ class HiFTGenerator(nn.Module):
252
 
253
  self.num_kernels = len(resblock_kernel_sizes)
254
  self.num_upsamples = len(upsample_rates)
255
- self.m_source = SourceModuleHnNSF(
 
 
 
256
  sampling_rate=sampling_rate,
257
  upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
258
  harmonic_num=nb_harmonics,
@@ -312,7 +485,7 @@ class HiFTGenerator(nn.Module):
312
  self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
313
  self.f0_predictor = f0_predictor
314
 
315
- # for cuda graph
316
  self.use_cuda_graph = False
317
  self.graph = {}
318
  self.inference_buffers = {}
@@ -347,6 +520,25 @@ class HiFTGenerator(nn.Module):
347
  self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
348
  return inverse_transform
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def decode_without_stft(self, x: torch.Tensor, s_stft: torch.Tensor) -> torch.Tensor:
351
  x = self.conv_pre(x)
352
  for i in range(self.num_upsamples):
@@ -373,24 +565,6 @@ class HiFTGenerator(nn.Module):
373
  x = self.conv_post(x)
374
  return x
375
 
376
- def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
377
- s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
378
- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1).to(s.dtype)
379
- if self.use_cuda_graph and x.shape[-1] in self.graph:
380
- self.inference_buffers[x.shape[-1]]['static_inputs']['static_x'].copy_(x)
381
- self.inference_buffers[x.shape[-1]]['static_inputs']['static_s_stft'].copy_(s_stft)
382
- self.graph[x.shape[-1]].replay()
383
- x = self.inference_buffers[x.shape[-1]]['static_outputs']['static_output_x']
384
- else:
385
- x = self.decode_without_stft(x, s_stft)
386
- magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
387
- phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
388
- magnitude = magnitude.to(torch.float32)
389
- phase = phase.to(torch.float32)
390
- x = self._istft(magnitude, phase).to(x.dtype)
391
- x = torch.clamp(x, -self.audio_limit, self.audio_limit)
392
- return x
393
-
394
  def forward(
395
  self,
396
  batch: dict,
@@ -406,7 +580,7 @@ class HiFTGenerator(nn.Module):
406
  # mel+source->speech
407
  generated_speech = self.decode(x=speech_feat, s=s)
408
  return generated_speech, f0
409
-
410
  def _init_cuda_graph(self):
411
  self.use_cuda_graph = True
412
  dummy_param = next(self.parameters())
@@ -435,7 +609,7 @@ class HiFTGenerator(nn.Module):
435
  print(f"CUDA Graph initialized successfully for chunk generator")
436
 
437
  @torch.inference_mode()
438
- def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = None) -> torch.Tensor:
439
  # mel->f0
440
  f0 = self.f0_predictor(speech_feat)
441
  # f0->source
@@ -443,7 +617,7 @@ class HiFTGenerator(nn.Module):
443
  s, _, _ = self.m_source(s)
444
  s = s.transpose(1, 2)
445
  # use cache_source to avoid glitch
446
- if cache_source is not None:
447
  s[:, :, :cache_source.shape[2]] = cache_source
448
  generated_speech = self.decode(x=speech_feat, s=s)
449
- return generated_speech, s
 
25
  from torch.nn.utils import remove_weight_norm
26
  from torch.nn.utils import weight_norm
27
  from torch.distributions.uniform import Uniform
28
+ try:
29
+ from torch.nn.utils.parametrizations import weight_norm
30
+ except ImportError:
31
+ from torch.nn.utils import weight_norm
32
  from stepvocoder.cosyvoice2.hifigan.activation import Snake
33
  from stepvocoder.cosyvoice2.utils.common import get_padding, init_weights
34
 
 
136
 
137
  def _f02uv(self, f0):
138
  # generate uv signal
139
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
140
  return uv
141
 
142
  @torch.no_grad()
 
145
  :param f0: [B, 1, sample_len], Hz
146
  :return: [B, 1, sample_len]
147
  """
148
+
149
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
150
  for i in range(self.harmonic_num + 1):
151
  F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
152
 
153
  theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
154
  u_dist = Uniform(low=-np.pi, high=np.pi)
155
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
156
  phase_vec[:, 0, :] = 0
157
  # generate sine waveforms
158
  sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
 
215
  sine_wavs = sine_wavs.transpose(1, 2)
216
  uv = uv.transpose(1, 2)
217
  sine_merge = self.l_tanh(self.l_linear(sine_wavs))
218
+
219
+ # source for noise branch, in the same shape as uv
220
+ noise = torch.randn_like(uv) * self.sine_amp / 3
221
+ return sine_merge, noise, uv
222
+
223
+
224
+ class SineGen2(torch.nn.Module):
225
+ """ Definition of sine generator
226
+ SineGen(samp_rate, harmonic_num = 0,
227
+ sine_amp = 0.1, noise_std = 0.003,
228
+ voiced_threshold = 0,
229
+ flag_for_pulse=False)
230
+ samp_rate: sampling rate in Hz
231
+ harmonic_num: number of harmonic overtones (default 0)
232
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
233
+ noise_std: std of Gaussian noise (default 0.003)
234
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
235
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
236
+ Note: when flag_for_pulse is True, the first time step of a voiced
237
+ segment is always sin(np.pi) or cos(0)
238
+ """
239
+
240
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
241
+ sine_amp=0.1, noise_std=0.003,
242
+ voiced_threshold=0,
243
+ flag_for_pulse=False):
244
+ super(SineGen2, self).__init__()
245
+ self.sine_amp = sine_amp
246
+ self.noise_std = noise_std
247
+ self.harmonic_num = harmonic_num
248
+ self.dim = self.harmonic_num + 1
249
+ self.sampling_rate = samp_rate
250
+ self.voiced_threshold = voiced_threshold
251
+ self.flag_for_pulse = flag_for_pulse
252
+ self.upsample_scale = upsample_scale
253
+
254
+ def _f02uv(self, f0):
255
+ # generate uv signal
256
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
257
+ return uv
258
+
259
+ def _f02sine(self, f0_values):
260
+ """ f0_values: (batchsize, length, dim)
261
+ where dim indicates fundamental tone and overtones
262
+ """
263
+ # convert to F0 in rad. The interger part n can be ignored
264
+ # because 2 * np.pi * n doesn't affect phase
265
+ rad_values = (f0_values / self.sampling_rate) % 1
266
+
267
+ # initial phase noise (no noise for fundamental component)
268
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
269
+ rand_ini[:, 0] = 0
270
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
271
+
272
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
273
+ if not self.flag_for_pulse:
274
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
275
+ scale_factor=1 / self.upsample_scale,
276
+ mode="linear").transpose(1, 2)
277
+
278
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
279
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
280
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
281
+ sines = torch.sin(phase)
282
+ else:
283
+ # If necessary, make sure that the first time step of every
284
+ # voiced segments is sin(pi) or cos(0)
285
+ # This is used for pulse-train generation
286
+
287
+ # identify the last time step in unvoiced segments
288
+ uv = self._f02uv(f0_values)
289
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
290
+ uv_1[:, -1, :] = 1
291
+ u_loc = (uv < 1) * (uv_1 > 0)
292
+
293
+ # get the instantanouse phase
294
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
295
+ # different batch needs to be processed differently
296
+ for idx in range(f0_values.shape[0]):
297
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
298
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
299
+ # stores the accumulation of i.phase within
300
+ # each voiced segments
301
+ tmp_cumsum[idx, :, :] = 0
302
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
303
+
304
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
305
+ # within the previous voiced segment.
306
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
307
+
308
+ # get the sines
309
+ sines = torch.cos(i_phase * 2 * np.pi)
310
+ return sines
311
+
312
+ def forward(self, f0):
313
+ """ sine_tensor, uv = forward(f0)
314
+ input F0: tensor(batchsize=1, length, dim=1)
315
+ f0 for unvoiced steps should be 0
316
+ output sine_tensor: tensor(batchsize=1, length, dim)
317
+ output uv: tensor(batchsize=1, length, 1)
318
+ """
319
+ # fundamental component
320
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
321
+
322
+ # generate sine waveforms
323
+ sine_waves = self._f02sine(fn) * self.sine_amp
324
+
325
+ # generate uv signal
326
+ uv = self._f02uv(f0)
327
+
328
+ # noise: for unvoiced should be similar to sine_amp
329
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
330
+ # . for voiced regions is self.noise_std
331
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
332
+ noise = noise_amp * torch.randn_like(sine_waves)
333
+
334
+ # first: set the unvoiced part to 0 by uv
335
+ # then: additive noise
336
+ sine_waves = sine_waves * uv + noise
337
+ return sine_waves, uv, noise
338
+
339
+ class SourceModuleHnNSF2(torch.nn.Module):
340
+ """ SourceModule for hn-nsf
341
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
342
+ add_noise_std=0.003, voiced_threshod=0)
343
+ sampling_rate: sampling_rate in Hz
344
+ harmonic_num: number of harmonic above F0 (default: 0)
345
+ sine_amp: amplitude of sine source signal (default: 0.1)
346
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
347
+ note that amplitude of noise in unvoiced is decided
348
+ by sine_amp
349
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
350
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
351
+ F0_sampled (batchsize, length, 1)
352
+ Sine_source (batchsize, length, 1)
353
+ noise_source (batchsize, length 1)
354
+ uv (batchsize, length, 1)
355
+ """
356
+
357
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
358
+ add_noise_std=0.003, voiced_threshod=0):
359
+ super(SourceModuleHnNSF2, self).__init__()
360
+
361
+ self.sine_amp = sine_amp
362
+ self.noise_std = add_noise_std
363
+
364
+ # to produce sine waveforms
365
+ self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
366
+ sine_amp, add_noise_std, voiced_threshod)
367
+
368
+ # to merge source harmonics into a single excitation
369
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
370
+ self.l_tanh = torch.nn.Tanh()
371
+
372
+ def forward(self, x):
373
+ """
374
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
375
+ F0_sampled (batchsize, length, 1)
376
+ Sine_source (batchsize, length, 1)
377
+ noise_source (batchsize, length 1)
378
+ """
379
+ # source for harmonic branch
380
+ with torch.no_grad():
381
+ sine_wavs, uv, _ = self.l_sin_gen(x)
382
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
383
+
384
  # source for noise branch, in the same shape as uv
385
  noise = torch.randn_like(uv) * self.sine_amp / 3
386
  return sine_merge, noise, uv
 
422
 
423
  self.num_kernels = len(resblock_kernel_sizes)
424
  self.num_upsamples = len(upsample_rates)
425
+ # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
426
+ # this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
427
+ this_SourceModuleHnNSF = SourceModuleHnNSF2 # WBY
428
+ self.m_source = this_SourceModuleHnNSF(
429
  sampling_rate=sampling_rate,
430
  upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
431
  harmonic_num=nb_harmonics,
 
485
  self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
486
  self.f0_predictor = f0_predictor
487
 
488
+ # for cuda graph
489
  self.use_cuda_graph = False
490
  self.graph = {}
491
  self.inference_buffers = {}
 
520
  self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
521
  return inverse_transform
522
 
523
+
524
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
525
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
526
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1).to(s.dtype)
527
+ if self.use_cuda_graph and x.shape[-1] in self.graph:
528
+ self.inference_buffers[x.shape[-1]]['static_inputs']['static_x'].copy_(x)
529
+ self.inference_buffers[x.shape[-1]]['static_inputs']['static_s_stft'].copy_(s_stft)
530
+ self.graph[x.shape[-1]].replay()
531
+ x = self.inference_buffers[x.shape[-1]]['static_outputs']['static_output_x']
532
+ else:
533
+ x = self.decode_without_stft(x, s_stft)
534
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
535
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
536
+ magnitude = magnitude.to(torch.float32)
537
+ phase = phase.to(torch.float32)
538
+ x = self._istft(magnitude, phase).to(x.dtype)
539
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
540
+ return x
541
+
542
  def decode_without_stft(self, x: torch.Tensor, s_stft: torch.Tensor) -> torch.Tensor:
543
  x = self.conv_pre(x)
544
  for i in range(self.num_upsamples):
 
565
  x = self.conv_post(x)
566
  return x
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  def forward(
569
  self,
570
  batch: dict,
 
580
  # mel+source->speech
581
  generated_speech = self.decode(x=speech_feat, s=s)
582
  return generated_speech, f0
583
+
584
  def _init_cuda_graph(self):
585
  self.use_cuda_graph = True
586
  dummy_param = next(self.parameters())
 
609
  print(f"CUDA Graph initialized successfully for chunk generator")
610
 
611
  @torch.inference_mode()
612
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
613
  # mel->f0
614
  f0 = self.f0_predictor(speech_feat)
615
  # f0->source
 
617
  s, _, _ = self.m_source(s)
618
  s = s.transpose(1, 2)
619
  # use cache_source to avoid glitch
620
+ if cache_source.shape[2] != 0:
621
  s[:, :, :cache_source.shape[2]] = cache_source
622
  generated_speech = self.decode(x=speech_feat, s=s)
623
+ return generated_speech, s