imwithye commited on
Commit
90d394b
·
1 Parent(s): 73c1ae2

add residual block

Browse files
rlcube/models/models.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch
5
+
6
+
7
+ class ResidualBlock(nn.Module):
8
+ def __init__(self, input_dim, hidden_dim):
9
+ super(ResidualBlock, self).__init__()
10
+ self.bn1 = nn.BatchNorm1d(input_dim)
11
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
12
+ self.bn2 = nn.BatchNorm1d(hidden_dim)
13
+ self.fc2 = nn.Linear(hidden_dim, input_dim)
14
+
15
+ def forward(self, x):
16
+ residual = x
17
+ out = self.bn1(x)
18
+ out = F.relu(out)
19
+ out = self.fc1(out)
20
+ out = self.bn2(out)
21
+ out = F.relu(out)
22
+ out = self.fc2(out)
23
+ out = out + residual
24
+ return out
25
+
26
+
27
+ if __name__ == "__main__":
28
+ print("Testing ResidualBlock, input_dim=24, hidden_dim=128")
29
+ x = torch.randn(4, 24)
30
+ print("Input shape:", x.shape)
31
+ print("Output shape:", ResidualBlock(24, 128)(x).shape)
rlcube/pyproject.toml CHANGED
@@ -8,5 +8,7 @@ dependencies = [
8
  "fastapi[standard]>=0.116.2",
9
  "gymnasium>=1.2.0",
10
  "ipykernel>=6.30.1",
 
11
  "numpy>=2.3.2",
 
12
  ]
 
8
  "fastapi[standard]>=0.116.2",
9
  "gymnasium>=1.2.0",
10
  "ipykernel>=6.30.1",
11
+ "lightning>=2.5.5",
12
  "numpy>=2.3.2",
13
+ "torch>=2.8.0",
14
  ]
rlcube/uv.lock CHANGED
The diff for this file is too large to render. See raw diff