Over the last few weeks, I tired to dig deeply into tokenization and learn how to build a BPE tokenizer from the ground up, I will document here both my understanding of tokenization and how to implement a performant version.
I like to think of tokenization as a trade-off between computational efficiency and learnability.
Efficiency vs. Meaning
At its core, tokenization involves creating a vocabulary (a look-up table) that maps chunks of text to integer IDs. The choice of what constitutes a "chunk" (The granularity of text to use) is very important.
Efficiency
Transformer models, the architecture behind most LLMs, have a computational complexity that scales quadratically with the sequence length ().
This puts a huge premium on representing text with the fewest tokens possible. We measure this using the Bytes per Token ratio:
A lower ratio means more compression and shorter sequences, leading to faster training, quicker inference, and the ability to handle longer contexts.
Let's consider an example sentence: The apparent meaning is for the masses, and the inner meaning is for the learned.
(93 bytes)
- Word-level tokenization is highly efficient. It splits the sentence into 15 tokens, giving us a great
6.2 Bytes/Token
ratio. But it's brittle. The vocabulary explodes with morphological variants ("run", "runs", "running"), and it completely fails on any word not seen during training (Out-of-Vocabulary or OOV), forcing us to use a generic<UNK>
token that throws away information. - Byte-level tokenization solves the OOV problem perfectly—any text can be represented as a sequence of its 256 possible bytes. However, it's incredibly inefficient. Our 93-byte sentence becomes 93 tokens, yielding a
1.0 Bytes/Token
ratio. This is no compression at all, leading to painfully long sequences.
We need a middle ground: great compression without the OOV problem. This is where subword tokenization comes in.
The Learning Problem:
Beyond efficiency, tokenization injects a powerful inductive bias into the model.
-
Byte-level tokenization offers a weak inductive bias. We're telling the model, "Here's a long stream of numbers; figure out from scratch that
(100, 111, 103)
represents the concept of a 'dog'." This requires a massive amount of data and computation for the model to learn basic linguistic structures. -
Subword tokenization provides a stronger, more helpful bias. By pre-chunking text into statistically common units (like
token
andization
), we give the model a head start. We're hinting that these chunks are meaningful building blocks. It’s analogous to how a Vision Transformer first breaks an image into patches instead of processing one pixel at a time.
Both efficiency and learnability point to subword tokenization as the superior strategy. The most popular algorithm for achieving this is Byte-Pair Encoding (BPE).
Byte-Pair Encoding (BPE): A Naive Implementation
BPE is a simple, greedy algorithm. It starts with the most basic units (bytes) and iteratively merges the most frequent adjacent pairs to build a vocabulary of meaningful subwords. Let's build a BPE tokenizer from the ground up to see exactly how it works.
Step 1: Vocabulary Initialization
We start with a base vocabulary of all 256 possible byte values. This ensures we have 100% coverage, any text can be represented.
vocab = {i : bytes([i]) for i in range(256)}
Step 2: Special Tokens
Next, we inject special tokens like <|endoftext|>
. These are treated as indivisible atomic units and are excluded from the merge process to preserve their role as control signals.
for i, token in enumerate(special_tokens):
vocab[256 + i] = token.encode("utf-8")
Step 3: Pre-tokenization
Feeding raw text directly to the BPE algorithm is a mistake. To guide the algorithm, we first perform pre-tokenization, splitting the input corpus into "word-like" units using a regular expression. This ensures merges respect linguistic boundaries.
Pre-tokenization is often ignored in discussions about BPE, but it plays a crucial role in shaping what the algorithm sees. A simple difference in the regex can lead to substantive differences in the final vocabulary.
import re
def pre_tokenization(training_data, special_tokens):
# GPT-2's regex pattern for splitting text into initial chunks.
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{N}+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
if not special_tokens:
return [m.group() for m in re.finditer(PAT, training_data)]
# Split the corpus by special tokens to prevent merges across them.
escaped_list = [re.escape(special_token) for special_token in special_tokens]
split_PAT = r"({})".format("|".join(escaped_list))
split_corpus = re.split(split_PAT, training_data)
pretokenized_train_data = []
for segment in split_corpus:
if not segment: continue
# If the segment is a special token itself, add it directly.
if segment in special_tokens:
pretokenized_train_data.append(segment)
else:
# Otherwise, apply the main regex pattern.
for m in re.finditer(PAT, segment):
pretokenized_train_data.append(m.group())
return pretokenized_train_data
Step 4: Iterative Merging
With the pre-tokenized corpus, the core BPE process begins. We first count the occurrences of each pre-tokenized "word" and represent them as tuples of bytes.
from collections import Counter
merges = []
pretokenized_train_data = pre_tokenization(training_data, special_tokens)
# Count initial "word" frequencies
word_counts = Counter(pretokenized_train_data)
# Convert words to tuples of bytes and store their counts
constructed_vocab = {}
for word, count in word_counts.items():
if len(word) == 1: continue
constructed_vocab[tuple(word.encode("utf-8"))] = count
Now, we iteratively find the most frequent pair and merge it.
# vocab_size is a hyperparameter.
token_id = len(vocab) # Start ID for new merged tokens
while token_id < vocab_size:
best_pair = find_best_pair(constructed_vocab)
merge_best_pair(constructed_vocab, best_pair, token_id)
# Bookkeeping
merges.append(best_pair)
vocab[token_id] = vocab[best_pair[0]] + vocab[best_pair[1]] # Or raw bytes
token_id += 1
The heavy lifting happens in helper functions. find_best_pair
scans all words to count all adjacent pairs.
def find_best_pair(constructed_vocab):
potential_merges_count = {}
for word_tuple, count in constructed_vocab.items():
for i in range(len(word_tuple) - 1):
key = (word_tuple[i], word_tuple[i+1])
potential_merges_count[key] = potential_merges_count.get(key, 0) + count
if not potential_merges_count: return None
# Tie-breaking rule: if frequencies are equal, choose the lexicographically
# greatest pair based on their string representation.
max_freq = max(potential_merges_count.values())
candidate_pairs = [k for k, v in potential_merges_count.items() if v == max_freq]
# This requires a mapping from token ID back to its string/byte representation
# to perform the comparison correctly.
best_pair = max(candidate_pairs)
return best_pair
Gotcha: Tie-breaking isn't just an edge case. Consistent tie-breaking (e.g., by lexicographical order of the merged token bytes) is essential for reproducing tokenizers like TikToken exactly. This subtle detail took me longer than I care to admit to get right.
After finding the best pair, merge_best_pair
replaces it across the entire vocabulary. This involves scanning every single word again.
# Helper to merge a pair within a single word tuple
def merge_in_word(word_tuple, best_pair, new_token_id):
new_word_list = []
k = 0
while k < len(word_tuple):
if k < len(word_tuple) - 1 and (word_tuple[k], word_tuple[k+1]) == best_pair:
new_word_list.append(new_token_id)
k += 2
else:
new_word_list.append(word_tuple[k])
k += 1
return tuple(new_word_list)
# Main function to apply the merge across the whole vocabulary
def merge_best_pair(constructed_vocab, best_pair, new_token_id):
to_add = {}
to_delete = []
for word_tuple, count in constructed_vocab.items():
# A simple check to see if the pair exists in the word
if any(word_tuple[i:i+2] == best_pair for i in range(len(word_tuple) - 1)):
new_word_tuple = merge_in_word(word_tuple, best_pair, new_token_id)
to_add[new_word_tuple] = to_add.get(new_word_tuple, 0) + count
to_delete.append(word_tuple)
for word in to_delete:
del constructed_vocab[word]
constructed_vocab.update(to_add)
This naive implementation works, but it's incredibly slow. A run on a tiny 65kb text file takes ~3 seconds. To scale, we need to optimize.
Optimizing BPE:
Profiling reveals two major bottlenecks: the repeated scanning in the merge loop and the single-threaded pre-tokenization.
Bottleneck 1: Optimizing the Merge Loop
In our naive approach, finding the best_pair
and updating the corpus requires iterating over our entire word count dictionary again and again.
The Fix: Instead of re-scanning, we use smarter data structures. We'll maintain a dictionary of potential_merges
(counts of all pairs) and another, bigram_locations
, that maps each pair to the set of words it appears in.
When we merge ('t', 'h')
-> 257
(let's call it 'th'
), we know the only counts that can change are those involving pairs immediately adjacent to it. For a word like ('t', 'h', 'e')
:
- The count for
('t', 'h')
decreases. - The count for
('h', 'e')
decreases. - A new count for
(257, 'e')
is created and increases.
By only updating the stats for the words directly affected by the merge (which we can look up instantly in bigram_locations
), we avoid redundant work.
The logic becomes:
- Get the list of all words affected by the
best_pair
frombigram_locations
. - For each affected word: a. Decrement the counts of all pairs in the old version of the word. b. Create the new version of the word by merging the pair. c. Increment the counts of all pairs in the new version. d. Update our main vocabulary count and bigram_locations.
Here are the key helper functions for this efficient approach:
from collections import defaultdict
# ... (after pre-tokenization and initial word counts) ...
potential_merges = defaultdict(int)
bigram_locations = defaultdict(set)
for word_tuple, count in constructed_vocab.items():
for i in range(len(word_tuple) - 1):
bigram = (word_tuple[i], word_tuple[i + 1])
potential_merges[bigram] += count
bigram_locations[bigram].add(word_tuple)
# --- In the main loop ---
# words_affected = list(bigram_locations[best_pair])
# for word_tuple in words_affected:
# if word_tuple not in constructed_vocab: continue
# count = constructed_vocab[word_tuple]
# decrement_counts(...)
# new_word_tuple = merge_in_word(...)
# increment_counts(...)
# ... (bookkeeping) ...
def decrement_counts(word_tuple, potential_merges, bigram_locations, count):
for j in range(len(word_tuple) - 1):
bigram = (word_tuple[j], word_tuple[j + 1])
potential_merges[bigram] -= count
if bigram in bigram_locations and word_tuple in bigram_locations[bigram]:
bigram_locations[bigram].discard(word_tuple)
if not bigram_locations[bigram]:
del bigram_locations[bigram]
if potential_merges.get(bigram, 0) <= 0:
potential_merges.pop(bigram, None)
def increment_counts(new_word_tuple, potential_merges, bigram_locations, count):
for j in range(len(new_word_tuple) - 1):
bigram = (new_word_tuple[j], new_word_tuple[j + 1])
potential_merges[bigram] += count
bigram_locations[bigram].add(new_word_tuple)
This bookkeeping is more complex, but it makes each merge step an efficient, localized update instead of a global scan.
Bottleneck 2: Parallelizing Pre-tokenization
For large datasets, just reading the file and running the regex is a significant, single-threaded bottleneck.
The Fix: This task is "embarrassingly parallel." We can split our large text file into multiple chunks and process them across all available CPU cores using Python's multiprocessing
.
The trick is splitting the file cleanly. We can't just cut at random byte offsets; we might slice a multi-byte character in half. A robust solution is to find "clean seams"—like a special token such as b"<|endoftext|>"
—and split the file there.
import os
from typing import BinaryIO
def find_chunk_boundaries(file: BinaryIO, desired_num_chunks: int, split_special_token: bytes) -> list[int]:
file.seek(0, os.SEEK_END)
file_size = file.tell()
file.seek(0)
chunk_size = file_size // desired_num_chunks
chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
chunk_boundaries[-1] = file_size
mini_chunk_size = 4096
for i in range(1, len(chunk_boundaries) - 1):
initial_position = chunk_boundaries[i]
file.seek(initial_position)
while True:
mini_chunk = file.read(mini_chunk_size)
if mini_chunk == b"":
chunk_boundaries[i] = file_size
break
found_at = mini_chunk.find(split_special_token)
if found_at != -1:
chunk_boundaries[i] = initial_position + found_at
break
initial_position += mini_chunk_size
return sorted(list(set(chunk_boundaries)))
With these boundaries, we can spin up a pool of processes to handle the pre-tokenization and counting in parallel.
import multiprocessing as mp
def process_chunk(file_path, start, end, special_tokens):
with open(file_path, "rb") as f:
f.seek(start)
text = f.read(end - start).decode("utf-8", errors="ignore")
tokens = pre_tokenization(text, special_tokens)
return Counter(tokens)
word_counts = Counter()
with open(file_path, 'rb') as f:
boundaries = find_chunk_boundaries(f, desired_num_chunks=os.cpu_count(), split_special_token=b"<|endoftext|>")
chunks = [(file_path, start, end, special_tokens) for start, end in zip(boundaries[:-1], boundaries[1:])]
with mp.Pool(processes=os.cpu_count()) as pool:
results = pool.starmap(process_chunk, chunks)
for result in results:
word_counts.update(result)
Crucial Engineering Notes:
- We use
multiprocessing
, notthreading
, because of Python's Global Interpreter Lock (GIL), which prevents true parallelism for CPU-bound tasks in threads.- This works best on Linux/macOS. The process creation overhead on Windows (
spawn
vs.fork
) can make multiprocessing slower for smaller tasks.
With these two optimizations, our BPE trainer goes from 3 seconds to 0.3 seconds on the same 65kb file, a speedup that scales to massive datasets.