How to build an efficient BPE tokenizer in python

June 20, 2025 | 25 min read

TokenizationNLPLanguage Modelling

Over the last few weeks, I tired to dig deeply into tokenization and learn how to build a BPE tokenizer from the ground up, I will document here both my understanding of tokenization and how to implement a performant version.

I like to think of tokenization as a trade-off between computational efficiency and learnability.


Efficiency vs. Meaning

At its core, tokenization involves creating a vocabulary (a look-up table) that maps chunks of text to integer IDs. The choice of what constitutes a "chunk" (The granularity of text to use) is very important.

Efficiency

Transformer models, the architecture behind most LLMs, have a computational complexity that scales quadratically with the sequence length (O(n2)O(n^2)).

This puts a huge premium on representing text with the fewest tokens possible. We measure this using the Bytes per Token ratio:

Bytes per Token=Number of TokensNumber of Bytes in Text​\text{Bytes per Token} = \frac{\text{Number of Tokens}} {\text{Number of Bytes in Text}​}

A lower ratio means more compression and shorter sequences, leading to faster training, quicker inference, and the ability to handle longer contexts.

Let's consider an example sentence: The apparent meaning is for the masses, and the inner meaning is for the learned. (93 bytes)

  • Word-level tokenization is highly efficient. It splits the sentence into 15 tokens, giving us a great 6.2 Bytes/Token ratio. But it's brittle. The vocabulary explodes with morphological variants ("run", "runs", "running"), and it completely fails on any word not seen during training (Out-of-Vocabulary or OOV), forcing us to use a generic <UNK> token that throws away information.
  • Byte-level tokenization solves the OOV problem perfectly—any text can be represented as a sequence of its 256 possible bytes. However, it's incredibly inefficient. Our 93-byte sentence becomes 93 tokens, yielding a 1.0 Bytes/Token ratio. This is no compression at all, leading to painfully long sequences.

We need a middle ground: great compression without the OOV problem. This is where subword tokenization comes in.

The Learning Problem:

Beyond efficiency, tokenization injects a powerful inductive bias into the model.

  • Byte-level tokenization offers a weak inductive bias. We're telling the model, "Here's a long stream of numbers; figure out from scratch that (100, 111, 103) represents the concept of a 'dog'." This requires a massive amount of data and computation for the model to learn basic linguistic structures.

  • Subword tokenization provides a stronger, more helpful bias. By pre-chunking text into statistically common units (like token and ization), we give the model a head start. We're hinting that these chunks are meaningful building blocks. It’s analogous to how a Vision Transformer first breaks an image into patches instead of processing one pixel at a time.

Both efficiency and learnability point to subword tokenization as the superior strategy. The most popular algorithm for achieving this is Byte-Pair Encoding (BPE).


Byte-Pair Encoding (BPE): A Naive Implementation

BPE is a simple, greedy algorithm. It starts with the most basic units (bytes) and iteratively merges the most frequent adjacent pairs to build a vocabulary of meaningful subwords. Let's build a BPE tokenizer from the ground up to see exactly how it works.

Step 1: Vocabulary Initialization

We start with a base vocabulary of all 256 possible byte values. This ensures we have 100% coverage, any text can be represented.

vocab = {i : bytes([i]) for i in range(256)}

Step 2: Special Tokens

Next, we inject special tokens like <|endoftext|>. These are treated as indivisible atomic units and are excluded from the merge process to preserve their role as control signals.

for i, token in enumerate(special_tokens):
    vocab[256 + i] = token.encode("utf-8")

Step 3: Pre-tokenization

Feeding raw text directly to the BPE algorithm is a mistake. To guide the algorithm, we first perform pre-tokenization, splitting the input corpus into "word-like" units using a regular expression. This ensures merges respect linguistic boundaries.

Pre-tokenization is often ignored in discussions about BPE, but it plays a crucial role in shaping what the algorithm sees. A simple difference in the regex can lead to substantive differences in the final vocabulary.

import re

def pre_tokenization(training_data, special_tokens):
    # GPT-2's regex pattern for splitting text into initial chunks.
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{N}+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    if not special_tokens:
        return [m.group() for m in re.finditer(PAT, training_data)]

    # Split the corpus by special tokens to prevent merges across them.
    escaped_list = [re.escape(special_token) for special_token in special_tokens]
    split_PAT = r"({})".format("|".join(escaped_list))
    split_corpus = re.split(split_PAT, training_data)

    pretokenized_train_data = []
    for segment in split_corpus:
        if not segment: continue
        # If the segment is a special token itself, add it directly.
        if segment in special_tokens:
            pretokenized_train_data.append(segment)
        else:
            # Otherwise, apply the main regex pattern.
            for m in re.finditer(PAT, segment):
                pretokenized_train_data.append(m.group())

    return pretokenized_train_data

Step 4: Iterative Merging

With the pre-tokenized corpus, the core BPE process begins. We first count the occurrences of each pre-tokenized "word" and represent them as tuples of bytes.

from collections import Counter

merges = []
pretokenized_train_data = pre_tokenization(training_data, special_tokens)
# Count initial "word" frequencies
word_counts = Counter(pretokenized_train_data)

# Convert words to tuples of bytes and store their counts
constructed_vocab = {}
for word, count in word_counts.items():
    if len(word) == 1: continue
    constructed_vocab[tuple(word.encode("utf-8"))] = count

Now, we iteratively find the most frequent pair and merge it.

# vocab_size is a hyperparameter.
token_id = len(vocab) # Start ID for new merged tokens

while token_id < vocab_size:
    best_pair = find_best_pair(constructed_vocab)
    merge_best_pair(constructed_vocab, best_pair, token_id)

    # Bookkeeping
    merges.append(best_pair)
    vocab[token_id] = vocab[best_pair[0]] + vocab[best_pair[1]] # Or raw bytes
    token_id += 1

The heavy lifting happens in helper functions. find_best_pair scans all words to count all adjacent pairs.

def find_best_pair(constructed_vocab):
    potential_merges_count = {}
    for word_tuple, count in constructed_vocab.items():
        for i in range(len(word_tuple) - 1):
            key = (word_tuple[i], word_tuple[i+1])
            potential_merges_count[key] = potential_merges_count.get(key, 0) + count

    if not potential_merges_count: return None

    # Tie-breaking rule: if frequencies are equal, choose the lexicographically
    # greatest pair based on their string representation.
    max_freq = max(potential_merges_count.values())
    candidate_pairs = [k for k, v in potential_merges_count.items() if v == max_freq]
    # This requires a mapping from token ID back to its string/byte representation
    # to perform the comparison correctly.
    best_pair = max(candidate_pairs)
    return best_pair

Gotcha: Tie-breaking isn't just an edge case. Consistent tie-breaking (e.g., by lexicographical order of the merged token bytes) is essential for reproducing tokenizers like TikToken exactly. This subtle detail took me longer than I care to admit to get right.

After finding the best pair, merge_best_pair replaces it across the entire vocabulary. This involves scanning every single word again.

# Helper to merge a pair within a single word tuple
def merge_in_word(word_tuple, best_pair, new_token_id):
    new_word_list = []
    k = 0
    while k < len(word_tuple):
        if k < len(word_tuple) - 1 and (word_tuple[k], word_tuple[k+1]) == best_pair:
            new_word_list.append(new_token_id)
            k += 2
        else:
            new_word_list.append(word_tuple[k])
            k += 1
    return tuple(new_word_list)

# Main function to apply the merge across the whole vocabulary
def merge_best_pair(constructed_vocab, best_pair, new_token_id):
    to_add = {}
    to_delete = []
    for word_tuple, count in constructed_vocab.items():
        # A simple check to see if the pair exists in the word
        if any(word_tuple[i:i+2] == best_pair for i in range(len(word_tuple) - 1)):
            new_word_tuple = merge_in_word(word_tuple, best_pair, new_token_id)
            to_add[new_word_tuple] = to_add.get(new_word_tuple, 0) + count
            to_delete.append(word_tuple)

    for word in to_delete:
        del constructed_vocab[word]
    constructed_vocab.update(to_add)

This naive implementation works, but it's incredibly slow. A run on a tiny 65kb text file takes ~3 seconds. To scale, we need to optimize.


Optimizing BPE:

Profiling reveals two major bottlenecks: the repeated scanning in the merge loop and the single-threaded pre-tokenization.

Bottleneck 1: Optimizing the Merge Loop

In our naive approach, finding the best_pair and updating the corpus requires iterating over our entire word count dictionary again and again.

The Fix: Instead of re-scanning, we use smarter data structures. We'll maintain a dictionary of potential_merges (counts of all pairs) and another, bigram_locations, that maps each pair to the set of words it appears in.

When we merge ('t', 'h') -> 257 (let's call it 'th'), we know the only counts that can change are those involving pairs immediately adjacent to it. For a word like ('t', 'h', 'e'):

  • The count for ('t', 'h') decreases.
  • The count for ('h', 'e') decreases.
  • A new count for (257, 'e') is created and increases.

By only updating the stats for the words directly affected by the merge (which we can look up instantly in bigram_locations), we avoid redundant work.

The logic becomes:

  1. Get the list of all words affected by the best_pair from bigram_locations.
  2. For each affected word: a. Decrement the counts of all pairs in the old version of the word. b. Create the new version of the word by merging the pair. c. Increment the counts of all pairs in the new version. d. Update our main vocabulary count and bigram_locations.

Here are the key helper functions for this efficient approach:

from collections import defaultdict

# ... (after pre-tokenization and initial word counts) ...
potential_merges = defaultdict(int)
bigram_locations = defaultdict(set)

for word_tuple, count in constructed_vocab.items():
    for i in range(len(word_tuple) - 1):
        bigram = (word_tuple[i], word_tuple[i + 1])
        potential_merges[bigram] += count
        bigram_locations[bigram].add(word_tuple)

# --- In the main loop ---
# words_affected = list(bigram_locations[best_pair])
# for word_tuple in words_affected:
#   if word_tuple not in constructed_vocab: continue
#   count = constructed_vocab[word_tuple]
#   decrement_counts(...)
#   new_word_tuple = merge_in_word(...)
#   increment_counts(...)
#   ... (bookkeeping) ...

def decrement_counts(word_tuple, potential_merges, bigram_locations, count):
    for j in range(len(word_tuple) - 1):
        bigram = (word_tuple[j], word_tuple[j + 1])
        potential_merges[bigram] -= count

        if bigram in bigram_locations and word_tuple in bigram_locations[bigram]:
            bigram_locations[bigram].discard(word_tuple)
            if not bigram_locations[bigram]:
                del bigram_locations[bigram]

        if potential_merges.get(bigram, 0) <= 0:
            potential_merges.pop(bigram, None)

def increment_counts(new_word_tuple, potential_merges, bigram_locations, count):
    for j in range(len(new_word_tuple) - 1):
        bigram = (new_word_tuple[j], new_word_tuple[j + 1])
        potential_merges[bigram] += count
        bigram_locations[bigram].add(new_word_tuple)

This bookkeeping is more complex, but it makes each merge step an efficient, localized update instead of a global scan.

Bottleneck 2: Parallelizing Pre-tokenization

For large datasets, just reading the file and running the regex is a significant, single-threaded bottleneck.

The Fix: This task is "embarrassingly parallel." We can split our large text file into multiple chunks and process them across all available CPU cores using Python's multiprocessing.

The trick is splitting the file cleanly. We can't just cut at random byte offsets; we might slice a multi-byte character in half. A robust solution is to find "clean seams"—like a special token such as b"<|endoftext|>"—and split the file there.

import os
from typing import BinaryIO

def find_chunk_boundaries(file: BinaryIO, desired_num_chunks: int, split_special_token: bytes) -> list[int]:
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096

    for i in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[i]
        file.seek(initial_position)
        while True:
            mini_chunk = file.read(mini_chunk_size)
            if mini_chunk == b"":
                chunk_boundaries[i] = file_size
                break

            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[i] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    return sorted(list(set(chunk_boundaries)))

With these boundaries, we can spin up a pool of processes to handle the pre-tokenization and counting in parallel.

import multiprocessing as mp

def process_chunk(file_path, start, end, special_tokens):
    with open(file_path, "rb") as f:
        f.seek(start)
        text = f.read(end - start).decode("utf-8", errors="ignore")
        tokens = pre_tokenization(text, special_tokens)
        return Counter(tokens)

word_counts = Counter()
with open(file_path, 'rb') as f:
    boundaries = find_chunk_boundaries(f, desired_num_chunks=os.cpu_count(), split_special_token=b"<|endoftext|>")

chunks = [(file_path, start, end, special_tokens) for start, end in zip(boundaries[:-1], boundaries[1:])]

with mp.Pool(processes=os.cpu_count()) as pool:
    results = pool.starmap(process_chunk, chunks)

for result in results:
    word_counts.update(result)

Crucial Engineering Notes:

  • We use multiprocessing, not threading, because of Python's Global Interpreter Lock (GIL), which prevents true parallelism for CPU-bound tasks in threads.
  • This works best on Linux/macOS. The process creation overhead on Windows (spawn vs. fork) can make multiprocessing slower for smaller tasks.

With these two optimizations, our BPE trainer goes from 3 seconds to 0.3 seconds on the same 65kb file, a 10×10\times speedup that scales to massive datasets.