Upload modeling_moment.py
Browse files- modeling_moment.py +15 -0
modeling_moment.py
CHANGED
|
@@ -503,6 +503,21 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
| 503 |
input_mask = torch.ones_like(time_series_values[:, 0, :])
|
| 504 |
|
| 505 |
return self.embed(x_enc=time_series_values, input_mask=input_mask, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
|
| 508 |
# refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L601
|
|
|
|
| 503 |
input_mask = torch.ones_like(time_series_values[:, 0, :])
|
| 504 |
|
| 505 |
return self.embed(x_enc=time_series_values, input_mask=input_mask, **kwargs)
|
| 506 |
+
|
| 507 |
+
def calculate_n_patches(self, seq_len: int) -> int:
|
| 508 |
+
"""
|
| 509 |
+
時系列の長さ(seq_len)を与えて、モデルのself.patch_lenとself.strideを使ってn_patchesを計算して返します。
|
| 510 |
+
strideがNoneの場合はpatch_lenを使用します。
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
seq_len (int): 時系列の長さ
|
| 514 |
+
|
| 515 |
+
Returns:
|
| 516 |
+
int: 計算されたn_patchesの数
|
| 517 |
+
"""
|
| 518 |
+
stride = self.stride if self.stride is not None else self.patch_len
|
| 519 |
+
n_patches = (seq_len - self.patch_len) // stride + 1
|
| 520 |
+
return n_patches
|
| 521 |
|
| 522 |
|
| 523 |
# refers: https://github.com/moment-timeseries-foundation-model/moment/blob/088b253a1138ac7e48a7efc9bf902336c9eec8d9/momentfm/models/moment.py#L601
|