|
|
from typing import Optional |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
try: |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
except ModuleNotFoundError: |
|
|
SummaryWriter = None |
|
|
|
|
|
try: |
|
|
import wandb |
|
|
except ModuleNotFoundError: |
|
|
wandb = None |
|
|
|
|
|
plt.ioff() |
|
|
|
|
|
PLOT_FILE_NAME = "metrics_plot.png" |
|
|
|
|
|
|
|
|
def safe_index(arr, idx): |
|
|
return arr[idx] if 0 <= idx < len(arr) else None |
|
|
|
|
|
|
|
|
class MetricsPlotSink: |
|
|
""" |
|
|
The MetricsPlotSink class records training metrics and saves them to a plot. |
|
|
|
|
|
Args: |
|
|
output_dir (str): Directory where the plot will be saved. |
|
|
""" |
|
|
|
|
|
def __init__(self, output_dir: str): |
|
|
self.output_dir = output_dir |
|
|
self.history = [] |
|
|
|
|
|
def update(self, values: dict): |
|
|
self.history.append(values) |
|
|
|
|
|
def save(self): |
|
|
if not self.history: |
|
|
print("No data to plot.") |
|
|
return |
|
|
|
|
|
def get_array(key): |
|
|
return np.array([h[key] for h in self.history if key in h]) |
|
|
|
|
|
epochs = get_array('epoch') |
|
|
train_loss = get_array('train_loss') |
|
|
test_loss = get_array('test_loss') |
|
|
test_coco_eval = [h['test_coco_eval_bbox'] for h in self.history if 'test_coco_eval_bbox' in h] |
|
|
ap50_90 = np.array([safe_index(x, 0) for x in test_coco_eval if x is not None], dtype=np.float32) |
|
|
ap50 = np.array([safe_index(x, 1) for x in test_coco_eval if x is not None], dtype=np.float32) |
|
|
ar50_90 = np.array([safe_index(x, 8) for x in test_coco_eval if x is not None], dtype=np.float32) |
|
|
|
|
|
ema_coco_eval = [h['ema_test_coco_eval_bbox'] for h in self.history if 'ema_test_coco_eval_bbox' in h] |
|
|
ema_ap50_90 = np.array([safe_index(x, 0) for x in ema_coco_eval if x is not None], dtype=np.float32) |
|
|
ema_ap50 = np.array([safe_index(x, 1) for x in ema_coco_eval if x is not None], dtype=np.float32) |
|
|
ema_ar50_90 = np.array([safe_index(x, 8) for x in ema_coco_eval if x is not None], dtype=np.float32) |
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(18, 12)) |
|
|
|
|
|
|
|
|
if len(epochs) > 0: |
|
|
if len(train_loss): |
|
|
axes[0][0].plot(epochs, train_loss, label='Training Loss', marker='o', linestyle='-') |
|
|
if len(test_loss): |
|
|
axes[0][0].plot(epochs, test_loss, label='Validation Loss', marker='o', linestyle='--') |
|
|
axes[0][0].set_title('Training and Validation Loss') |
|
|
axes[0][0].set_xlabel('Epoch Number') |
|
|
axes[0][0].set_ylabel('Loss Value') |
|
|
axes[0][0].legend() |
|
|
axes[0][0].grid(True) |
|
|
|
|
|
|
|
|
if ap50.size > 0 or ema_ap50.size > 0: |
|
|
if ap50.size > 0: |
|
|
axes[0][1].plot(epochs[:len(ap50)], ap50, marker='o', linestyle='-', label='Base Model') |
|
|
if ema_ap50.size > 0: |
|
|
axes[0][1].plot(epochs[:len(ema_ap50)], ema_ap50, marker='o', linestyle='--', label='EMA Model') |
|
|
axes[0][1].set_title('Average Precision @0.50') |
|
|
axes[0][1].set_xlabel('Epoch Number') |
|
|
axes[0][1].set_ylabel('AP50') |
|
|
axes[0][1].legend() |
|
|
axes[0][1].grid(True) |
|
|
|
|
|
|
|
|
if ap50_90.size > 0 or ema_ap50_90.size > 0: |
|
|
if ap50_90.size > 0: |
|
|
axes[1][0].plot(epochs[:len(ap50_90)], ap50_90, marker='o', linestyle='-', label='Base Model') |
|
|
if ema_ap50_90.size > 0: |
|
|
axes[1][0].plot(epochs[:len(ema_ap50_90)], ema_ap50_90, marker='o', linestyle='--', label='EMA Model') |
|
|
axes[1][0].set_title('Average Precision @0.50:0.95') |
|
|
axes[1][0].set_xlabel('Epoch Number') |
|
|
axes[1][0].set_ylabel('AP') |
|
|
axes[1][0].legend() |
|
|
axes[1][0].grid(True) |
|
|
|
|
|
|
|
|
if ar50_90.size > 0 or ema_ar50_90.size > 0: |
|
|
if ar50_90.size > 0: |
|
|
axes[1][1].plot(epochs[:len(ar50_90)], ar50_90, marker='o', linestyle='-', label='Base Model') |
|
|
if ema_ar50_90.size > 0: |
|
|
axes[1][1].plot(epochs[:len(ema_ar50_90)], ema_ar50_90, marker='o', linestyle='--', label='EMA Model') |
|
|
axes[1][1].set_title('Average Recall @0.50:0.95') |
|
|
axes[1][1].set_xlabel('Epoch Number') |
|
|
axes[1][1].set_ylabel('AR') |
|
|
axes[1][1].legend() |
|
|
axes[1][1].grid(True) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(f"{self.output_dir}/{PLOT_FILE_NAME}") |
|
|
plt.close(fig) |
|
|
print(f"Results saved to {self.output_dir}/{PLOT_FILE_NAME}") |
|
|
|
|
|
|
|
|
class MetricsTensorBoardSink: |
|
|
""" |
|
|
Training metrics via TensorBoard. |
|
|
|
|
|
Args: |
|
|
output_dir (str): Directory where TensorBoard logs will be written. |
|
|
""" |
|
|
|
|
|
def __init__(self, output_dir: str): |
|
|
if SummaryWriter: |
|
|
self.writer = SummaryWriter(log_dir=output_dir) |
|
|
print(f"TensorBoard logging initialized. To monitor logs, use 'tensorboard --logdir {output_dir}' and open http://localhost:6006/ in browser.") |
|
|
else: |
|
|
self.writer = None |
|
|
print("Unable to initialize TensorBoard. Logging is turned off for this session. Run 'pip install tensorboard' to enable logging.") |
|
|
|
|
|
def update(self, values: dict): |
|
|
if not self.writer: |
|
|
return |
|
|
|
|
|
epoch = values['epoch'] |
|
|
|
|
|
if 'train_loss' in values: |
|
|
self.writer.add_scalar("Loss/Train", values['train_loss'], epoch) |
|
|
if 'test_loss' in values: |
|
|
self.writer.add_scalar("Loss/Test", values['test_loss'], epoch) |
|
|
|
|
|
if 'test_coco_eval_bbox' in values: |
|
|
coco_eval = values['test_coco_eval_bbox'] |
|
|
ap50_90 = safe_index(coco_eval, 0) |
|
|
ap50 = safe_index(coco_eval, 1) |
|
|
ar50_90 = safe_index(coco_eval, 8) |
|
|
if ap50_90 is not None: |
|
|
self.writer.add_scalar("Metrics/Base/AP50_90", ap50_90, epoch) |
|
|
if ap50 is not None: |
|
|
self.writer.add_scalar("Metrics/Base/AP50", ap50, epoch) |
|
|
if ar50_90 is not None: |
|
|
self.writer.add_scalar("Metrics/Base/AR50_90", ar50_90, epoch) |
|
|
|
|
|
if 'ema_test_coco_eval_bbox' in values: |
|
|
ema_coco_eval = values['ema_test_coco_eval_bbox'] |
|
|
ema_ap50_90 = safe_index(ema_coco_eval, 0) |
|
|
ema_ap50 = safe_index(ema_coco_eval, 1) |
|
|
ema_ar50_90 = safe_index(ema_coco_eval, 8) |
|
|
if ema_ap50_90 is not None: |
|
|
self.writer.add_scalar("Metrics/EMA/AP50_90", ema_ap50_90, epoch) |
|
|
if ema_ap50 is not None: |
|
|
self.writer.add_scalar("Metrics/EMA/AP50", ema_ap50, epoch) |
|
|
if ema_ar50_90 is not None: |
|
|
self.writer.add_scalar("Metrics/EMA/AR50_90", ema_ar50_90, epoch) |
|
|
|
|
|
self.writer.flush() |
|
|
|
|
|
def close(self): |
|
|
if not self.writer: |
|
|
return |
|
|
|
|
|
self.writer.close() |
|
|
|
|
|
class MetricsWandBSink: |
|
|
""" |
|
|
Training metrics via W&B. |
|
|
|
|
|
Args: |
|
|
output_dir (str): Directory where W&B logs will be written locally. |
|
|
project (str, optional): Associate this training run with a W&B project. If None, W&B will generate a name based on the git repo name. |
|
|
run (str, optional): W&B run name. If None, W&B will generate a random name. |
|
|
config (dict, optional): Input parameters, like hyperparameters or data preprocessing settings for the run for later comparison. |
|
|
""" |
|
|
|
|
|
def __init__(self, output_dir: str, project: Optional[str] = None, run: Optional[str] = None, config: Optional[dict] = None): |
|
|
self.output_dir = output_dir |
|
|
if wandb: |
|
|
self.run = wandb.init( |
|
|
project=project, |
|
|
name=run, |
|
|
config=config, |
|
|
dir=output_dir |
|
|
) |
|
|
print(f"W&B logging initialized. To monitor logs, open {wandb.run.url}.") |
|
|
else: |
|
|
self.run = None |
|
|
print("Unable to initialize W&B. Logging is turned off for this session. Run 'pip install wandb' to enable logging.") |
|
|
|
|
|
def update(self, values: dict): |
|
|
if not wandb or not self.run: |
|
|
return |
|
|
|
|
|
epoch = values['epoch'] |
|
|
log_dict = {"epoch": epoch} |
|
|
|
|
|
if 'train_loss' in values: |
|
|
log_dict["Loss/Train"] = values['train_loss'] |
|
|
if 'test_loss' in values: |
|
|
log_dict["Loss/Test"] = values['test_loss'] |
|
|
|
|
|
if 'test_coco_eval_bbox' in values: |
|
|
coco_eval = values['test_coco_eval_bbox'] |
|
|
ap50_90 = safe_index(coco_eval, 0) |
|
|
ap50 = safe_index(coco_eval, 1) |
|
|
ar50_90 = safe_index(coco_eval, 8) |
|
|
if ap50_90 is not None: |
|
|
log_dict["Metrics/Base/AP50_90"] = ap50_90 |
|
|
if ap50 is not None: |
|
|
log_dict["Metrics/Base/AP50"] = ap50 |
|
|
if ar50_90 is not None: |
|
|
log_dict["Metrics/Base/AR50_90"] = ar50_90 |
|
|
|
|
|
if 'ema_test_coco_eval_bbox' in values: |
|
|
ema_coco_eval = values['ema_test_coco_eval_bbox'] |
|
|
ema_ap50_90 = safe_index(ema_coco_eval, 0) |
|
|
ema_ap50 = safe_index(ema_coco_eval, 1) |
|
|
ema_ar50_90 = safe_index(ema_coco_eval, 8) |
|
|
if ema_ap50_90 is not None: |
|
|
log_dict["Metrics/EMA/AP50_90"] = ema_ap50_90 |
|
|
if ema_ap50 is not None: |
|
|
log_dict["Metrics/EMA/AP50"] = ema_ap50 |
|
|
if ema_ar50_90 is not None: |
|
|
log_dict["Metrics/EMA/AR50_90"] = ema_ar50_90 |
|
|
|
|
|
wandb.log(log_dict) |
|
|
|
|
|
def close(self): |
|
|
if not wandb or not self.run: |
|
|
return |
|
|
|
|
|
self.run.finish() |