✨ [New] inference code and refactor train example
Browse files- examples/example_inference.py +35 -0
- examples/example_train.py +8 -10
examples/example_inference.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import hydra
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
project_root = Path(__file__).resolve().parent.parent
|
| 8 |
+
sys.path.append(str(project_root))
|
| 9 |
+
|
| 10 |
+
from yolo.config.config import Config
|
| 11 |
+
from yolo.model.yolo import get_model
|
| 12 |
+
from yolo.tools.data_loader import create_dataloader
|
| 13 |
+
from yolo.tools.solver import ModelTester
|
| 14 |
+
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
| 18 |
+
def main(cfg: Config):
|
| 19 |
+
custom_logger()
|
| 20 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
| 21 |
+
|
| 22 |
+
device = torch.device(cfg.device)
|
| 23 |
+
model = get_model(cfg).to(device)
|
| 24 |
+
|
| 25 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
| 26 |
+
dataloader = create_dataloader(cfg)
|
| 27 |
+
device = torch.device(cfg.device)
|
| 28 |
+
model = get_model(cfg).to(device)
|
| 29 |
+
|
| 30 |
+
tester = ModelTester(cfg, model, save_path, device)
|
| 31 |
+
tester.solve(dataloader, cfg.task.epoch)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
main()
|
examples/example_train.py
CHANGED
|
@@ -3,30 +3,28 @@ from pathlib import Path
|
|
| 3 |
|
| 4 |
import hydra
|
| 5 |
import torch
|
| 6 |
-
from loguru import logger
|
| 7 |
|
| 8 |
project_root = Path(__file__).resolve().parent.parent
|
| 9 |
sys.path.append(str(project_root))
|
| 10 |
|
| 11 |
from yolo.config.config import Config
|
|
|
|
| 12 |
from yolo.tools.data_loader import create_dataloader
|
| 13 |
-
from yolo.tools.
|
| 14 |
-
from yolo.tools.trainer import ModelTrainer
|
| 15 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
| 16 |
|
| 17 |
|
| 18 |
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
| 19 |
def main(cfg: Config):
|
| 20 |
custom_logger()
|
| 21 |
-
save_path = validate_log_directory(cfg
|
| 22 |
-
if cfg.download.auto:
|
| 23 |
-
prepare_dataset(cfg.download)
|
| 24 |
-
|
| 25 |
dataloader = create_dataloader(cfg)
|
| 26 |
# TODO: get_device or rank, for DDP mode
|
| 27 |
-
device = torch.device(
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
if __name__ == "__main__":
|
|
|
|
| 3 |
|
| 4 |
import hydra
|
| 5 |
import torch
|
|
|
|
| 6 |
|
| 7 |
project_root = Path(__file__).resolve().parent.parent
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
+
from yolo.model.yolo import get_model
|
| 12 |
from yolo.tools.data_loader import create_dataloader
|
| 13 |
+
from yolo.tools.solver import ModelTrainer
|
|
|
|
| 14 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
| 15 |
|
| 16 |
|
| 17 |
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
| 18 |
def main(cfg: Config):
|
| 19 |
custom_logger()
|
| 20 |
+
save_path = validate_log_directory(cfg, cfg.name)
|
|
|
|
|
|
|
|
|
|
| 21 |
dataloader = create_dataloader(cfg)
|
| 22 |
# TODO: get_device or rank, for DDP mode
|
| 23 |
+
device = torch.device(cfg.device)
|
| 24 |
+
model = get_model(cfg).to(device)
|
| 25 |
+
|
| 26 |
+
trainer = ModelTrainer(cfg, model, save_path, device)
|
| 27 |
+
trainer.solve(dataloader, cfg.task.epoch)
|
| 28 |
|
| 29 |
|
| 30 |
if __name__ == "__main__":
|