File size: 5,130 Bytes
f24563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""
Logging utilities for LLM training.
"""

import os
import time
import json
import logging
from typing import Dict, Any, Optional, List
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf


def setup_logger(
    name: str,
    log_file: Optional[str] = None,
    level: int = logging.INFO
) -> logging.Logger:
    """
    Set up logger.
    
    Args:
        name: Logger name
        log_file: Path to log file
        level: Logging level
        
    Returns:
        Logger
    """
    # Create logger
    logger = logging.getLogger(name)
    logger.setLevel(level)
    
    # Create formatter
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    
    # Create console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(level)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    
    # Create file handler if log file is provided
    if log_file is not None:
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(log_file), exist_ok=True)
        
        # Create file handler
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(level)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    
    return logger


def log_metrics(
    metrics: Dict[str, Any],
    step: int,
    logger: Optional[logging.Logger] = None,
    prefix: str = "",
    log_to_console: bool = True
) -> None:
    """
    Log metrics.
    
    Args:
        metrics: Dictionary of metrics
        step: Training step
        logger: Logger
        prefix: Prefix for metric names
        log_to_console: Whether to log to console
    """
    # Add prefix to metric names
    if prefix:
        metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}
    
    # Convert metrics to Python types
    metrics = {
        k: float(v) if isinstance(v, (np.ndarray, jnp.ndarray)) else v
        for k, v in metrics.items()
    }
    
    # Log to console
    if log_to_console:
        print(f"Step {step}:")
        for k, v in metrics.items():
            if isinstance(v, float):
                print(f"  {k}: {v:.4f}")
            else:
                print(f"  {k}: {v}")
    
    # Log to logger
    if logger is not None:
        logger.info(f"Step {step}: {metrics}")


def create_summary_writer(log_dir: str) -> tf.summary.SummaryWriter:
    """
    Create TensorBoard summary writer.
    
    Args:
        log_dir: Directory for TensorBoard logs
        
    Returns:
        TensorBoard summary writer
    """
    # Create directory if it doesn't exist
    os.makedirs(log_dir, exist_ok=True)
    
    # Create summary writer
    return tf.summary.create_file_writer(log_dir)


def log_metrics_to_tensorboard(
    metrics: Dict[str, Any],
    step: int,
    writer: tf.summary.SummaryWriter,
    prefix: str = ""
) -> None:
    """
    Log metrics to TensorBoard.
    
    Args:
        metrics: Dictionary of metrics
        step: Training step
        writer: TensorBoard summary writer
        prefix: Prefix for metric names
    """
    # Add prefix to metric names
    if prefix:
        metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}
    
    # Convert metrics to Python types
    metrics = {
        k: float(v) if isinstance(v, (np.ndarray, jnp.ndarray)) else v
        for k, v in metrics.items()
    }
    
    # Log metrics to TensorBoard
    with writer.as_default():
        for k, v in metrics.items():
            if isinstance(v, float):
                tf.summary.scalar(k, v, step=step)
            elif isinstance(v, (list, tuple)) and all(isinstance(x, float) for x in v):
                tf.summary.histogram(k, v, step=step)
    
    # Flush writer
    writer.flush()


def log_text_to_tensorboard(
    text: str,
    tag: str,
    step: int,
    writer: tf.summary.SummaryWriter
) -> None:
    """
    Log text to TensorBoard.
    
    Args:
        text: Text to log
        tag: Tag for text
        step: Training step
        writer: TensorBoard summary writer
    """
    # Log text to TensorBoard
    with writer.as_default():
        tf.summary.text(tag, text, step=step)
    
    # Flush writer
    writer.flush()


def log_model_summary(
    model: Any,
    input_shape: tuple,
    logger: Optional[logging.Logger] = None
) -> None:
    """
    Log model summary.
    
    Args:
        model: Model
        input_shape: Input shape
        logger: Logger
    """
    # Create dummy input
    dummy_input = jnp.ones(input_shape, dtype=jnp.int32)
    
    # Initialize model
    params = model.init(jax.random.PRNGKey(0), dummy_input)
    
    # Count parameters
    param_count = sum(
        np.prod(p.shape) for p in jax.tree_util.tree_leaves(params)
    )
    
    # Log model summary
    summary = f"Model summary:\n"
    summary += f"  Input shape: {input_shape}\n"
    summary += f"  Parameter count: {param_count:,}\n"
    
    # Log to console
    print(summary)
    
    # Log to logger
    if logger is not None:
        logger.info(summary)