loocorez's picture
Upload folder using huggingface_hub
4d308e1 verified
"""
A number of functions that help with evaluating a base model.
"""
import math
import torch
import torch.distributed as dist
@torch.no_grad()
def evaluate_bpb(model, batches, steps, token_bytes):
"""
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
apples:apples if you change the vocab size. The way this works is that instead of just
calculating the average loss as usual, you calculate the sum loss, and indepependently
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
the number of bytes that the target tokens represent.
The added complexity is so that:
1) All "normal" tokens are normalized by the length of the token in bytes
2) No special tokens (e.g. <|bos|>) are included in the metric - they are masked out.
3) No actively masked tokens (using ignore_index of e.g. -1) are included in the metric.
In addition to evaluate_loss, we need the token_bytes tensor:
It is a 1D tensor of shape (vocab_size,), indicating the number of bytes for
each token id, or 0 if the token is to not be counted (e.g. special tokens).
"""
# record the losses
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
batch_iter = iter(batches)
for _ in range(steps):
x, y = next(batch_iter)
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (y < 0).any():
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
valid = y >= 0
y_safe = torch.where(valid, y, torch.zeros_like(y))
# map valid targets to their byte length; ignored targets contribute 0 bytes
num_bytes2d = torch.where(
valid,
token_bytes[y_safe],
torch.zeros_like(y, dtype=token_bytes.dtype)
)
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
else:
# fast path: no ignored targets, safe to index directly
num_bytes2d = token_bytes[y]
total_nats += (loss2d * (num_bytes2d > 0)).sum()
total_bytes += num_bytes2d.sum()
# sum reduce across all ranks
world_size = dist.get_world_size() if dist.is_initialized() else 1
if world_size > 1:
dist.all_reduce(total_nats, op=dist.ReduceOp.SUM)
dist.all_reduce(total_bytes, op=dist.ReduceOp.SUM)
# move both to cpu, calculate bpb and return
total_nats = total_nats.item()
total_bytes = total_bytes.item()
bpb = total_nats / (math.log(2) * total_bytes)
return bpb