# Example pipeline used from; https://lightning.ai/lightning-ai/studios/image-segmentation-with-pytorch-lightning
import torch
from torchvision import transforms, datasets, models
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
import os
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from torch.optim.lr_scheduler import PolynomialLR
class LitSegmentation(L.LightningModule):
def __init__(self, iters_per_epoch, lr, momentum, weight_decay):
super().__init__()
self.model = models.segmentation.fcn_resnet50(num_classes=21, aux_loss=True)
self.loss_fn = torch.nn.CrossEntropyLoss()
self.iters_per_epoch = iters_per_epoch
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
def training_step(self, batch):
images, targets = batch
outputs = self.model(images)['out']
loss = self.loss_fn(outputs, targets.long().squeeze(1))
self.log("train_loss", loss, sync_dist=True)
return loss
def validation_step(self, batch):
images, targets = batch
outputs = self.model(images)['out']
loss = self.loss_fn(outputs, targets.long().squeeze(1))
self.log("val_loss", loss, sync_dist=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
scheduler = PolynomialLR(
optimizer, total_iters=self.iters_per_epoch * self.trainer.max_epochs, power=0.9
)
return [optimizer], [scheduler]
class SegmentationData(L.LightningDataModule):
def __init__(self, batch_size=4):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
dataset_path = "data/VOCtrainval_11-May-2012.tar"
if not os.path.exists(dataset_path):
datasets.VOCSegmentation(root="data", download=True)
def train_dataloader(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
target_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((256, 256), antialias=True)])
train_dataset = datasets.VOCSegmentation(root="data", transform=transform, target_transform=target_transform)
return torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=63)
def val_dataloader(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
target_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((256, 256), antialias=True)])
val_dataset = datasets.VOCSegmentation(root="data", year='2012', image_set='val', transform=transform, target_transform=target_transform)
return torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=63)
def main(**kwargs):
data = SegmentationData(kwargs.get("batch_size", 4))
iters_per_epoch = len(data.train_dataloader())
model = LitSegmentation(iters_per_epoch, kwargs.get("lr", 0.02), kwargs.get("momentum", 0.9), kwargs.get("weight_decay", 1e-4))
trainer = L.Trainer(max_epochs=kwargs.get("epoch", 30), strategy=DDPStrategy(find_unused_parameters=True), enable_checkpointing=False)
trainer.fit(model, data)
val_loss = trainer.logged_metrics["val_loss"].detach().item()
return val_loss
if __name__ == "__main__":
import neps
import logging
logging.basicConfig(level=logging.INFO)
# Search space for hyperparameters
pipeline_space = dict(
lr=neps.Float(
lower=0.0001,
upper=0.1,
log=True,
prior=0.02
),
momentum=neps.Float(
lower=0.1,
upper=0.9,
prior=0.5
),
weight_decay=neps.Float(
lower=1e-5,
upper=1e-3,
log=True,
prior=1e-4
),
epoch=neps.Integer(
lower=10,
upper=30,
is_fidelity=True
),
batch_size=neps.Integer(
lower=4,
upper=12,
prior=4
),
)
neps.run(
evaluate_pipeline=main,
pipeline_space=pipeline_space,
root_directory="hpo_image_segmentation",
max_evaluations_total=500
)