|
|
""" |
|
|
Logging utilities for LLM training. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import json |
|
|
import logging |
|
|
from typing import Dict, Any, Optional, List |
|
|
import numpy as np |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
def setup_logger( |
|
|
name: str, |
|
|
log_file: Optional[str] = None, |
|
|
level: int = logging.INFO |
|
|
) -> logging.Logger: |
|
|
""" |
|
|
Set up logger. |
|
|
|
|
|
Args: |
|
|
name: Logger name |
|
|
log_file: Path to log file |
|
|
level: Logging level |
|
|
|
|
|
Returns: |
|
|
Logger |
|
|
""" |
|
|
|
|
|
logger = logging.getLogger(name) |
|
|
logger.setLevel(level) |
|
|
|
|
|
|
|
|
formatter = logging.Formatter( |
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
|
) |
|
|
|
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
|
console_handler.setLevel(level) |
|
|
console_handler.setFormatter(formatter) |
|
|
logger.addHandler(console_handler) |
|
|
|
|
|
|
|
|
if log_file is not None: |
|
|
|
|
|
os.makedirs(os.path.dirname(log_file), exist_ok=True) |
|
|
|
|
|
|
|
|
file_handler = logging.FileHandler(log_file) |
|
|
file_handler.setLevel(level) |
|
|
file_handler.setFormatter(formatter) |
|
|
logger.addHandler(file_handler) |
|
|
|
|
|
return logger |
|
|
|
|
|
|
|
|
def log_metrics( |
|
|
metrics: Dict[str, Any], |
|
|
step: int, |
|
|
logger: Optional[logging.Logger] = None, |
|
|
prefix: str = "", |
|
|
log_to_console: bool = True |
|
|
) -> None: |
|
|
""" |
|
|
Log metrics. |
|
|
|
|
|
Args: |
|
|
metrics: Dictionary of metrics |
|
|
step: Training step |
|
|
logger: Logger |
|
|
prefix: Prefix for metric names |
|
|
log_to_console: Whether to log to console |
|
|
""" |
|
|
|
|
|
if prefix: |
|
|
metrics = {f"{prefix}/{k}": v for k, v in metrics.items()} |
|
|
|
|
|
|
|
|
metrics = { |
|
|
k: float(v) if isinstance(v, (np.ndarray, jnp.ndarray)) else v |
|
|
for k, v in metrics.items() |
|
|
} |
|
|
|
|
|
|
|
|
if log_to_console: |
|
|
print(f"Step {step}:") |
|
|
for k, v in metrics.items(): |
|
|
if isinstance(v, float): |
|
|
print(f" {k}: {v:.4f}") |
|
|
else: |
|
|
print(f" {k}: {v}") |
|
|
|
|
|
|
|
|
if logger is not None: |
|
|
logger.info(f"Step {step}: {metrics}") |
|
|
|
|
|
|
|
|
def create_summary_writer(log_dir: str) -> tf.summary.SummaryWriter: |
|
|
""" |
|
|
Create TensorBoard summary writer. |
|
|
|
|
|
Args: |
|
|
log_dir: Directory for TensorBoard logs |
|
|
|
|
|
Returns: |
|
|
TensorBoard summary writer |
|
|
""" |
|
|
|
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
return tf.summary.create_file_writer(log_dir) |
|
|
|
|
|
|
|
|
def log_metrics_to_tensorboard( |
|
|
metrics: Dict[str, Any], |
|
|
step: int, |
|
|
writer: tf.summary.SummaryWriter, |
|
|
prefix: str = "" |
|
|
) -> None: |
|
|
""" |
|
|
Log metrics to TensorBoard. |
|
|
|
|
|
Args: |
|
|
metrics: Dictionary of metrics |
|
|
step: Training step |
|
|
writer: TensorBoard summary writer |
|
|
prefix: Prefix for metric names |
|
|
""" |
|
|
|
|
|
if prefix: |
|
|
metrics = {f"{prefix}/{k}": v for k, v in metrics.items()} |
|
|
|
|
|
|
|
|
metrics = { |
|
|
k: float(v) if isinstance(v, (np.ndarray, jnp.ndarray)) else v |
|
|
for k, v in metrics.items() |
|
|
} |
|
|
|
|
|
|
|
|
with writer.as_default(): |
|
|
for k, v in metrics.items(): |
|
|
if isinstance(v, float): |
|
|
tf.summary.scalar(k, v, step=step) |
|
|
elif isinstance(v, (list, tuple)) and all(isinstance(x, float) for x in v): |
|
|
tf.summary.histogram(k, v, step=step) |
|
|
|
|
|
|
|
|
writer.flush() |
|
|
|
|
|
|
|
|
def log_text_to_tensorboard( |
|
|
text: str, |
|
|
tag: str, |
|
|
step: int, |
|
|
writer: tf.summary.SummaryWriter |
|
|
) -> None: |
|
|
""" |
|
|
Log text to TensorBoard. |
|
|
|
|
|
Args: |
|
|
text: Text to log |
|
|
tag: Tag for text |
|
|
step: Training step |
|
|
writer: TensorBoard summary writer |
|
|
""" |
|
|
|
|
|
with writer.as_default(): |
|
|
tf.summary.text(tag, text, step=step) |
|
|
|
|
|
|
|
|
writer.flush() |
|
|
|
|
|
|
|
|
def log_model_summary( |
|
|
model: Any, |
|
|
input_shape: tuple, |
|
|
logger: Optional[logging.Logger] = None |
|
|
) -> None: |
|
|
""" |
|
|
Log model summary. |
|
|
|
|
|
Args: |
|
|
model: Model |
|
|
input_shape: Input shape |
|
|
logger: Logger |
|
|
""" |
|
|
|
|
|
dummy_input = jnp.ones(input_shape, dtype=jnp.int32) |
|
|
|
|
|
|
|
|
params = model.init(jax.random.PRNGKey(0), dummy_input) |
|
|
|
|
|
|
|
|
param_count = sum( |
|
|
np.prod(p.shape) for p in jax.tree_util.tree_leaves(params) |
|
|
) |
|
|
|
|
|
|
|
|
summary = f"Model summary:\n" |
|
|
summary += f" Input shape: {input_shape}\n" |
|
|
summary += f" Parameter count: {param_count:,}\n" |
|
|
|
|
|
|
|
|
print(summary) |
|
|
|
|
|
|
|
|
if logger is not None: |
|
|
logger.info(summary) |
|
|
|