File size: 3,127 Bytes
4ffa9fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# model_img2ph.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class CNNEncoder(nn.Module):
    def __init__(self, in_channels=1, hidden_dim=256, dropout=0.2):
        super().__init__()
        # Convolutions mostly reduce frequency dimension, not time
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=(2,1), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(64, 128, kernel_size=3, stride=(2,1), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(128, 256, kernel_size=3, stride=(2,1), padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(256, hidden_dim, kernel_size=3, stride=(2,1), padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: (B, n_mels, T)
        x = x.unsqueeze(1)  # (B,1,n_mels,T)
        feat = self.conv(x)  # (B,C,H’,T)
        B, C, H, T = feat.size()
        # collapse frequency into features, keep time intact
        feat = feat.permute(0, 3, 1, 2).contiguous()  # (B,T,C,H)
        feat = feat.view(B, T, C*H)  # (B,T,features)
        return feat


class PhonemeDecoder(nn.Module):
    def __init__(self, vocab_size, enc_dim=128*5, rnn_hidden=128, num_layers=2, dropout=0.3):
        super().__init__()
        self.rnn = nn.GRU(
            enc_dim, rnn_hidden,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False  # Changed to unidirectional
        )
        self.proj = nn.Linear(rnn_hidden, 256)  # Single projection layer
        self.norm = nn.LayerNorm(256)  # Added LayerNorm
        self.fc_out = nn.Linear(256, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, enc_out):
        rnn_out, _ = self.rnn(enc_out)  # (B,T,rnn_hidden)
        dense_out = self.proj(rnn_out)  # (B,T,256)
        dense_out = self.norm(dense_out)  # Normalize
        dense_out = F.relu(dense_out)  # Activation
        dense_out = self.dropout(dense_out)  # Dropout after activation
        logits = self.fc_out(dense_out)  # (B,T,vocab_size)
        return logits

    


class Image2Phoneme(nn.Module):
    def __init__(self, vocab_size, in_channels=1, enc_hidden=128, rnn_hidden=128):
        super().__init__()
        self.encoder = CNNEncoder(in_channels=in_channels, hidden_dim=enc_hidden)
        # enc_dim = enc_hidden * H’, after convs Hβ€™β‰ˆ5 (if input mel=80, stride=(2,1) 4 times β†’ 80/16=5)
        enc_dim = enc_hidden * 5
        self.decoder = PhonemeDecoder(vocab_size, enc_dim=enc_dim, rnn_hidden=rnn_hidden)

    def forward(self, mels):
        # mels: (B,n_mels,T)
        enc_out = self.encoder(mels)     # (B,T,enc_dim)
        logits = self.decoder(enc_out)   # (B,T,vocab_size)
        return logits