Upload SegmentBorzoi
Browse files- segment_borzoi.py +20 -11
segment_borzoi.py
CHANGED
|
@@ -7,7 +7,9 @@ from einops import rearrange
|
|
| 7 |
from torch import einsum
|
| 8 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 9 |
|
| 10 |
-
from genomics_research.segmentnt.layers.
|
|
|
|
|
|
|
| 11 |
|
| 12 |
FEATURES = [
|
| 13 |
"protein_coding_gene",
|
|
@@ -91,7 +93,7 @@ class SegmentBorzoi(PreTrainedModel):
|
|
| 91 |
|
| 92 |
# Correct transformer
|
| 93 |
for layer in self.transformer:
|
| 94 |
-
layer[0].fn[1] = BorzoiAttentionLayer(
|
| 95 |
config.embed_dim,
|
| 96 |
heads=config.num_attention_heads,
|
| 97 |
dim_key=config.attention_dim_key,
|
|
@@ -105,7 +107,7 @@ class SegmentBorzoi(PreTrainedModel):
|
|
| 105 |
self.separable1.conv_layer[1].bias = None
|
| 106 |
self.separable0.conv_layer[1].bias = None
|
| 107 |
|
| 108 |
-
def forward(self, x):
|
| 109 |
# Stem
|
| 110 |
x = x.transpose(1, 2)
|
| 111 |
x = self.stem(x)
|
|
@@ -199,14 +201,14 @@ def relative_shift(x: torch.Tensor) -> torch.Tensor:
|
|
| 199 |
to_pad = torch.zeros_like(x[..., :1])
|
| 200 |
x = torch.cat((to_pad, x), dim=-1)
|
| 201 |
_, h, t1, t2 = x.shape
|
| 202 |
-
x = x.reshape(-1, h, t2, t1)
|
| 203 |
x = x[:, :, 1:, :]
|
| 204 |
-
x = x.reshape(-1, h, t1, t2 - 1)
|
| 205 |
return x[..., : ((t2 + 1) // 2)]
|
| 206 |
|
| 207 |
|
| 208 |
class BorzoiAttentionLayer(nn.Module):
|
| 209 |
-
def __init__(
|
| 210 |
self,
|
| 211 |
dim,
|
| 212 |
*,
|
|
@@ -216,7 +218,7 @@ class BorzoiAttentionLayer(nn.Module):
|
|
| 216 |
dim_value=64,
|
| 217 |
dropout=0.0,
|
| 218 |
pos_dropout=0.0,
|
| 219 |
-
):
|
| 220 |
super().__init__()
|
| 221 |
self.scale = dim_key**-0.5
|
| 222 |
self.heads = heads
|
|
@@ -232,22 +234,29 @@ class BorzoiAttentionLayer(nn.Module):
|
|
| 232 |
self.num_rel_pos_features = num_rel_pos_features
|
| 233 |
|
| 234 |
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
|
| 235 |
-
self.rel_content_bias = nn.Parameter(
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
# dropouts
|
| 239 |
|
| 240 |
self.pos_dropout = nn.Dropout(pos_dropout)
|
| 241 |
self.attn_dropout = nn.Dropout(dropout)
|
| 242 |
|
| 243 |
-
def forward(self, x):
|
| 244 |
n, h = x.shape[-2], self.heads
|
| 245 |
|
| 246 |
q = self.to_q(x)
|
| 247 |
k = self.to_k(x)
|
| 248 |
v = self.to_v(x)
|
| 249 |
|
| 250 |
-
q, k, v = map(
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
q = q * self.scale
|
| 253 |
|
|
|
|
| 7 |
from torch import einsum
|
| 8 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 9 |
|
| 10 |
+
from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import (
|
| 11 |
+
TorchUNetHead,
|
| 12 |
+
)
|
| 13 |
|
| 14 |
FEATURES = [
|
| 15 |
"protein_coding_gene",
|
|
|
|
| 93 |
|
| 94 |
# Correct transformer
|
| 95 |
for layer in self.transformer:
|
| 96 |
+
layer[0].fn[1] = BorzoiAttentionLayer( # type: ignore
|
| 97 |
config.embed_dim,
|
| 98 |
heads=config.num_attention_heads,
|
| 99 |
dim_key=config.attention_dim_key,
|
|
|
|
| 107 |
self.separable1.conv_layer[1].bias = None
|
| 108 |
self.separable0.conv_layer[1].bias = None
|
| 109 |
|
| 110 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 111 |
# Stem
|
| 112 |
x = x.transpose(1, 2)
|
| 113 |
x = self.stem(x)
|
|
|
|
| 201 |
to_pad = torch.zeros_like(x[..., :1])
|
| 202 |
x = torch.cat((to_pad, x), dim=-1)
|
| 203 |
_, h, t1, t2 = x.shape
|
| 204 |
+
x = x.reshape(-1, h, t2, t1) # noqa: FKA100
|
| 205 |
x = x[:, :, 1:, :]
|
| 206 |
+
x = x.reshape(-1, h, t1, t2 - 1) # noqa: FKA100
|
| 207 |
return x[..., : ((t2 + 1) // 2)]
|
| 208 |
|
| 209 |
|
| 210 |
class BorzoiAttentionLayer(nn.Module):
|
| 211 |
+
def __init__( # type: ignore
|
| 212 |
self,
|
| 213 |
dim,
|
| 214 |
*,
|
|
|
|
| 218 |
dim_value=64,
|
| 219 |
dropout=0.0,
|
| 220 |
pos_dropout=0.0,
|
| 221 |
+
) -> None:
|
| 222 |
super().__init__()
|
| 223 |
self.scale = dim_key**-0.5
|
| 224 |
self.heads = heads
|
|
|
|
| 234 |
self.num_rel_pos_features = num_rel_pos_features
|
| 235 |
|
| 236 |
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
|
| 237 |
+
self.rel_content_bias = nn.Parameter(
|
| 238 |
+
torch.randn(1, heads, 1, dim_key) # noqa: FKA100
|
| 239 |
+
)
|
| 240 |
+
self.rel_pos_bias = nn.Parameter(
|
| 241 |
+
torch.randn(1, heads, 1, dim_key) # noqa: FKA100
|
| 242 |
+
)
|
| 243 |
|
| 244 |
# dropouts
|
| 245 |
|
| 246 |
self.pos_dropout = nn.Dropout(pos_dropout)
|
| 247 |
self.attn_dropout = nn.Dropout(dropout)
|
| 248 |
|
| 249 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 250 |
n, h = x.shape[-2], self.heads
|
| 251 |
|
| 252 |
q = self.to_q(x)
|
| 253 |
k = self.to_k(x)
|
| 254 |
v = self.to_v(x)
|
| 255 |
|
| 256 |
+
q, k, v = map( # noqa
|
| 257 |
+
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), # type: ignore
|
| 258 |
+
(q, k, v),
|
| 259 |
+
)
|
| 260 |
|
| 261 |
q = q * self.scale
|
| 262 |
|