Skip to content

Wandb

Wandb plugin.

Todo

This plugin is experimental and out of date.

class WandbParams
dataclass
#

Parameters for initializing a wandb run.

This class is a dataclass that contains all the parameters that are used to initialize a wandb run. It is used by the WandbPlugin to initialize a run. It can be modified using the modify() method.

Please refer to the documentation of the wandb.init() method for more information on the parameters.

def modify(**kwargs) #

Modify the parameters of this instance.

This method returns a new instance of this class with the parameters modified. This is useful for example when you want to modify the parameters of a run to add tags or notes.

Source code in src/amltk/scheduling/plugins/wandb.py
def modify(self, **kwargs: Any) -> WandbParams:
    """Modify the parameters of this instance.

    This method returns a new instance of this class with the parameters
    modified. This is useful for example when you want to modify the
    parameters of a run to add tags or notes.
    """
    return replace(self, **kwargs)

def run(name, config=None) #

Initialize a wandb run.

This method initializes a wandb run using the parameters of this instance. It returns the wandb run object.

PARAMETER DESCRIPTION
name

The name of the run.

TYPE: str

config

The configuration of the run.

TYPE: Mapping[str, Any] | None DEFAULT: None

RETURNS DESCRIPTION
WRun

The wandb run object.

Source code in src/amltk/scheduling/plugins/wandb.py
def run(
    self,
    name: str,
    config: Mapping[str, Any] | None = None,
) -> WRun:
    """Initialize a wandb run.

    This method initializes a wandb run using the parameters of this
    instance. It returns the wandb run object.

    Args:
        name: The name of the run.
        config: The configuration of the run.

    Returns:
        The wandb run object.
    """
    run = wandb.init(
        config=dict(config) if config else None,
        name=name,
        project=self.project,
        group=self.group,
        tags=self.tags,
        entity=self.entity,
        notes=self.notes,
        reinit=self.reinit,
        dir=self.dir,
        config_exclude_keys=self.config_exclude_keys,
        config_include_keys=self.config_include_keys,
        mode=self.mode,
        allow_val_change=self.allow_val_change,
        force=self.force,
    )
    if run is None:
        raise RuntimeError("Wandb run was not initialized")

    return run

class WandbLiveRunWrap(params, fn, *, modify=None) #

Bases: Generic[P]

Wrap a function to log the results to a wandb run.

This class is used to wrap a function that returns a report to log the results to a wandb run. It is used by the WandbTrialTracker to wrap the target function.

PARAMETER DESCRIPTION
params

The parameters to initialize the wandb run.

TYPE: WandbParams

fn

The function to wrap.

TYPE: Callable[Concatenate[Trial, P], Report]

modify

A function that modifies the parameters of the wandb run before each trial.

TYPE: Callable[[Trial, WandbParams], WandbParams] | None DEFAULT: None

Source code in src/amltk/scheduling/plugins/wandb.py
def __init__(
    self,
    params: WandbParams,
    fn: Callable[Concatenate[Trial, P], Trial.Report],
    *,
    modify: Callable[[Trial, WandbParams], WandbParams] | None = None,
):
    """Initialize the wrapper.

    Args:
        params: The parameters to initialize the wandb run.
        fn: The function to wrap.
        modify: A function that modifies the parameters of the wandb run
            before each trial.
    """
    super().__init__()
    self.params = params
    self.fn = fn
    self.modify = modify

def __call__(trial, *args, **kwargs) #

Call the wrapped function and log the results to a wandb run.

Source code in src/amltk/scheduling/plugins/wandb.py
def __call__(self, trial: Trial, *args: P.args, **kwargs: P.kwargs) -> Trial.Report:
    """Call the wrapped function and log the results to a wandb run."""
    params = self.params if self.modify is None else self.modify(trial, self.params)
    with params.run(name=trial.name, config=trial.config) as run:
        # Make sure the run is available from the trial
        trial.extras["wandb"] = run

        report = self.fn(trial, *args, **kwargs)

        report_df = report.df()
        run.log({"table": wandb.Table(dataframe=report_df)})
        wandb_summary = {
            k: v
            for k, v in report.summary.items()
            if isinstance(v, int | float | np.number)
        }
        run.summary.update(wandb_summary)

    wandb.finish()
    return report

class WandbTrialTracker(params, *, modify=None) #

Bases: Plugin

Track trials using wandb.

This class is a task plugin that tracks trials using wandb.

PARAMETER DESCRIPTION
params

The parameters to initialize the wandb run.

TYPE: WandbParams

modify

A function that modifies the parameters of the wandb run before each trial.

TYPE: Callable[[Trial, WandbParams], WandbParams] | None DEFAULT: None

Source code in src/amltk/scheduling/plugins/wandb.py
def __init__(
    self,
    params: WandbParams,
    *,
    modify: Callable[[Trial, WandbParams], WandbParams] | None = None,
):
    """Initialize the plugin.

    Args:
        params: The parameters to initialize the wandb run.
        modify: A function that modifies the parameters of the wandb run
            before each trial.
    """
    super().__init__()
    self.params = params
    self.modify = modify

name: str
classvar
#

The name of the plugin.

def attach_task(task) #

Use the task to register several callbacks.

Source code in src/amltk/scheduling/plugins/wandb.py
@override
def attach_task(self, task: Task) -> None:
    """Use the task to register several callbacks."""
    self._check_explicit_reinit_arg_with_executor(task.scheduler)

def pre_submit(fn, *args, **kwargs) #

Wrap the target function to log the results to a wandb run.

This method wraps the target function to log the results to a wandb run and returns the wrapped function.

PARAMETER DESCRIPTION
fn

The target function.

TYPE: Callable[P, R]

args

The positional arguments of the target function.

TYPE: Any DEFAULT: ()

kwargs

The keyword arguments of the target function.

TYPE: Any DEFAULT: {}

RETURNS DESCRIPTION
tuple[Callable[P, R], tuple, dict] | None

The wrapped function, the positional arguments and the keyword

tuple[Callable[P, R], tuple, dict] | None

arguments.

Source code in src/amltk/scheduling/plugins/wandb.py
@override
def pre_submit(
    self,
    fn: Callable[P, R],
    *args: Any,
    **kwargs: Any,
) -> tuple[Callable[P, R], tuple, dict] | None:
    """Wrap the target function to log the results to a wandb run.

    This method wraps the target function to log the results to a wandb run
    and returns the wrapped function.

    Args:
        fn: The target function.
        args: The positional arguments of the target function.
        kwargs: The keyword arguments of the target function.

    Returns:
        The wrapped function, the positional arguments and the keyword
        arguments.
    """
    fn = WandbLiveRunWrap(self.params, fn, modify=self.modify)  # type: ignore
    return fn, args, kwargs

def copy() #

Copy the plugin.

Source code in src/amltk/scheduling/plugins/wandb.py
@override
def copy(self) -> Self:
    """Copy the plugin."""
    return self.__class__(modify=self.modify, params=replace(self.params))

class WandbPlugin(*, project, group=None, entity=None, dir=None, mode='online') #

Log trials using wandb.

This class is the entry point to log trials using wandb. It can be used to create a trial_tracker() to pass into a Task(plugins=...) or to create wandb.Run's for custom purposes with run().

PARAMETER DESCRIPTION
project

The name of the project.

TYPE: str

group

The name of the group.

TYPE: str | None DEFAULT: None

entity

The name of the entity.

TYPE: str | None DEFAULT: None

dir

The directory to store the runs in.

TYPE: str | Path | None DEFAULT: None

mode

The mode to use for the runs.

TYPE: Literal['online', 'offline', 'disabled'] DEFAULT: 'online'

Source code in src/amltk/scheduling/plugins/wandb.py
def __init__(
    self,
    *,
    project: str,
    group: str | None = None,
    entity: str | None = None,
    dir: str | Path | None = None,  # noqa: A002
    mode: Literal["online", "offline", "disabled"] = "online",
):
    """Initialize the plugin.

    Args:
        project: The name of the project.
        group: The name of the group.
        entity: The name of the entity.
        dir: The directory to store the runs in.
        mode: The mode to use for the runs.
    """
    super().__init__()
    _dir = Path(project) if dir is None else Path(dir)
    _dir.mkdir(parents=True, exist_ok=True)

    self.dir = _dir.resolve().absolute()
    self.project = project
    self.group = group
    self.entity = entity
    self.mode = mode

def trial_tracker(job_type='trial', *, modify=None) #

Create a live tracker.

PARAMETER DESCRIPTION
job_type

The job type to use for the runs.

TYPE: str DEFAULT: 'trial'

modify

A function that modifies the parameters of the wandb run before each trial.

TYPE: Callable[[Trial, WandbParams], WandbParams] | None DEFAULT: None

RETURNS DESCRIPTION
WandbTrialTracker

A live tracker.

Source code in src/amltk/scheduling/plugins/wandb.py
def trial_tracker(
    self,
    job_type: str = "trial",
    *,
    modify: Callable[[Trial, WandbParams], WandbParams] | None = None,
) -> WandbTrialTracker:
    """Create a live tracker.

    Args:
        job_type: The job type to use for the runs.
        modify: A function that modifies the parameters of the wandb run
            before each trial.

    Returns:
        A live tracker.
    """
    params = WandbParams(
        project=self.project,
        entity=self.entity,
        group=self.group,
        dir=self.dir,
        mode=self.mode,  # type: ignore
        job_type=job_type,
    )
    return WandbTrialTracker(params, modify=modify)

def run(*, name, job_type=None, group=None, config=None, tags=None, resume=None, notes=None) #

Create a wandb run.

See wandb.init() for more.

Source code in src/amltk/scheduling/plugins/wandb.py
def run(
    self,
    *,
    name: str,
    job_type: str | None = None,
    group: str | None = None,
    config: Mapping[str, Any] | None = None,
    tags: list[str] | None = None,
    resume: bool | str | None = None,
    notes: str | None = None,
) -> WRun:
    """Create a wandb run.

    See [`wandb.init()`](https://docs.wandb.ai/ref/python/init) for more.
    """
    return WandbParams(
        project=self.project,
        entity=self.entity,
        group=group,
        dir=self.dir,
        mode=self.mode,  # type: ignore
        job_type=job_type,
        tags=tags,
        resume=resume,
        notes=notes,
    ).run(
        name=name,
        config=config,
    )