|
|
""" |
|
|
A number of functions that help with evaluating a base model. |
|
|
""" |
|
|
import math |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate_bpb(model, batches, steps, token_bytes): |
|
|
""" |
|
|
Instead of the naive 'mean loss', this function returns the bits per byte (bpb), |
|
|
which is a tokenization vocab size-indepedent metric, meaning you are still comparing |
|
|
apples:apples if you change the vocab size. The way this works is that instead of just |
|
|
calculating the average loss as usual, you calculate the sum loss, and indepependently |
|
|
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by |
|
|
the number of bytes that the target tokens represent. |
|
|
|
|
|
The added complexity is so that: |
|
|
1) All "normal" tokens are normalized by the length of the token in bytes |
|
|
2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out. |
|
|
3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric. |
|
|
|
|
|
In addition to evaluate_loss, we need the token_bytes tensor: |
|
|
It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for |
|
|
each token id, or 0 if the token is to not be counted (e.g. special tokens). |
|
|
""" |
|
|
|
|
|
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device()) |
|
|
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) |
|
|
batch_iter = iter(batches) |
|
|
for _ in range(steps): |
|
|
x, y = next(batch_iter) |
|
|
loss2d = model(x, y, loss_reduction='none') |
|
|
loss2d = loss2d.view(-1) |
|
|
y = y.view(-1) |
|
|
if (y < 0).any(): |
|
|
|
|
|
|
|
|
valid = y >= 0 |
|
|
y_safe = torch.where(valid, y, torch.zeros_like(y)) |
|
|
|
|
|
num_bytes2d = torch.where( |
|
|
valid, |
|
|
token_bytes[y_safe], |
|
|
torch.zeros_like(y, dtype=token_bytes.dtype) |
|
|
) |
|
|
total_nats += (loss2d * (num_bytes2d > 0)).sum() |
|
|
total_bytes += num_bytes2d.sum() |
|
|
else: |
|
|
|
|
|
num_bytes2d = token_bytes[y] |
|
|
total_nats += (loss2d * (num_bytes2d > 0)).sum() |
|
|
total_bytes += num_bytes2d.sum() |
|
|
|
|
|
world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
if world_size > 1: |
|
|
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM) |
|
|
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM) |
|
|
|
|
|
total_nats = total_nats.item() |
|
|
total_bytes = total_bytes.item() |
|
|
bpb = total_nats / (math.log(2) * total_bytes) |
|
|
return bpb |
|
|
|