Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
Β·
12ab884
1
Parent(s):
4b66627
(chore): pre-flight hardening for model expansion (seeds, typo, diagnostics, dtypes, optional deterministic cuDNN)
Browse files- scripts/train_model.py +13 -11
scripts/train_model.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
import os
|
|
|
|
| 2 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 3 |
from datetime import datetime
|
| 4 |
import argparse, numpy as np, torch
|
|
@@ -21,8 +22,8 @@ parser.add_argument("--normalize", action="store_true")
|
|
| 21 |
parser.add_argument("--batch-size", type=int, default=16)
|
| 22 |
parser.add_argument("--epochs", type=int, default=10)
|
| 23 |
parser.add_argument("--learning-rate", type=float, default=1e-3)
|
| 24 |
-
parser.add_argument("--model", type=str, default="figure2",
|
| 25 |
-
|
| 26 |
args = parser.parse_args()
|
| 27 |
|
| 28 |
# Constants
|
|
@@ -36,7 +37,8 @@ os.makedirs("outputs", exist_ok=True)
|
|
| 36 |
os.makedirs("outputs/logs", exist_ok=True)
|
| 37 |
|
| 38 |
print("Preprocessing Configuration:")
|
| 39 |
-
print(f"
|
|
|
|
| 40 |
print(f" Baseline Correct: {'β
' if args.baseline else 'β'}")
|
| 41 |
print(f" Smoothing : {'β
' if args.smooth else 'β'}")
|
| 42 |
print(f" Normalization : {'β
' if args.normalize else 'β'}")
|
|
@@ -66,14 +68,13 @@ for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
|
|
| 66 |
y_train, y_val = y[train_idx], y[val_idx]
|
| 67 |
|
| 68 |
train_loader = DataLoader(
|
| 69 |
-
TensorDataset(torch.tensor(X_train), torch.tensor(y_train)),
|
| 70 |
batch_size=args.batch_size, shuffle=True)
|
| 71 |
val_loader = DataLoader(
|
| 72 |
-
TensorDataset(torch.tensor(X_val), torch.tensor(y_val
|
| 73 |
|
| 74 |
# Model selection
|
| 75 |
-
model = (
|
| 76 |
-
input_length=args.target_len).to(DEVICE)
|
| 77 |
|
| 78 |
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
| 79 |
criterion = torch.nn.CrossEntropyLoss()
|
|
@@ -127,9 +128,10 @@ print(f"β
Model saved to {model_path}")
|
|
| 127 |
|
| 128 |
|
| 129 |
def save_diagnostics_log(fold_acc, confs, args_param, output_path):
|
| 130 |
-
fold_metrics = [
|
| 131 |
-
|
| 132 |
-
|
|
|
|
| 133 |
log = {
|
| 134 |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 135 |
"preprocessing": {
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 4 |
from datetime import datetime
|
| 5 |
import argparse, numpy as np, torch
|
|
|
|
| 22 |
parser.add_argument("--batch-size", type=int, default=16)
|
| 23 |
parser.add_argument("--epochs", type=int, default=10)
|
| 24 |
parser.add_argument("--learning-rate", type=float, default=1e-3)
|
| 25 |
+
parser.add_argument("--model", type=str, default="figure2", choices=model_choices())
|
| 26 |
+
|
| 27 |
args = parser.parse_args()
|
| 28 |
|
| 29 |
# Constants
|
|
|
|
| 37 |
os.makedirs("outputs/logs", exist_ok=True)
|
| 38 |
|
| 39 |
print("Preprocessing Configuration:")
|
| 40 |
+
print(f" Resample to : {args.target_len}")
|
| 41 |
+
|
| 42 |
print(f" Baseline Correct: {'β
' if args.baseline else 'β'}")
|
| 43 |
print(f" Smoothing : {'β
' if args.smooth else 'β'}")
|
| 44 |
print(f" Normalization : {'β
' if args.normalize else 'β'}")
|
|
|
|
| 68 |
y_train, y_val = y[train_idx], y[val_idx]
|
| 69 |
|
| 70 |
train_loader = DataLoader(
|
| 71 |
+
TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)),
|
| 72 |
batch_size=args.batch_size, shuffle=True)
|
| 73 |
val_loader = DataLoader(
|
| 74 |
+
TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long)))
|
| 75 |
|
| 76 |
# Model selection
|
| 77 |
+
model = build_model(args.model, args.target_len).to(DEVICE)
|
|
|
|
| 78 |
|
| 79 |
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
|
| 80 |
criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
| 128 |
|
| 129 |
|
| 130 |
def save_diagnostics_log(fold_acc, confs, args_param, output_path):
|
| 131 |
+
fold_metrics = [
|
| 132 |
+
{"fold": i + 1, "accuracy": float(a), "confusion_matrix": c.tolist()}
|
| 133 |
+
for i, (a, c) in enumerate(zip(fold_acc, confs))
|
| 134 |
+
]
|
| 135 |
log = {
|
| 136 |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 137 |
"preprocessing": {
|