Image segmentation hpo

# Example pipeline used from; https://lightning.ai/lightning-ai/studios/image-segmentation-with-pytorch-lightning

import os

import torch
from torchvision import transforms, datasets, models
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
from torch.optim.lr_scheduler import PolynomialLR


class LitSegmentation(L.LightningModule):
    def __init__(self, iters_per_epoch, lr, momentum, weight_decay):
        super().__init__()
        self.model = models.segmentation.fcn_resnet50(num_classes=21, aux_loss=True)
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.iters_per_epoch = iters_per_epoch
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

    def training_step(self, batch):
        images, targets = batch
        outputs = self.model(images)["out"]
        loss = self.loss_fn(outputs, targets.long().squeeze(1))
        self.log("train_loss", loss, sync_dist=True)
        return loss

    def validation_step(self, batch):
        images, targets = batch
        outputs = self.model(images)["out"]
        loss = self.loss_fn(outputs, targets.long().squeeze(1))
        self.log("val_loss", loss, sync_dist=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )
        scheduler = PolynomialLR(
            optimizer,
            total_iters=self.iters_per_epoch * self.trainer.max_epochs,
            power=0.9,
        )
        return [optimizer], [scheduler]


class SegmentationData(L.LightningDataModule):
    def __init__(self, batch_size=4):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        dataset_path = ".data/VOC/VOCtrainval_11-May-2012.tar"
        if not os.path.exists(dataset_path):
            datasets.VOCSegmentation(root=".data/VOC", download=True)

    def train_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((256, 256), antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        target_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Resize((256, 256), antialias=True)]
        )
        train_dataset = datasets.VOCSegmentation(
            root=".data/VOC", transform=transform, target_transform=target_transform
        )
        return torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=16,
            persistent_workers=True,
        )

    def val_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((256, 256), antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        target_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Resize((256, 256), antialias=True)]
        )
        val_dataset = datasets.VOCSegmentation(
            root=".data/VOC",
            year="2012",
            image_set="val",
            transform=transform,
            target_transform=target_transform,
        )
        return torch.utils.data.DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=16,
            persistent_workers=True,
        )


def evaluate_pipeline(**kwargs):
    data = SegmentationData(kwargs.get("batch_size", 4))
    data.prepare_data()
    iters_per_epoch = len(data.train_dataloader())
    model = LitSegmentation(
        iters_per_epoch,
        kwargs.get("lr", 0.02),
        kwargs.get("momentum", 0.9),
        kwargs.get("weight_decay", 1e-4),
    )
    trainer = L.Trainer(
        max_epochs=kwargs.get("epoch", 30),
        strategy=DDPStrategy(find_unused_parameters=True),
        enable_checkpointing=False,
    )
    trainer.fit(model, data)
    val_loss = trainer.logged_metrics["val_loss"].detach().item()
    return val_loss


if __name__ == "__main__":
    import neps
    import logging

    logging.basicConfig(level=logging.INFO)

    # Search space for hyperparameters
    pipeline_space = dict(
        lr=neps.HPOFloat(lower=0.0001, upper=0.1, log=True, prior=0.02),
        momentum=neps.HPOFloat(lower=0.1, upper=0.9, prior=0.5),
        weight_decay=neps.HPOFloat(lower=1e-5, upper=1e-3, log=True, prior=1e-4),
        epoch=neps.HPOInteger(lower=10, upper=30, is_fidelity=True),
        batch_size=neps.HPOInteger(lower=4, upper=12, prior=4),
    )

    neps.run(
        evaluate_pipeline=evaluate_pipeline,
        pipeline_space=pipeline_space,
        root_directory="results/hpo_image_segmentation",
        fidelities_to_spend=500
    )