Into AI

Into AI

Learn to train a deep learning model on multiple GPUs with Distributed PyTorch

Part 2: Train a CNN on NVIDIA GPUs with PyTorch's Distributed Data Parallel (DDP)

Dr. Ashish Bamania's avatar
Dr. Ashish Bamania
Mar 08, 2026
∙ Paid

👋🏻 Hey there!

This is the second part of a two-part lesson where we get our hands dirty by writing PyTorch code to train a model across multiple GPUs using the distributed machine learning algorithm Distributed Data Parallel (DDP).

If you missed the first part, where we learned about DDP visually, here is your link to get started.

Learn to train deep learning models on multiple GPUs

Learn to train deep learning models on multiple GPUs

Dr. Ashish Bamania
·
Feb 26
Read full story

DDP sounds straightforward conceptually, but the intricate dependencies between GPU computation and communication make it tricky to implement and optimize.

Fortunately, we do not have to rewrite it from scratch and handle all existing edge cases. PyTorch already implements all of these in its PyTorch Distributed library in the DistributedDataParallel class.

To make this lesson easy to follow, I will write all the code in a Jupyter notebook on Kaggle.

Kaggle provides free access to two NVIDIA T4 GPUs for 30 hours each week. We will train a CNN on these two GPUs with DDP using the CIFAR-10 dataset.

The CIFAR-10 dataset contains 60,000 32x32 colour images in 10 classes, with 6000 images per class. Some examples from the dataset are shown in this image.

Let’s begin!

Blessings from Jensen Huang (Source)

Import packages

We start with importing the necessary packages as follows.

# PyTorch core
import torch                              
import torch.nn as nn                    
import torch.optim as optim  
from torch.utils.data import DataLoader                 

# PyTorch Distributed
import torch.distributed as dist          
import torch.multiprocessing as mp        
from torch.nn.parallel import DistributedDataParallel as DDP  
from torch.utils.data.distributed import DistributedSampler  

# TorchVision
import torchvision                        
import torchvision.transforms as transforms 

# Standard library utils
import os       

Define the CNN classifier

Our CNN takes an image as input and classifies it into one of the 10 classes.

We define it exactly the same way we would when training on a single machine, and we aren’t doing anything fancy in this step.

class ImageClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        self.cnn = nn.Sequential(
            # Feature extraction
            # Block 1
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2),

            # Block 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),

            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2),

            # Block 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),

            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2),
            
            nn.Flatten(),
            nn.Linear(in_features=256 * 4 * 4, out_features=512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            
            # Classification
            nn.Linear(in_features=512, out_features=num_classes),
        )

    def forward(self, x):
        return self.cnn(x)

Write helper functions for distributed training

Now comes the distributed part.

User's avatar

Continue reading this post for free, courtesy of Dr. Ashish Bamania.

Or purchase a paid subscription.
© 2026 Dr. Ashish Bamania · Privacy ∙ Terms ∙ Collection notice
Start your SubstackGet the app
Substack is the home for great culture