Pytorch lightning fsdp

# Based on: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.demos import Transformer, WikiText2


class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size, lr):
        super().__init__()
        self.model = Transformer(  # 1B parameters
            vocab_size=vocab_size,
            nlayers=32,
            nhid=4096,
            ninp=1024,
            nhead=64,
        )
        self.lr = lr

    def training_step(self, batch):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


def evaluate_pipeline(lr=0.1, epoch=20):
    L.seed_everything(42)

    # Data
    dataset = WikiText2()
    train_dataloader = DataLoader(dataset)

    # Model
    model = LanguageModel(vocab_size=dataset.vocab_size, lr=lr)

    # Trainer
    trainer = L.Trainer(accelerator="cuda", strategy=FSDPStrategy())
    trainer.fit(model, train_dataloader, max_epochs=epoch)
    return trainer.logged_metrics["train_loss"].detach().item()


if __name__ == "__main__":
    import neps
    import logging

    logging.basicConfig(level=logging.INFO)

    pipeline_space = dict(
        lr=neps.Float(
            lower=0.0001, 
            upper=0.1, 
            log=True, 
            prior=0.01
            ),
        epoch=neps.Integer(
            lower=1,
            upper=3,
            is_fidelity=True
            )
        )

    neps.run(
        evaluate_pipeline=evaluate_pipeline, 
        pipeline_space=pipeline_space, 
        root_directory="results/pytorch_lightning_fsdp", 
        max_evaluations_total=5
        )