Build and Train a Mixture-of-Experts (MoE) LLM from Scratch
An end-to-end guide to training an Mixture-of-Experts (MoE) LLM from scratch.
Most modern-day LLMs use the Mixture of Experts (MoE) architecture. This includes Grok-1, DeepSeekMoE, gpt-oss, and Mixtral (and many other proprietary LLMs whose architectural details aren’t publicly available).
In the previous lessons on Into AI, we learned:
🌈 What the Mixture-of-Experts (MoE) architecture is and how it works
🎈Built the Mixture-of-Experts (MoE) decoder-only transformer from scratch
In this lesson, we will:
Build an LLM using the MoE transformer with Grouped Query Attention (GQA)
Train it on a publicly available database from Wikipedia
Generate text from the trained LLM
The following are the steps that we implement.
Let’s begin!
Import Necessary Packages
# Standard libraries
import math
import random
# Hide deprecation warnings
import warnings
warnings.filterwarnings('ignore')
# PyTorch core
import torch
import torch.nn as nn
import torch.optim as optim
# For processing data
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
# BPE tokenizer
import tiktoken
# Training utils
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
# Progress bar
from tqdm import tqdmBefore we move forward, I want to introduce you to the Visual Tech Bundle.
It is a collection of visual guides that explain core AI, LLM, Systems design, and Computer science concepts via image-first lessons.
Others are already loving these books.
This includes Dharmesh Shah, the co-founder and CEO of HubSpot.

❤️ I’m sure that you will love these too:
Now back to our lesson!
Getting our data ready
1. Download the dataset
We will use the WikiText dataset from Hugging Face to train our model. This dataset is derived from verified Wikipedia articles and contains approximately 103 million words.
It is downloaded as follows.
# Load WikiText dataset
train_data = load_dataset("wikitext", "wikitext-103-v1", split="train") # Train subset
val_data = load_dataset("wikitext", "wikitext-103-v1", split="validation") # Validation subsetLet’s check out a training example from this dataset.
# Training subset example
print(train_data['text'][807])
"""
Output:
The International Civil Aviation Organization ( ICAO ) defines general aviation ( GA )
as " an aircraft operation other than a commercial air transport operation
or an aerial work operation . " It defines commercial air transport ( CAT )
as " an aircraft operation involving the transport of passengers , cargo
or mail for remuneration or hire " , and aerial work as " an aircraft operation
in which an aircraft is used for specialized services such as agriculture ,
construction , photography , surveying , observation and patrol , search and
rescue , aerial advertisement , etc . "
"""There’s a little problem with this dataset where some text examples are blanks and headings are formatted as “ === HEADING === “, as shown below.
# Unwanted training subset example
print(train_data['text'][804:807])
"""
Output:
['', ' = = Definitions = = \n', '']
"""Alongside this, WikiText-103 was built with a vocabulary consisting of words (rather than sub-words/ characters), and any word that appeared fewer than three times was replaced with <unk>.
This is shown below.
# UNK marker
print(train_data['text'][80])
"""
Output:
<unk> <unk> 2
"""We need to get rid of these to clean up the dataset.
We also need an End-of-Text (EOS) token called <|endoftext|> between separate sections and articles, so that the model learns natural stopping boundaries and doesn't try to connect unrelated content.
(We will come back to re-using the EOS token when discussing the helper function for text generation/inference.)
Both of these changes are made using a helper function as follows.
def clean_text(dataset):
# EOS token
EOS = "<|endoftext|>"
cleaned = []
for text in dataset["text"]:
# Strip surrounding whitespace
line = text.strip()
# Keep only non-empty lines that aren't headings (such as "=== HEADING ===")
if line and not line.startswith("="):
# Remove <unk> markers and normalize whitespace
cleaned.append(" ".join(line.replace("<unk>", "").split()))
# Join all entries into a string, separated by EOS token
return EOS.join(cleaned) + EOS
# Clean training and validation text
training_text = clean_text(train_data)
validation_text = clean_text(val_data)A short subset of training_text is shown below.
# Print a subset of 'training_text'
print(training_text[3457:3700])
"""
Output:
Each character has a field and distance of movement limited by their
Action Gauge . Up to nine characters can be assigned to a single mission .
During gameplay , characters will call out if something happens to them ,
such as their health points ( HP ) getting low or being knocked out
"""2. Tokenize the dataset
We previously learned how to build a character-level tokenizer from scratch.
We have also used it to train an LLM from scratch.
But because it is character-based, training an LLM using it makes the model learn the meaning of text from individual letters rather than semantic units. This makes training painfully difficult.
For this tutorial, we will instead use Tiktoken, a fast BPE tokeniser used in OpenAI's models. It is a sub-word tokenizer (breaks down words into sub-words) that better captures language structure than a character-level tokenizer, enabling faster LLM training.

We start by creating an instance of Tiktoken as follows.
# Create an instance of Tokenizer used by GPT-2
tokenizer = tiktoken.get_encoding("gpt2")The vocabulary size when using the Tiktonen tokenizer for GPT-2 is 50,257. This means that the tokenizer recognizes 50,257 unique subwords.
# Check vocabulary size
vocab_size = tokenizer.n_vocab
print(f"Vocabulary size: {vocab_size}")
# Output: Vocabulary size: 502573. Create a dataset required for language modeling
Once we have tokenized our dataset, we need to load and serve the data during training. This is where the TextDataset class comes in, which inherits from PyTorch’s Dataset class, and has the following methods:
__init__: Tokenizes given text (tokens to token IDs) and sets the maximum sequence length (max_seq_length)__len__:Returns the number of training sequences available__getitem__: Returns a training sequence and its targets (tokens shifted forward by one position) at a given index
# Dataset for language modeling
class TextDataset(Dataset):
def __init__(self, text, tokenizer, max_seq_length):
# Convert text to token IDs
self.tokens = tokenizer.encode(text, allowed_special="all")
# Maximum length of each training sequence
self.max_seq_length = max_seq_length
# Get number of valid training sequences
def __len__(self):
num_sequences = (len(self.tokens) - 1) // self.max_seq_length
return num_sequences
# Get an input sequence and targets
def __getitem__(self, idx):
# Start index of the sequence
start = idx * self.max_seq_length
# End index of the sequence
end = start + self.max_seq_length
# Input token sequence
input_ids = torch.tensor(self.tokens[start:end], dtype=torch.long)
# Next-token targets/ labels (shifted by one character)
target_ids = torch.tensor(self.tokens[start+1:end+1], dtype=torch.long)
return input_ids, target_idsLet’s use the TextDataset class to create the training and validation datasets.
# Define maximum sequence length
MAX_SEQ_LENGTH = 128
# Create training and validation datasets
train_dataset = TextDataset(training_text, tokenizer, MAX_SEQ_LENGTH)
val_dataset = TextDataset(validation_text, tokenizer, MAX_SEQ_LENGTH)The number of training sequences in the training and validation datasets is as follows.
print(f"Number of training sequences: {len(train_dataset):,}")
print(f"Number of validation sequences: {len(val_dataset):,}")
"""
Output:
Number of training sequences: 876,162
Number of validation sequences: 1,828
"""And this is how a training sequence and its target look.
# Example of a training sequence and its target
input, target = train_dataset[16]
print("Input IDs:\n", input)
print("\nTarget IDs:\n", target)
print("\nDecoded Input:\n", tokenizer.decode(input.tolist()))
print("\nDecoded Target:\n", tokenizer.decode(target.tolist()))
"""
Output:
Input IDs:
tensor([ 262, 1271, 286, 22469, 4991, 583, 4365, 764, 317, 636,
286, 428, 8515, 2950, 4441, 3748, 7514, 14520, 4981, 329,
1123, 2095, 705, 82, 1767, 764, 554, 1502, 284, 4620,
428, 837, 262, 22849, 4847, 16560, 656, 262, 1218, 983,
547, 4615, 837, 355, 484, 1718, 510, 257, 1588, 6903,
286, 4088, 2272, 2622, 329, 262, 8561, 764, 1119, 635,
12328, 262, 8722, 6460, 290, 10152, 286, 711, 523, 484,
714, 5198, 284, 649, 1938, 981, 26645, 262, 6393, 6805,
286, 262, 2168, 705, 11327, 764, 383, 15064, 3341, 547,
3066, 2402, 1903, 287, 2478, 764, 383, 2095, 9824, 547,
1760, 416, 8835, 73, 280, 837, 508, 550, 3111, 319,
262, 2180, 569, 18354, 7496, 17740, 1830, 764, 1649, 4441,
262, 17871, 5321, 11630, 837, 8835, 73, 280])
Target IDs:
tensor([ 1271, 286, 22469, 4991, 583, 4365, 764, 317, 636, 286,
428, 8515, 2950, 4441, 3748, 7514, 14520, 4981, 329, 1123,
2095, 705, 82, 1767, 764, 554, 1502, 284, 4620, 428,
837, 262, 22849, 4847, 16560, 656, 262, 1218, 983, 547,
4615, 837, 355, 484, 1718, 510, 257, 1588, 6903, 286,
4088, 2272, 2622, 329, 262, 8561, 764, 1119, 635, 12328,
262, 8722, 6460, 290, 10152, 286, 711, 523, 484, 714,
5198, 284, 649, 1938, 981, 26645, 262, 6393, 6805, 286,
262, 2168, 705, 11327, 764, 383, 15064, 3341, 547, 3066,
2402, 1903, 287, 2478, 764, 383, 2095, 9824, 547, 1760,
416, 8835, 73, 280, 837, 508, 550, 3111, 319, 262,
2180, 569, 18354, 7496, 17740, 1830, 764, 1649, 4441, 262,
17871, 5321, 11630, 837, 8835, 73, 280, 373])
Decoded Input:
the number of playable units per mission . A part of this upgrade involved creating
unique polygon models for each character 's body . In order to achieve this ,
the cooperative elements incorporated into the second game were removed ,
as they took up a large portion of memory space needed for the improvements .
They also adjusted the difficulty settings and ease of play so they could appeal
to new players while retaining the essential components of the series ' gameplay .
The newer systems were decided upon early in development . The character designs
were done by Honjou , who had worked on the previous Valkyria Chronicles games .
When creating the Nameless Squad , Honjou
Decoded Target:
number of playable units per mission . A part of this upgrade involved creating
unique polygon models for each character 's body . In order to achieve this ,
the cooperative elements incorporated into the second game were removed ,
as they took up a large portion of memory space needed for the improvements .
They also adjusted the difficulty settings and ease of play so they could appeal
to new players while retaining the essential components of the series ' gameplay .
The newer systems were decided upon early in development . The character designs
were done by Honjou , who had worked on the previous Valkyria Chronicles games .
When creating the Nameless Squad , Honjou was
"""Note how the input and target sequences are shifted by one sub-word. This is to help the model learn to predict each next sub-word given the previous ones.
Setting up DataLoader for Batch training








