"""Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
Mind that this example does not run on Windows at the moment."""
import math
import os
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
)
NUM_GPU = 8 # Number of GPUs to use for FSDP
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def train(model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
ddp_loss = torch.zeros(2).to(rank)
if sampler:
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target, reduction="sum")
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(data)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print("Train Epoch: {} \tLoss: {:.6f}".format(epoch, ddp_loss[0] / ddp_loss[1]))
def test(model, rank, world_size, test_loader):
model.eval()
correct = 0
ddp_loss = torch.zeros(3).to(rank)
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
ddp_loss[0] += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
ddp_loss[2] += len(data)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
test_loss = math.inf
if rank == 0:
test_loss = ddp_loss[0] / ddp_loss[2]
print(
"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
test_loss,
int(ddp_loss[1]),
int(ddp_loss[2]),
100.0 * ddp_loss[1] / ddp_loss[2],
)
)
return test_loss
def fsdp_main(rank, world_size, test_loss_tensor, lr, epochs, save_model=False):
setup(rank, world_size)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("./", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("./", train=False, transform=transform)
sampler1 = DistributedSampler(
dataset1, rank=rank, num_replicas=world_size, shuffle=True
)
sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)
train_kwargs = {"batch_size": 64, "sampler": sampler1}
test_kwargs = {"batch_size": 1000, "sampler": sampler2}
cuda_kwargs = {"num_workers": 2, "pin_memory": True, "shuffle": False}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=100
)
torch.cuda.set_device(rank)
init_start_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
model = Net().to(rank)
model = FSDP(model)
optimizer = optim.Adadelta(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
init_start_event.record()
test_loss = math.inf
for epoch in range(1, epochs + 1):
train(model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
# calculate test loss for this epoch
test_loss = test(model, rank, world_size, test_loader)
scheduler.step()
if rank == 0:
test_loss_tensor[0] = test_loss
init_end_event.record()
if rank == 0:
init_end_event.synchronize()
print(
"CUDA event elapsed time:"
f" {init_start_event.elapsed_time(init_end_event) / 1000}sec"
)
if save_model:
# use a barrier to make sure training is done on all ranks
dist.barrier()
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
cleanup()
def evaluate_pipeline(lr=0.1, epoch=20):
torch.manual_seed(42)
test_loss_tensor = torch.zeros(1)
test_loss_tensor.share_memory_()
mp.spawn(
fsdp_main, args=(NUM_GPU, test_loss_tensor, lr, epoch), nprocs=NUM_GPU, join=True
)
loss = test_loss_tensor.item()
return loss
if __name__ == "__main__":
import neps
import logging
logging.basicConfig(level=logging.INFO)
class HPOSpace(neps.PipelineSpace):
lr = neps.Float(lower=0.0001, upper=0.1, log=True, prior=0.01)
epoch = neps.IntegerFidelity(lower=1, upper=3)
neps.run(
evaluate_pipeline=evaluate_pipeline,
pipeline_space=HPOSpace(),
root_directory="results/pytorch_fsdp",
fidelities_to_spend=20,
)