Learn to train a deep learning model on multiple GPUs with Distributed PyTorch
Part 2: Train a CNN on NVIDIA GPUs with PyTorch's Distributed Data Parallel (DDP)
👋🏻 Hey there!
This is the second part of a two-part lesson where we get our hands dirty by writing PyTorch code to train a model across multiple GPUs using the distributed machine learning algorithm Distributed Data Parallel (DDP).
If you missed the first part, where we learned about DDP visually, here is your link to get started.
DDP sounds straightforward conceptually, but the intricate dependencies between GPU computation and communication make it tricky to implement and optimize.
Fortunately, we do not have to rewrite it from scratch and handle all existing edge cases. PyTorch already implements all of these in its PyTorch Distributed library in the DistributedDataParallel class.
To make this lesson easy to follow, I will write all the code in a Jupyter notebook on Kaggle.
Kaggle provides free access to two NVIDIA T4 GPUs for 30 hours each week. We will train a CNN on these two GPUs with DDP using the CIFAR-10 dataset.

Let’s begin!

Import packages
We start with importing the necessary packages as follows.
# PyTorch core
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# PyTorch Distributed
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
# TorchVision
import torchvision
import torchvision.transforms as transforms
# Standard library utils
import os Define the CNN classifier
Our CNN takes an image as input and classifies it into one of the 10 classes.
We define it exactly the same way we would when training on a single machine, and we aren’t doing anything fancy in this step.
class ImageClassifier(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.cnn = nn.Sequential(
# Feature extraction
# Block 1
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=64),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
# Block 2
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=128),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
# Block 3
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(in_features=256 * 4 * 4, out_features=512),
nn.ReLU(),
nn.Dropout(p=0.5),
# Classification
nn.Linear(in_features=512, out_features=num_classes),
)
def forward(self, x):
return self.cnn(x)Write helper functions for distributed training
Now comes the distributed part.





