Karpathy recently released nanochat repo which cotains code for training the best ChatGPT under $100. While skimming the high level code, I noticed across bits per bytes instead of typical cross entropy loss. And, i found it interesting, so i decided to dig in.

TL;DR

  • Bit per byte (BPB) is just cross-entropy measured per byte. We divide cross-entropy by log(2) to convert to bits.
  • Because it’s per byte, BPB is tokenizer-agnostic and lets you compare models fairly even when they use different vocabularies and rules.
  • Perplexity and token-level loss change when you change the tokenizer; BPB largely doesn’t.

LLM doesn’t predict the text, it predicts the (next) token. But token definitions depend on the tokenizer (BPE, Unigram, merges, special tokens, etc.). Swap tokenizers and the same sentence can become more or fewer tokens. So per-token metrics (avg CE, perplexity) change even if the underlying modeling quality didn’t.

Some popular tokenizer choices are:

Model Tokenizer Vocab Size
GPT-4 cl100k_base (BPE) 100,256
LLaMA 3 TikToken (BPE) 128,000
Gemini 2.5 SentencePiece (Unigram) 256,000
Claude closed-source undisclosed

Different tokenizers ≠ comparable “tokens”. So a model that uses a coarser tokenizer (fewer, longer tokens) can appear to have a lower per-token loss or perplexity, simply because the denominator changed.

Instead of normalizing loss per token, normalize per byte of UTF-8 text that those tokens represent. Then, no matter how you split words into tokens, you’re still asking: how many bits, on average, does the model need to encode each byte of text?

Below is the simplified and more readable version of the original code.

import math
import torch
import torch.distributed as dist

@torch.no_grad()
def evaluate_bpb(model, batches, steps: int, token_bytes: torch.Tensor) -> float:
    """
    Compute Bits-Per-Byte (BPB) over `steps` batches.

    Shapes (your mental model):
      B  = batch size
      Seq = sequence length
      V  = vocab size

    Inputs:
      - model: callable like model(x, y, loss_reduction='none') -> loss per token.
               Expects:
                 x: (B, Seq) token ids (int64)
                 y: (B, Seq) target token ids (int64), may contain ignore_index (<0)
               Returns:
                 loss2d: (B, Seq) per-token loss in NATs (float32/float16)
      - batches: iterable yielding (x, y) as above.
      - steps: number of batches to evaluate.
      - token_bytes: (V,) int64 — byte length of each token id; 0 for special tokens
                     (those should not count toward BPB).

    Notes:
      - BPB = (sum of losses in NATs over *counted* tokens) / (ln(2) * total_counted_bytes)
      - Tokens contribute to the denominator by their byte length; tokens with 0 bytes
        (specials) and ignored targets (<0) are excluded from both numerator & denominator.
    """
    device = model.get_device() if hasattr(model, "get_device") else next(model.parameters()).device

    # Accumulators across steps (and later across ranks)
    sum_nats  = torch.tensor(0.0, dtype=torch.float32, device=device)  # scalar
    sum_bytes = torch.tensor(0,   dtype=torch.int64,   device=device)  # scalar

    token_bytes = token_bytes.to(device=device, dtype=torch.int64)     # (V,)

    batch_iter = iter(batches)
    for _ in range(steps):
        x, y = next(batch_iter)                  # x: (B, Seq), y: (B, Seq)
        x = x.to(device)
        y = y.to(device)

        loss2d = model(x, y, loss_reduction='none')  # (B, Seq) NATs
        loss1d = loss2d.reshape(-1)                  # (B*Seq,)
        y1d    = y.reshape(-1)                       # (B*Seq,)

        if (y1d < 0).any():
            # Mask out ignore_index (<0) before indexing into token_bytes
            valid  = (y1d >= 0)                                      # (B*Seq,)
            ysafe  = torch.where(valid, y1d, torch.zeros_like(y1d))  # (B*Seq,)
            nb     = torch.where(valid, token_bytes[ysafe], torch.zeros_like(y1d))  # (B*Seq,) int64
        else:
            nb = token_bytes[y1d]  # (B*Seq,) int64

        # Count only tokens with positive byte length
        counted = (nb > 0)                             # (B*Seq,) bool
        sum_nats  += (loss1d[counted]).sum()           # scalar
        sum_bytes += nb[counted].sum()                 # scalar int64

    # Distributed sum over all ranks, if initialized
    if dist.is_initialized() and dist.get_world_size() > 1:
        dist.all_reduce(sum_nats,  op=dist.ReduceOp.SUM)
        dist.all_reduce(sum_bytes, op=dist.ReduceOp.SUM)

    total_nats  = float(sum_nats.item())
    total_bytes = int(sum_bytes.item())

    # Guard against division by zero (e.g., all tokens were special/ignored)
    if total_bytes == 0:
        return float("nan")

    bpb = total_nats / (math.log(2.0) * total_bytes)
    return bpb