Image segmentation hpo

# Example pipeline used from; https://lightning.ai/lightning-ai/studios/image-segmentation-with-pytorch-lightning

import torch
from torchvision import transforms, datasets, models
import lightning as L
from lightning.pytorch.strategies import DDPStrategy
import os
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from torch.optim.lr_scheduler import PolynomialLR


class LitSegmentation(L.LightningModule):
    def __init__(self, iters_per_epoch, lr, momentum, weight_decay):
        super().__init__()
        self.model = models.segmentation.fcn_resnet50(num_classes=21, aux_loss=True)
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.iters_per_epoch = iters_per_epoch
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

    def training_step(self, batch):
        images, targets = batch
        outputs = self.model(images)['out']
        loss = self.loss_fn(outputs, targets.long().squeeze(1))
        self.log("train_loss", loss, sync_dist=True)
        return loss

    def validation_step(self, batch):
        images, targets = batch
        outputs = self.model(images)['out']
        loss = self.loss_fn(outputs, targets.long().squeeze(1))
        self.log("val_loss", loss, sync_dist=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
        scheduler = PolynomialLR(
            optimizer, total_iters=self.iters_per_epoch * self.trainer.max_epochs, power=0.9
        )
        return [optimizer], [scheduler]



class SegmentationData(L.LightningDataModule):
    def __init__(self, batch_size=4):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        dataset_path = "data/VOCtrainval_11-May-2012.tar"
        if not os.path.exists(dataset_path):
            datasets.VOCSegmentation(root="data", download=True)

    def train_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((256, 256), antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        target_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((256, 256), antialias=True)])
        train_dataset = datasets.VOCSegmentation(root="data", transform=transform, target_transform=target_transform)
        return torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=63)

    def val_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((256, 256), antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        target_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((256, 256), antialias=True)])
        val_dataset = datasets.VOCSegmentation(root="data", year='2012', image_set='val', transform=transform, target_transform=target_transform)
        return torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=63)



def main(**kwargs):
    data = SegmentationData(kwargs.get("batch_size", 4))
    iters_per_epoch = len(data.train_dataloader())
    model = LitSegmentation(iters_per_epoch, kwargs.get("lr", 0.02), kwargs.get("momentum", 0.9), kwargs.get("weight_decay", 1e-4)) 
    trainer = L.Trainer(max_epochs=kwargs.get("epoch", 30), strategy=DDPStrategy(find_unused_parameters=True), enable_checkpointing=False)
    trainer.fit(model, data)
    val_loss = trainer.logged_metrics["val_loss"].detach().item()
    return val_loss


if __name__ == "__main__":
    import neps
    import logging

    logging.basicConfig(level=logging.INFO)

    # Search space for hyperparameters
    pipeline_space = dict(
        lr=neps.Float(
            lower=0.0001, 
            upper=0.1, 
            log=True, 
            prior=0.02
            ),
        momentum=neps.Float(
            lower=0.1, 
            upper=0.9, 
            prior=0.5
            ),
        weight_decay=neps.Float(
            lower=1e-5, 
            upper=1e-3, 
            log=True, 
            prior=1e-4
            ),
        epoch=neps.Integer(
            lower=10,
            upper=30,
            is_fidelity=True
            ),
        batch_size=neps.Integer(
            lower=4,
            upper=12,
            prior=4
        ),
    )

    neps.run(
        evaluate_pipeline=main, 
        pipeline_space=pipeline_space, 
        root_directory="hpo_image_segmentation", 
        max_evaluations_total=500
    )