Skip to content

Evaluation

amltk.sklearn.evaluation #

This module contains the cross-validation evaluation protocol.

This protocol will create a cross-validation task to be used in parallel and optimization. It represents a typical cross-validation evaluation for sklearn, handling some of the minor nuances of sklearn and it's interaction with optimization and parallelization.

Please see CVEvaluation for more information on usage.

PostSplitSignature module-attribute #

PostSplitSignature: TypeAlias = Callable[
    [Trial, int, "CVEvaluation.PostSplitInfo"],
    "CVEvaluation.PostSplitInfo",
]

A type alias for the post split callback signature.

Please see PostSplitInfo for more information on the information available to this callback.

def my_post_split(
    trial: Trial,
    split_number: int,
    eval: CVEvalauation.PostSplitInfo
) -> CVEvaluation.PostSplitInfo:
    ...

TaskTypeName module-attribute #

TaskTypeName: TypeAlias = Literal[
    "binary",
    "multiclass",
    "multilabel-indicator",
    "multiclass-multioutput",
    "continuous",
    "continuous-multioutput",
]

A type alias for the task type name as defined by sklearn.

XLike module-attribute #

A type alias for X input data type as defined by sklearn.

YLike module-attribute #

A type alias for y input data type as defined by sklearn.

CVEarlyStoppingProtocol #

Bases: Protocol

Protocol for early stopping in cross-validation.

You class should implement the update() and should_stop() methods. You can optionally inherit from this class but it is not required.

class MyStopper:

    def update(self, report: Trial.Report) -> None:
        if report.status is Trial.Status.SUCCESS:
            # ... do some update logic

    def should_stop(self, trial: Trial, split_infos: list[CVEvaluation.PostSplitInfo]) -> bool | Exception:
        mean_scores_up_to_current_split = np.mean([i.val_scores["accuracy"] for i in split_infos])
        if mean_scores_up_to_current_split > 0.9:
            return False  # Keep going
        else:
            return True  # Stop evaluating

should_stop #

should_stop(
    trial: Trial, scores: SplitScores
) -> bool | Exception

Determines whether the cross-validation should stop early.

PARAMETER DESCRIPTION
trial

The trial that is currently being evaluated.

TYPE: Trial

scores

The scores from the evlauated splits.

TYPE: SplitScores

RETURNS DESCRIPTION
bool | Exception

True if the cross-validation should stop, False if it should continue, or an Exception if it should stop and you'd like a custom error to be registered with the trial.

Source code in src/amltk/sklearn/evaluation.py
def should_stop(
    self,
    trial: Trial,
    scores: CVEvaluation.SplitScores,
) -> bool | Exception:
    """Determines whether the cross-validation should stop early.

    Args:
        trial: The trial that is currently being evaluated.
        scores: The scores from the evlauated splits.

    Returns:
        `True` if the cross-validation should stop, `False` if it should
        continue, or an `Exception` if it should stop and you'd like a custom
        error to be registered with the trial.
    """
    ...

update #

update(report: Report) -> None

Update the protocol with a new report.

This will be called when a trial has been completed, either successfully or failed. You can check for successful trials by using report.status.

PARAMETER DESCRIPTION
report

The report from the trial.

TYPE: Report

Source code in src/amltk/sklearn/evaluation.py
def update(self, report: Trial.Report) -> None:
    """Update the protocol with a new report.

    This will be called when a trial has been completed, either successfully
    or failed. You can check for successful trials by using
    [`report.status`][amltk.optimization.Trial.Report.status].

    Args:
        report: The report from the trial.
    """
    ...

CVEvaluation #

CVEvaluation(
    X: XLike,
    y: YLike,
    *,
    X_test: XLike | None = None,
    y_test: YLike | None = None,
    splitter: (
        Literal["holdout", "cv"]
        | BaseShuffleSplit
        | BaseCrossValidator
    ) = "cv",
    n_splits: int = 5,
    holdout_size: float = 0.33,
    train_score: bool = False,
    store_models: bool = False,
    rebalance_if_required_for_stratified_splitting: (
        bool | None
    ) = None,
    additional_scorers: Mapping[str, _Scorer] | None = None,
    random_state: Seed | None = None,
    params: Mapping[str, Any] | None = None,
    task_hint: (
        TaskTypeName
        | Literal["classification", "regression", "auto"]
    ) = "auto",
    working_dir: str | Path | PathBucket | None = None,
    on_error: Literal["raise", "fail"] = "fail",
    post_split: PostSplitSignature | None = None,
    post_processing: (
        Callable[[Report, Node, CompleteEvalInfo], Report]
        | None
    ) = None,
    post_processing_requires_models: bool = False
)

Bases: Emitter

Cross-validation evaluation protocol.

This protocol will create a cross-validation task to be used in parallel and optimization. It represents a typical cross-validation evaluation for sklearn.

Aside from the init parameters, it expects: * The pipeline you are optimizing can be made into a sklearn.pipeline.Pipeline calling .build("sklearn"). * The seed for the trial will be passed as a param to .configure(). If you have a component that accepts a random_state parameter, you can use a request() so that it will be seeded correctly.

from amltk.sklearn import CVEvaluation
from amltk.pipeline import Component, request
from amltk.optimization import Metric

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import get_scorer
from sklearn.datasets import load_iris
from pathlib import Path

pipeline = Component(
    RandomForestClassifier,
    config={"random_state": request("random_state")},
    space={"n_estimators": (10, 100), "criterion": ["gini", "entropy"]},
)

working_dir = Path("./some-path")
X, y = load_iris(return_X_y=True)
evaluator = CVEvaluation(
    X,
    y,
    n_splits=3,
    splitter="cv",
    additional_scorers={"roc_auc": get_scorer("roc_auc_ovr")},
    store_models=False,
    train_score=True,
    working_dir=working_dir,
)

history = pipeline.optimize(
    target=evaluator.fn,
    metric=Metric("accuracy", minimize=False, bounds=(0, 1)),
    working_dir=working_dir,
    max_trials=1,
)
print(history.df())
evaluator.bucket.rmdir()  # Cleanup
                                                     status  ...  profile:cv:train_score:time:unit
name                                                         ...                                  
config_id=1_seed=1641201137_budget=None_instanc...  success  ...                           seconds

[1 rows x 114 columns]

If you need to pass specific configuration items to your pipeline during configuration, you can do so using a request() in the config of your pipeline.

In the below example, we allow the pipeline to be configured with "n_jobs" and pass it in to the CVEvalautor using the params argument.

from amltk.sklearn import CVEvaluation
from amltk.pipeline import Component, request
from amltk.optimization import Metric

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import get_scorer
from sklearn.datasets import load_iris
from pathlib import Path

working_dir = Path("./some-path")
X, y = load_iris(return_X_y=True)

pipeline = Component(
    RandomForestClassifier,
    config={
        "random_state": request("random_state"),
        # Allow it to be configured with n_jobs
        "n_jobs": request("n_jobs", default=None)
    },
    space={"n_estimators": (10, 100), "criterion": ["gini", "entropy"]},
)

evaluator = CVEvaluation(
    X,
    y,
    working_dir=working_dir,
    # Use the `configure` keyword in params to pass to the `n_jobs`
    # Anything in the pipeline requesting `n_jobs` will get the value
    params={"configure": {"n_jobs": 2}}
)
history = pipeline.optimize(
    target=evaluator.fn,
    metric=Metric("accuracy"),
    working_dir=working_dir,
    max_trials=1,
)
print(history.df())
evaluator.bucket.rmdir()  # Cleanup
                                                     status  ...  profile:cv:split_4:time:unit
name                                                         ...                              
config_id=1_seed=1348220125_budget=None_instanc...  success  ...                       seconds

[1 rows x 113 columns]

CV Early Stopping

To see more about early stopping, please see CVEvaluation.cv_early_stopping_plugin().

PARAMETER DESCRIPTION
X

The features to use for training.

TYPE: XLike

y

The target to use for training.

TYPE: YLike

X_test

The features to use for testing. If provided, all scorers will be calculated on this data as well. Must be provided with y_test=.

Scorer params for test scoring

Due to nuances of sklearn's metadata routing, if you need to provide parameters to the scorer for the test data, you can prefix these with "test_". For example, if you need to provide pos_label to the scorer for the test data, you must provide test_pos_label in the params argument.

TYPE: XLike | None DEFAULT: None

y_test

The target to use for testing. If provided, all scorers will be calculated on this data as well. Must be provided with X_test=.

TYPE: YLike | None DEFAULT: None

splitter

The cross-validation splitter to use. This can be either "holdout" or "cv". Please see the related arguments below. If a scikit-learn cross-validator is provided, this will be used directly.

TYPE: Literal['holdout', 'cv'] | BaseShuffleSplit | BaseCrossValidator DEFAULT: 'cv'

n_splits

The number of cross-validation splits to use. This argument will be ignored if splitter="holdout" or a custom splitter is provided for splitter=.

TYPE: int DEFAULT: 5

holdout_size

The size of the holdout set to use. This argument will be ignored if splitter="cv" or a custom splitter is provided for splitter=.

TYPE: float DEFAULT: 0.33

train_score

Whether to score on the training data as well. This will take extra time as predictions will be made on the training data as well.

TYPE: bool DEFAULT: False

store_models

Whether to store the trained models in the trial.

TYPE: bool DEFAULT: False

rebalance_if_required_for_stratified_splitting

Whether the CVEvaluator should rebalance the training data to allow for stratified splitting. * If True, rebalancing will be done if required. That is when the splitter= is "cv" or a StratifiedKFold and there are fewer instances of a minority class than n_splits=. * If None, rebalancing will be done if required it. Same as True but raises a warning if it occurs. * If False, rebalancing will never be done.

TYPE: bool | None DEFAULT: None

additional_scorers

Additional scorers to use.

TYPE: Mapping[str, _Scorer] | None DEFAULT: None

random_state

The random state to use for the cross-validation splitter=. If a custom splitter is provided, this will be ignored.

TYPE: Seed | None DEFAULT: None

params

Parameters to pass to the estimator, splitter or scorers. See scikit-learn.org/stable/metadata_routing.html for more information.

You may also additionally include the following as dictionarys:

  • "configure": Parameters to pass to the pipeline for configure(). Please the example in the class docstring for more information.
  • "build": Parameters to pass to the pipeline for build().

    from imblearn.pipeline import Pipeline as ImbalancedPipeline
    CVEvaluator(
        ...,
        params={
            "build": {
                "builder": "sklearn",
                "pipeline_type": ImbalancedPipeline
            }
        }
    )
    
  • "transform_context": The transform context to use for configure().

Scorer params for test scoring

Due to nuances of sklearn's metadata routing, if you need to provide parameters to the scorer for the test data, you must prefix these with "test_". For example, if you need to provide pos_label to the scorer for the test data, you can provide test_pos_label in the params argument.

TYPE: Mapping[str, Any] | None DEFAULT: None

task_hint

A string indicating the task type matching those use by sklearn's type_of_target. This can be either "binary", "multiclass", "multilabel-indicator", "continuous", "continuous-multioutput" or "multiclass-multioutput".

You can also provide "classification" or "regression" for a more general hint.

If not provided, this will be inferred from the target data. If you know this value, it is recommended to provide it as sometimes the target is ambiguous and sklearn may infer incorrectly.

TYPE: TaskTypeName | Literal['classification', 'regression', 'auto'] DEFAULT: 'auto'

working_dir

The directory to use for storing data. If not provided, a temporary directory will be used. If provided as a string or a Path, it will be used as the path to the directory.

TYPE: str | Path | PathBucket | None DEFAULT: None

on_error

What to do if an error occurs in the task. This can be either "raise" or "fail". If "raise", the error will be raised and the task will fail. If "fail", the error will be caught and the task will report a failure report with the error message stored inside. Set this to "fail" if you want to continue optimization even if some trials fail.

TYPE: Literal['raise', 'fail'] DEFAULT: 'fail'

post_split

If provided, this callable will be called with a PostSplitInfo.

For example, this could be useful if you'd like to save out-of-fold predictions for later use.

def my_post_split(
    split_number: int,
    info: CVEvaluator.PostSplitInfo,
) -> None:
    X_val, y_val = info.val
    oof_preds = fitted_model.predict(X_val)

    split = info.current_split
    info.trial.store({f"oof_predictions_{split}.npy": oof_preds})
    return info

Run in the worker

This callable will be pickled and sent to the worker that is executing an evaluation. This means that you should mitigate relying on any large objects if your callalbe is an object, as the object will get pickled and sent to the worker. This also means you can not rely on information obtained from other trials as when sending the callable to a worker, it is no longer updatable from the main process.

You should also avoid holding on to references to either the model or large data that is passed in PostSplitInfo to the function.

This parameter should primarily be used for callables that rely solely on the output of the current trial and wish to store/add additional information to the trial itself.

TYPE: PostSplitSignature | None DEFAULT: None

post_processing

If provided, this callable will be called with all of the evaluated splits and the final report that will be returned. This can be used to do things such as augment the final scores if required, cleanup any resources or any other tasks that should be run after the evaluation has completed. This will be handed a Report and a CompleteEvalInfo, which contains all the information about the evaluation. If your function requires the individual models, you can set post_processing_requires_models=True. By default this is False as this requires having all models in memory at once.

This can be useful when you'd like to report the score of a bagged model, i.e. an ensemble of all validation models. Another example is if you'd like to add to the summary, the score of what the model would be if refit on all the data.

from amltk.sklearn.voting import voting_with_prefitted_estimators

# Compute the test score of all fold models bagged together
def my_post_processing(
    report: Trial.Report,
    pipeline: Node,
    info: CVEvaluator.CompleteEvalInfo,
) -> Trial.Report:
    bagged_model = voting_with_prefitted_estimators(info.models)
    acc = info.scorers["accuracy"]
    bagged_score = acc(bagged_model, info.X_test, info.y_test)
    report.summary["bagged_test_score"] = bagged_score
    return report

Run in the worker

This callable will be pickled and sent to the worker that is executing an evaluation. This means that you should mitigate relying on any large objects if your callalbe is an object, as the object will get pickled and sent to the worker. This also means you can not rely on information obtained from other trials as when sending the callable to a worker, it is no longer updatable from the main process.

This parameter should primarily be used for callables that will augment the report or what is stored with the trial. It should rely solely on the current trial to prevent unexpected issues.

TYPE: Callable[[Report, Node, CompleteEvalInfo], Report] | None DEFAULT: None

post_processing_requires_models

Whether the post_processing function requires the models to be passed to it. If True, the models will be passed to the function in the CompleteEvalInfo object. If False, the models will not be passed to the function. By default this is False as this requires having all models in memory at once.

TYPE: bool DEFAULT: False

Source code in src/amltk/sklearn/evaluation.py
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
def __init__(  # noqa: PLR0913, C901
    self,
    X: XLike,  # noqa: N803
    y: YLike,
    *,
    X_test: XLike | None = None,  # noqa: N803
    y_test: YLike | None = None,
    splitter: (
        Literal["holdout", "cv"] | BaseShuffleSplit | BaseCrossValidator
    ) = "cv",
    n_splits: int = 5,  # sklearn default
    holdout_size: float = 0.33,
    train_score: bool = False,
    store_models: bool = False,
    rebalance_if_required_for_stratified_splitting: bool | None = None,
    additional_scorers: Mapping[str, _Scorer] | None = None,
    random_state: Seed | None = None,  # Only used if cv is an int/float
    params: Mapping[str, Any] | None = None,
    task_hint: (
        TaskTypeName | Literal["classification", "regression", "auto"]
    ) = "auto",
    working_dir: str | Path | PathBucket | None = None,
    on_error: Literal["raise", "fail"] = "fail",
    post_split: PostSplitSignature | None = None,
    post_processing: (
        Callable[[Trial.Report, Node, CVEvaluation.CompleteEvalInfo], Trial.Report]
        | None
    ) = None,
    post_processing_requires_models: bool = False,
) -> None:
    """Initialize the evaluation protocol.

    Args:
        X: The features to use for training.
        y: The target to use for training.
        X_test: The features to use for testing. If provided, all
            scorers will be calculated on this data as well.
            Must be provided with `y_test=`.

            !!! tip "Scorer params for test scoring"

                Due to nuances of sklearn's metadata routing, if you need to provide
                parameters to the scorer for the test data, you can prefix these
                with `#!python "test_"`. For example, if you need to provide
                `pos_label` to the scorer for the test data, you must provide
                `test_pos_label` in the `params` argument.

        y_test: The target to use for testing. If provided, all
            scorers will be calculated on this data as well.
            Must be provided with `X_test=`.
        splitter: The cross-validation splitter to use. This can be either
            `#!python "holdout"` or `#!python "cv"`. Please see the related
            arguments below. If a scikit-learn cross-validator is provided,
            this will be used directly.
        n_splits: The number of cross-validation splits to use.
            This argument will be ignored if `#!python splitter="holdout"`
            or a custom splitter is provided for `splitter=`.
        holdout_size: The size of the holdout set to use. This argument
            will be ignored if `#!python splitter="cv"` or a custom splitter
            is provided for `splitter=`.
        train_score: Whether to score on the training data as well. This
            will take extra time as predictions will be made on the
            training data as well.
        store_models: Whether to store the trained models in the trial.
        rebalance_if_required_for_stratified_splitting: Whether the CVEvaluator
            should rebalance the training data to allow for stratified splitting.
            * If `True`, rebalancing will be done if required. That is when
                the `splitter=` is `"cv"` or a `StratifiedKFold` and
                there are fewer instances of a minority class than `n_splits=`.
            * If `None`, rebalancing will be done if required it. Same
                as `True` but raises a warning if it occurs.
            * If `False`, rebalancing will never be done.
        additional_scorers: Additional scorers to use.
        random_state: The random state to use for the cross-validation
            `splitter=`. If a custom splitter is provided, this will be
            ignored.
        params: Parameters to pass to the estimator, splitter or scorers.
            See https://scikit-learn.org/stable/metadata_routing.html for
            more information.

            You may also additionally include the following as dictionarys:

            * `#!python "configure"`: Parameters to pass to the pipeline
                for [`configure()`][amltk.pipeline.Node.configure]. Please
                the example in the class docstring for more information.
            * `#!python "build"`: Parameters to pass to the pipeline for
                [`build()`][amltk.pipeline.Node.build].

                ```python
                from imblearn.pipeline import Pipeline as ImbalancedPipeline
                CVEvaluator(
                    ...,
                    params={
                        "build": {
                            "builder": "sklearn",
                            "pipeline_type": ImbalancedPipeline
                        }
                    }
                )
                ```

            * `#!python "transform_context"`: The transform context to use
                for [`configure()`][amltk.pipeline.Node.configure].

            !!! tip "Scorer params for test scoring"

                Due to nuances of sklearn's metadata routing, if you need to provide
                parameters to the scorer for the test data, you must prefix these
                with `#!python "test_"`. For example, if you need to provide
                `pos_label` to the scorer for the test data, you can provide
                `test_pos_label` in the `params` argument.

        task_hint: A string indicating the task type matching those
            use by sklearn's `type_of_target`. This can be either
            `#!python "binary"`, `#!python "multiclass"`,
            `#!python "multilabel-indicator"`, `#!python "continuous"`,
            `#!python "continuous-multioutput"` or
            `#!python "multiclass-multioutput"`.

            You can also provide `#!python "classification"` or
            `#!python "regression"` for a more general hint.

            If not provided, this will be inferred from the target data.
            If you know this value, it is recommended to provide it as
            sometimes the target is ambiguous and sklearn may infer
            incorrectly.
        working_dir: The directory to use for storing data. If not provided,
            a temporary directory will be used. If provided as a string
            or a `Path`, it will be used as the path to the directory.
        on_error: What to do if an error occurs in the task. This can be
            either `#!python "raise"` or `#!python "fail"`. If `#!python "raise"`,
            the error will be raised and the task will fail. If `#!python "fail"`,
            the error will be caught and the task will report a failure report
            with the error message stored inside.
            Set this to `#!python "fail"` if you want to continue optimization
            even if some trials fail.
        post_split: If provided, this callable will be called with a
            [`PostSplitInfo`][amltk.sklearn.evaluation.CVEvaluation.PostSplitInfo].

            For example, this could be useful if you'd like to save out-of-fold
            predictions for later use.

            ```python
            def my_post_split(
                split_number: int,
                info: CVEvaluator.PostSplitInfo,
            ) -> None:
                X_val, y_val = info.val
                oof_preds = fitted_model.predict(X_val)

                split = info.current_split
                info.trial.store({f"oof_predictions_{split}.npy": oof_preds})
                return info
            ```

            !!! warning "Run in the worker"

                This callable will be pickled and sent to the worker that is
                executing an evaluation. This means that you should mitigate
                relying on any large objects if your callalbe is an object, as
                the object will get pickled and sent to the worker. This also means
                you can not rely on information obtained from other trials as when
                sending the callable to a worker, it is no longer updatable from the
                main process.

                You should also avoid holding on to references to either the model
                or large data that is passed in
                [`PostSplitInfo`][amltk.sklearn.evaluation.CVEvaluation.PostSplitInfo]
                to the function.

                This parameter should primarily be used for callables that rely
                solely on the output of the current trial and wish to store/add
                additional information to the trial itself.

        post_processing: If provided, this callable will be called with all of the
            evaluated splits and the final report that will be returned.
            This can be used to do things such as augment the final scores
            if required, cleanup any resources or any other tasks that should be
            run after the evaluation has completed. This will be handed a
            [`Report`][amltk.optimization.trial.Trial.Report] and a
            [`CompleteEvalInfo`][amltk.sklearn.evaluation.CVEvaluation.CompleteEvalInfo],
            which contains all the information about the evaluation. If your
            function requires the individual models, you can set
            `post_processing_requires_models=True`. By default this is `False`
            as this requires having all models in memory at once.

            This can be useful when you'd like to report the score of a bagged
            model, i.e. an ensemble of all validation models. Another example
            is if you'd like to add to the summary, the score of what the model
            would be if refit on all the data.

            ```python
            from amltk.sklearn.voting import voting_with_prefitted_estimators

            # Compute the test score of all fold models bagged together
            def my_post_processing(
                report: Trial.Report,
                pipeline: Node,
                info: CVEvaluator.CompleteEvalInfo,
            ) -> Trial.Report:
                bagged_model = voting_with_prefitted_estimators(info.models)
                acc = info.scorers["accuracy"]
                bagged_score = acc(bagged_model, info.X_test, info.y_test)
                report.summary["bagged_test_score"] = bagged_score
                return report
            ```

            !!! warning "Run in the worker"

                This callable will be pickled and sent to the worker that is
                executing an evaluation. This means that you should mitigate
                relying on any large objects if your callalbe is an object, as
                the object will get pickled and sent to the worker. This also means
                you can not rely on information obtained from other trials as when
                sending the callable to a worker, it is no longer updatable from the
                main process.

                This parameter should primarily be used for callables that will
                augment the report or what is stored with the trial. It should
                rely solely on the current trial to prevent unexpected issues.

        post_processing_requires_models: Whether the `post_processing` function
            requires the models to be passed to it. If `True`, the models will
            be passed to the function in the `CompleteEvalInfo` object. If `False`,
            the models will not be passed to the function. By default this is
            `False` as this requires having all models in memory at once.

    """
    super().__init__()
    if (X_test is not None and y_test is None) or (
        y_test is not None and X_test is None
    ):
        raise ValueError(
            "Both `X_test`, `y_test` must be provided together if one is provided.",
        )

    match working_dir:
        case None:
            tmpdir = Path(
                tempfile.mkdtemp(
                    prefix=self.TMP_DIR_PREFIX,
                    suffix=datetime.now().isoformat(),
                ),
            )
            bucket = PathBucket(tmpdir)
        case str() | Path():
            bucket = PathBucket(working_dir)
        case PathBucket():
            bucket = working_dir

    match task_hint:
        case "classification" | "regression" | "auto":
            task_type = identify_task_type(y, task_hint=task_hint)
        case (
            "binary"
            | "multiclass"
            | "multilabel-indicator"
            | "continuous"
            | "continuous-multioutput"
            | "multiclass-multioutput"  #
        ):
            task_type = task_hint
        case _:
            raise ValueError(
                f"Invalid {task_hint=} provided. Must be in {_valid_task_types}"
                f"\n{type(task_hint)=}",
            )

    match splitter:
        case "cv":
            splitter = _default_cv_resampler(
                task_type,
                n_splits=n_splits,
                random_state=random_state,
            )

        case "holdout":
            splitter = _default_holdout(
                task_type,
                holdout_size=holdout_size,
                random_state=random_state,
            )
        case _:
            splitter = splitter  # noqa: PLW0127

    # This whole block is to check whether we should resample for stratified
    # sampling, in the case of a low minority class.
    if (
        isinstance(splitter, StratifiedKFold)
        and rebalance_if_required_for_stratified_splitting is not False
        and task_type in ("binary", "multiclass")
    ):
        if rebalance_if_required_for_stratified_splitting is None:
            _warning = (
                f"Labels have fewer than `{n_splits=}` instances. Resampling data"
                " to ensure it's possible to have one of each label in each fold."
                " Note that this may cause things to crash if you've provided extra"
                " `params` as the `X` data will have gotten slightly larger. Please"
                " set `rebalance_if_required_for_stratified_splitting=False` if you"
                " do not wish this to be enabled automatically, in which case, you"
                " may either perform resampling yourself or choose a smaller"
                " `n_splits=`."
            )
        else:
            _warning = None

        x_is_frame = isinstance(X, pd.DataFrame)
        y_is_frame = isinstance(y, pd.Series | pd.DataFrame)

        X, y = resample_if_minority_class_too_few_for_n_splits(  # type: ignore
            X if x_is_frame else pd.DataFrame(X),
            y if y_is_frame else pd.Series(y),  # type: ignore
            n_splits=n_splits,
            seed=random_state,
            _warning_if_occurs=_warning,
        )

        if not x_is_frame:
            X = X.to_numpy()  # type: ignore
        if not y_is_frame:
            y = y.to_numpy()  # type: ignore

    self.task_type = task_type
    self.additional_scorers = additional_scorers
    self.bucket = bucket
    self.splitter = splitter
    self.params = dict(params) if params is not None else {}
    self.store_models = store_models
    self.train_score = train_score

    self.X_stored = self.bucket[self._X_FILENAME].put(X)
    self.y_stored = self.bucket[self._Y_FILENAME].put(y)

    self.X_test_stored = None
    self.y_test_stored = None
    if X_test is not None and y_test is not None:
        self.X_test_stored = self.bucket[self._X_TEST_FILENAME].put(X_test)
        self.y_test_stored = self.bucket[self._Y_TEST_FILENAME].put(y_test)

    # We apply a heuristic that "large" parameters, such as sample_weights
    # should be stored to disk as transferring them directly to subprocess as
    # parameters is quite expensive (they must be non-optimally pickled and
    # streamed to the receiving process). By saving it to a file, we can
    # make use of things like numpy/pandas specific efficient pickling
    # protocols and also avoid the need to stream it to the subprocess.
    storable_params = {
        k: v
        for k, v in self.params.items()
        if hasattr(v, "__len__") and len(v) > self.LARGE_PARAM_HEURISTIC  # type: ignore
    }
    for k, v in storable_params.items():
        match subclass_map(v, self.PARAM_EXTENSION_MAPPING, default=None):  # type: ignore
            case (_, extension_to_save_as):
                ext = extension_to_save_as
            case _:
                ext = "pkl"

        self.params[k] = self.bucket[f"{k}.{ext}"].put(v)

    # This is the actual function that will be called in the task
    self.fn = partial(
        cross_validate_task,
        X=self.X_stored,
        y=self.y_stored,
        X_test=self.X_test_stored,
        y_test=self.y_test_stored,
        splitter=self.splitter,
        additional_scorers=self.additional_scorers,
        params=self.params,
        store_models=self.store_models,
        train_score=self.train_score,
        on_error=on_error,
        post_split=post_split,
        post_processing=post_processing,
        post_processing_requires_models=post_processing_requires_models,
    )

LARGE_PARAM_HEURISTIC class-attribute #

LARGE_PARAM_HEURISTIC: int = 100

Any item in params= which is greater will be stored to disk when sent to the worker.

When launching tasks, pickling and streaming large data to tasks can be expensive. This parameter checks if the object is large and if so, stores it to disk and gives it to the task as a Stored object instead.

Please feel free to overwrite this class variable as needed.

PARAM_EXTENSION_MAPPING class-attribute #

PARAM_EXTENSION_MAPPING: dict[type[Sized], str] = {
    ndarray: "npy",
    DataFrame: "pdpickle",
    Series: "pdpickle",
}

The mapping from types to extensions in params.

If the parameter is an instance of one of these types, and is larger than LARGE_PARAM_HEURISTIC, then it will be stored to disk and loaded back up in the task.

Please feel free to overwrite this class variable as needed.

SPLIT_EVALUATED class-attribute instance-attribute #

SPLIT_EVALUATED: Event[
    [Trial, SplitScores], bool | Exception
] = Event("split-evaluated")

Event that is emitted when a split has been evaluated.

Only emitted if the evaluator plugin is being used.

TMP_DIR_PREFIX class-attribute #

TMP_DIR_PREFIX: str = 'amltk-sklearn-cv-evaluation-data-'

Prefix for temporary directory names.

This is only used when working_dir is not specified. If not specified you can control the tmp dir location by setting the TMPDIR environment variable. By default this is /tmp.

When using a temporary directory, it will be deleted by default, controlled by the delete_working_dir= argument.

X_stored instance-attribute #

X_stored: Stored[XLike] = put(X)

The stored features.

You can call .load() to load the data.

additional_scorers instance-attribute #

additional_scorers: Mapping[str, _Scorer] | None = (
    additional_scorers
)

Additional scorers that will be used.

bucket instance-attribute #

bucket: PathBucket = bucket

The bucket to use for storing data.

For cleanup, you can call bucket.rmdir().

params instance-attribute #

params: Mapping[str, Any | Stored[Any]] = (
    dict(params) if params is not None else {}
)

Parameters to pass to the estimator, splitter or scorers.

Please see scikit-learn.org/stable/metadata_routing.html for more.

splitter instance-attribute #

splitter: BaseShuffleSplit | BaseCrossValidator = splitter

The splitter that will be used.

store_models instance-attribute #

store_models: bool = store_models

Whether models will be stored in the trial.

task_type instance-attribute #

task_type: TaskTypeName = task_type

The inferred task type.

train_score instance-attribute #

train_score: bool = train_score

Whether scores will be calculated on the training data as well.

y_stored instance-attribute #

y_stored: Stored[YLike] = put(y)

The stored target.

You can call .load() to load the data.

CompleteEvalInfo dataclass #

CompleteEvalInfo(
    X: XLike,
    y: YLike,
    X_test: XLike | None,
    y_test: YLike | None,
    splitter: BaseShuffleSplit | BaseCrossValidator,
    max_splits: int,
    scores: SplitScores,
    scorers: dict[str, _Scorer],
    models: list[BaseEstimator] | None,
    splitter_params: Mapping[str, Any],
    fit_params: Mapping[str, Any],
    scorer_params: Mapping[str, Any],
    test_scorer_params: Mapping[str, Any],
)

Information about the final evaluation of a cross-validation task.

This class contains information about the final evaluation of a cross-validation that will be passed to the post-processing function.

X instance-attribute #
X: XLike

The features to used for training.

X_test instance-attribute #
X_test: XLike | None

The features used for testing.

fit_params instance-attribute #
fit_params: Mapping[str, Any]

The parameters that were used for fitting the estimator.

Please use select_params() if you need to select the params specific to a split, i.e. for sample_weights.

max_splits instance-attribute #
max_splits: int

The maximum number of splits that were (or could have been) evaluated.

models instance-attribute #
models: list[BaseEstimator] | None

The models that were trained in each split.

This will be None if post_processing_requires_models=False.

scorer_params instance-attribute #
scorer_params: Mapping[str, Any]

The parameters that were used for scoring the estimator.

Please use select_params() if you need to select the params specific to a split, i.e. for sample_weights.

scorers instance-attribute #
scorers: dict[str, _Scorer]

The scorers that were used.

scores instance-attribute #
scores: SplitScores

The scores for the splits that were evaluated.

splitter instance-attribute #
splitter: BaseShuffleSplit | BaseCrossValidator

The splitter that was used.

splitter_params instance-attribute #
splitter_params: Mapping[str, Any]

The parameters that were used for the splitter.

test_scorer_params instance-attribute #
test_scorer_params: Mapping[str, Any]

The parameters that were used for scoring the test data.

Please use select_params() if you need to select the params specific to a split, i.e. for sample_weights.

y instance-attribute #
y: YLike

The targets used for training.

y_test instance-attribute #
y_test: YLike | None

The targets used for testing.

select_params #
select_params(
    params: Mapping[str, Any], indices: ndarray
) -> dict[str, Any]

Convinience method to select parameters for a specific split.

Source code in src/amltk/sklearn/evaluation.py
def select_params(
    self,
    params: Mapping[str, Any],
    indices: np.ndarray,
) -> dict[str, Any]:
    """Convinience method to select parameters for a specific split."""
    return _check_method_params(self.X, params, indices=indices)

PostSplitInfo #

Bases: NamedTuple

Information about the evaluation of a split.

ATTRIBUTE DESCRIPTION
X

The features to used for training.

TYPE: XLike

y

The targets used for training.

TYPE: YLike

X_test

The features used for testing if it was passed in.

TYPE: XLike | None

y_test

The targets used for testing if it was passed in.

TYPE: YLike | None

i_train

The train indices for this split.

TYPE: ndarray

i_val

The validation indices for this split.

TYPE: ndarray

model

The model that was trained in this split.

TYPE: BaseEstimator

train_scores

The training scores for this split if requested.

TYPE: Mapping[str, float] | None

val_scores

The validation scores for this split.

TYPE: Mapping[str, float]

test_scores

The test scores for this split if requested.

TYPE: Mapping[str, float] | None

fitting_params

Any additional fitting parameters that were used.

TYPE: Mapping[str, Any]

train_scorer_params

Any additional scorer parameters used for evaluating scorers on training set.

TYPE: Mapping[str, Any]

val_scorer_params

Any additional scorer parameters used for evaluating scorers on training set.

TYPE: Mapping[str, Any]

test_scorer_params

Any additional scorer parameters used for evaluating scorers on training set.

TYPE: Mapping[str, Any]

SplitScores #

Bases: NamedTuple

The scores for a split.

ATTRIBUTE DESCRIPTION
val

The validation scores for all evaluated split.

TYPE: Mapping[str, list[float]]

train

The training scores for all evaluated splits if requested.

TYPE: Mapping[str, list[float]] | None

test

The test scores for all evaluated splits if requested.

TYPE: Mapping[str, list[float]] | None

cv_early_stopping_plugin #

cv_early_stopping_plugin(
    strategy: CVEarlyStoppingProtocol | None = None,
    *,
    create_comms: (
        Callable[[], tuple[Comm, Comm]] | None
    ) = None
) -> _CVEarlyStoppingPlugin

Create a plugin for a task allow for early stopping.

from dataclasses import dataclass
from pathlib import Path

import sklearn.datasets
from sklearn.tree import DecisionTreeClassifier

from amltk.sklearn import CVEvaluation
from amltk.pipeline import Component
from amltk.optimization import Metric, Trial

working_dir = Path("./some-path")
pipeline = Component(DecisionTreeClassifier, space={"max_depth": (1, 10)})
x, y = sklearn.datasets.load_iris(return_X_y=True)
evaluator = CVEvaluation(x, y, n_splits=3, working_dir=working_dir)

# Our early stopping strategy, with an `update()` and `should_stop()`
# signature match what's expected.

@dataclass
class CVEarlyStopper:
    def update(self, report: Trial.Report) -> None:
        # Normally you would update w.r.t. a finished trial, such
        # as updating a moving average of the scores.
        pass

    def should_stop(self, trial: Trial, scores: CVEvaluation.SplitScores) -> bool | Exception:
        # Return True to stop, False to continue. Alternatively, return a
        # specific exception to attach to the report instead
        return True

history = pipeline.optimize(
    target=evaluator.fn,
    metric=Metric("accuracy", minimize=False, bounds=(0, 1)),
    max_trials=1,
    working_dir=working_dir,

    # Here we insert the plugin to the task that will get created
    plugins=[evaluator.cv_early_stopping_plugin(strategy=CVEarlyStopper())],

    # Notably, we set `on_trial_exception="continue"` to not stop as
    # we expect trials to fail given the early stopping strategy
    on_trial_exception="continue",
)
╭──── Report(config_id=1_seed=1509460901_budget=None_instance=None) - fail ────╮
│ Status(fail)                                                                 │
│ MetricCollection(                                                            │
│     metrics={                                                                │
│         'accuracy': Metric(                                                  │
│             name='accuracy',                                                 │
│             minimize=False,                                                  │
│             bounds=(0.0, 1.0),                                               │
│             fn=None                                                          │
│         )                                                                    │
│     }                                                                        │
│ )                                                                            │
│ ╭─ Metrics ────────────────────────────────────────────────────────────────╮ │
│ │ MetricCollection(                                                        │ │
│ │     metrics={                                                            │ │
│ │         'accuracy': Metric(                                              │ │
│ │             name='accuracy',                                             │ │
│ │             minimize=False,                                              │ │
│ │             bounds=(0.0, 1.0),                                           │ │
│ │             fn=None                                                      │ │
│ │         )                                                                │ │
│ │     }                                                                    │ │
│ │ )                                                                        │ │
│ ╰──────────────────────────────────────────────────────────────────────────╯ │
│ config             {'DecisionTreeClassifier:max_depth': 7}                   │
│ seed               1509460901                                                │
│ bucket             PathBucket(PosixPath('some-path/config_id=1_seed=1509460… │
│ summary            {'split_0:val_accuracy': 0.94}                            │
│ storage            {'exception.txt'}                                         │
│ profile:cv         Interval(                                                 │
│                        memory=Interval(                                      │
│                            start_vms=1229381632.0,                           │
│                            start_rss=244924416.0,                            │
│                            end_vms=1229381632,                               │
│                            end_rss=247676928,                                │
│                            unit=bytes                                        │
│                        ),                                                    │
│                        time=Interval(                                        │
│                            start=1723534477.633168,                          │
│                            end=1723534477.6756542,                           │
│                            kind=wall,                                        │
│                            unit=seconds                                      │
│                        )                                                     │
│                    )                                                         │
│ profile:cv:fit     Interval(                                                 │
│                        memory=Interval(                                      │
│                            start_vms=1229381632.0,                           │
│                            start_rss=245841920.0,                            │
│                            end_vms=1229381632,                               │
│                            end_rss=247545856,                                │
│                            unit=bytes                                        │
│                        ),                                                    │
│                        time=Interval(                                        │
│                            start=1723534477.645106,                          │
│                            end=1723534477.649212,                            │
│                            kind=wall,                                        │
│                            unit=seconds                                      │
│                        )                                                     │
│                    )                                                         │
│ profile:cv:score   Interval(                                                 │
│                        memory=Interval(                                      │
│                            start_vms=1229381632.0,                           │
│                            start_rss=247545856.0,                            │
│                            end_vms=1229381632,                               │
│                            end_rss=247676928,                                │
│                            unit=bytes                                        │
│                        ),                                                    │
│                        time=Interval(                                        │
│                            start=1723534477.6497915,                         │
│                            end=1723534477.651904,                            │
│                            kind=wall,                                        │
│                            unit=seconds                                      │
│                        )                                                     │
│                    )                                                         │
│ profile:cv:split_0 Interval(                                                 │
│                        memory=Interval(                                      │
│                            start_vms=1229381632.0,                           │
│                            start_rss=247676928.0,                            │
│                            end_vms=1229381632,                               │
│                            end_rss=247676928,                                │
│                            unit=bytes                                        │
│                        ),                                                    │
│                        time=Interval(                                        │
│                            start=1723534477.6522815,                         │
│                            end=1723534477.675374,                            │
│                            kind=wall,                                        │
│                            unit=seconds                                      │
│                        )                                                     │
│                    )                                                         │
╰──────────────────────────────────────────────────────────────────────────────╯

!!! warning "Recommended settings for CVEvaluation

When a trial is early stopped, it will be counted as a failed trial.
This can conflict with the behaviour of `pipeline.optimize` which
by default sets `on_trial_exception="raise"`, causing the optimization
to end. If using [`pipeline.optimize`][amltk.pipeline.Node.optimize],
to set `on_trial_exception="continue"` to continue optimization.

This will also add a new event to the task which you can subscribe to with task.on("split-evaluated"). It will be passed a CVEvaluation.PostSplitInfo that you can use to make a decision on whether to continue or stop. The passed in strategy= simply sets up listening to these events for you. You can also do this manually.

scores = []
evaluator = CVEvaluation(...)
task = scheduler.task(
    evaluator.fn,
    plugins=[evaluator.cv_early_stopping_plugin()]
)

@task.on("split-evaluated")
def should_stop(trial: Trial, scores: CVEvaluation.SplitScores) -> bool | Execption:
    # Make a decision on whether to stop or continue
    return info.scores["accuracy"] < np.mean(scores.val["accuracy"])

@task.on("result")
def update_scores(_, report: Trial.Report) -> bool | Execption:
    if report.status is Trial.Status.SUCCESS:
        return scores.append(report.values["accuracy"])
PARAMETER DESCRIPTION
strategy

The strategy to use for early stopping. Must implement the update() and should_stop() methods of CVEarlyStoppingProtocol. Please follow the documentation link to find out more.

By default, when no strategy= is passedj this is None and this will create a Comm object, allowing communication between the worker running the task and the main process. This adds a new event to the task that you can subscribe to with task.on("split-evaluated"). This is how a passed in strategy will be called and updated.

TYPE: CVEarlyStoppingProtocol | None DEFAULT: None

create_comms

A function that creates a pair of comms for the plugin to use. This is useful if you want to create a custom communication channel. If not provided, the default communication channel will be used.

Default communication channel

By default we use a simple multiprocessing.Pipe which works for parallel processses from ProcessPoolExecutor. This may not work if the tasks is being executed in a different filesystem or depending on the executor which executes the task.

TYPE: Callable[[], tuple[Comm, Comm]] | None DEFAULT: None

RETURNS DESCRIPTION
_CVEarlyStoppingPlugin

The plugin to use for the task.

Source code in src/amltk/sklearn/evaluation.py
def cv_early_stopping_plugin(
    self,
    strategy: CVEarlyStoppingProtocol
    | None = None,  # TODO: Can provide some defaults...
    *,
    create_comms: Callable[[], tuple[Comm, Comm]] | None = None,
) -> CVEvaluation._CVEarlyStoppingPlugin:
    """Create a plugin for a task allow for early stopping.

    ```python exec="true" source="material-block" result="python" html="true"
    from dataclasses import dataclass
    from pathlib import Path

    import sklearn.datasets
    from sklearn.tree import DecisionTreeClassifier

    from amltk.sklearn import CVEvaluation
    from amltk.pipeline import Component
    from amltk.optimization import Metric, Trial

    working_dir = Path("./some-path")
    pipeline = Component(DecisionTreeClassifier, space={"max_depth": (1, 10)})
    x, y = sklearn.datasets.load_iris(return_X_y=True)
    evaluator = CVEvaluation(x, y, n_splits=3, working_dir=working_dir)

    # Our early stopping strategy, with an `update()` and `should_stop()`
    # signature match what's expected.

    @dataclass
    class CVEarlyStopper:
        def update(self, report: Trial.Report) -> None:
            # Normally you would update w.r.t. a finished trial, such
            # as updating a moving average of the scores.
            pass

        def should_stop(self, trial: Trial, scores: CVEvaluation.SplitScores) -> bool | Exception:
            # Return True to stop, False to continue. Alternatively, return a
            # specific exception to attach to the report instead
            return True

    history = pipeline.optimize(
        target=evaluator.fn,
        metric=Metric("accuracy", minimize=False, bounds=(0, 1)),
        max_trials=1,
        working_dir=working_dir,

        # Here we insert the plugin to the task that will get created
        plugins=[evaluator.cv_early_stopping_plugin(strategy=CVEarlyStopper())],

        # Notably, we set `on_trial_exception="continue"` to not stop as
        # we expect trials to fail given the early stopping strategy
        on_trial_exception="continue",
    )
    from amltk._doc import doc_print; doc_print(print, history[0])  # markdown-exec: hide
    evaluator.bucket.rmdir()  # markdown-exec: hide
    ```

    !!! warning "Recommended settings for `CVEvaluation`

        When a trial is early stopped, it will be counted as a failed trial.
        This can conflict with the behaviour of `pipeline.optimize` which
        by default sets `on_trial_exception="raise"`, causing the optimization
        to end. If using [`pipeline.optimize`][amltk.pipeline.Node.optimize],
        to set `on_trial_exception="continue"` to continue optimization.

    This will also add a new event to the task which you can subscribe to with
    [`task.on("split-evaluated")`][amltk.sklearn.evaluation.CVEvaluation.SPLIT_EVALUATED].
    It will be passed a
    [`CVEvaluation.PostSplitInfo`][amltk.sklearn.evaluation.CVEvaluation.PostSplitInfo]
    that you can use to make a decision on whether to continue or stop. The
    passed in `strategy=` simply sets up listening to these events for you.
    You can also do this manually.

    ```python
    scores = []
    evaluator = CVEvaluation(...)
    task = scheduler.task(
        evaluator.fn,
        plugins=[evaluator.cv_early_stopping_plugin()]
    )

    @task.on("split-evaluated")
    def should_stop(trial: Trial, scores: CVEvaluation.SplitScores) -> bool | Execption:
        # Make a decision on whether to stop or continue
        return info.scores["accuracy"] < np.mean(scores.val["accuracy"])

    @task.on("result")
    def update_scores(_, report: Trial.Report) -> bool | Execption:
        if report.status is Trial.Status.SUCCESS:
            return scores.append(report.values["accuracy"])
    ```

    Args:
        strategy: The strategy to use for early stopping. Must implement the
            `update()` and `should_stop()` methods of
            [`CVEarlyStoppingProtocol`][amltk.sklearn.evaluation.CVEarlyStoppingProtocol].
            Please follow the documentation link to find out more.

            By default, when no `strategy=` is passedj this is `None` and
            this will create a [`Comm`][amltk.scheduling.plugins.comm.Comm] object,
            allowing communication between the worker running the task and the main
            process. This adds a new event to the task that you can subscribe
            to with
            [`task.on("split-evaluated")`][amltk.sklearn.evaluation.CVEvaluation.SPLIT_EVALUATED].
            This is how a passed in strategy will be called and updated.
        create_comms: A function that creates a pair of comms for the
            plugin to use. This is useful if you want to create a
            custom communication channel. If not provided, the default
            communication channel will be used.

            !!! note "Default communication channel"

                By default we use a simple `multiprocessing.Pipe` which works
                for parallel processses from
                [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor].
                This may not work if the tasks is being executed in a different
                filesystem or depending on the executor which executes the task.

    Returns:
        The plugin to use for the task.
    """  # noqa: E501
    return CVEvaluation._CVEarlyStoppingPlugin(
        self,
        strategy=strategy,
        create_comms=create_comms,
    )

identify_task_type #

identify_task_type(
    y: YLike,
    *,
    task_hint: Literal[
        "classification", "regression", "auto"
    ] = "auto"
) -> TaskTypeName

Identify the task type from the target data.

Source code in src/amltk/sklearn/evaluation.py
def identify_task_type(  # noqa: PLR0911
    y: YLike,
    *,
    task_hint: Literal["classification", "regression", "auto"] = "auto",
) -> TaskTypeName:
    """Identify the task type from the target data."""
    inferred_type: TaskTypeName = type_of_target(y)
    if task_hint == "auto":
        warnings.warn(
            f"`{task_hint=}` was not provided. The task type was inferred from"
            f" the target data to be '{inferred_type}'."
            " To silence this warning, please provide `task_hint`.",
            AutomaticTaskTypeInferredWarning,
            stacklevel=2,
        )
        return inferred_type

    match task_hint, inferred_type:
        # First two cases are everything is fine
        case (
            "classification",
            "binary"
            | "multiclass"
            | "multilabel-indicator"
            | "multiclass-multioutput",
        ):
            return inferred_type
        case ("regression", "continuous" | "continuous-multioutput"):
            return inferred_type
        # Hinted to be regression but we got a single column classification task
        case ("regression", "binary" | "multiclass"):
            warnings.warn(
                f"`{task_hint=}` but `{inferred_type=}`."
                " Set to `continuous` as there is only one target column.",
                MismatchedTaskTypeWarning,
                stacklevel=2,
            )
            return "continuous"
        # Hinted to be regression but we got multi-column classification task
        case ("regression", "multilabel-indicator" | "multiclass-multioutput"):
            warnings.warn(
                f"`{task_hint=}` but `{inferred_type=}`."
                " Set to `continuous-multiouput` as there are more than 1 target"
                " columns.",
                MismatchedTaskTypeWarning,
                stacklevel=2,
            )
            return "continuous"
        # Hinted to be classification but we got a single column regression task
        case ("classification", "continuous"):
            match len(np.unique(y)):
                case 1:
                    raise ValueError(
                        "The target data has only one unique value. This is"
                        f" not a valid classification task.\n{y=}",
                    )
                case 2:
                    warnings.warn(
                        f"`{task_hint=}` but `{inferred_type=}`."
                        " Set to `binary` as only 2 unique values."
                        " To silence this, provide a specific task type to"
                        f"`task_hint=` from {_valid_task_types}.",
                        MismatchedTaskTypeWarning,
                        stacklevel=2,
                    )
                    return "binary"
                case _:
                    warnings.warn(
                        f"`{task_hint=}` but `{inferred_type=}`."
                        " Set to `multiclass` as >2 unique values."
                        " To silence this, provide a specific task type to"
                        f"`task_hint=` from {_valid_task_types}.",
                        MismatchedTaskTypeWarning,
                        stacklevel=2,
                    )
                    return "multiclass"
        # Hinted to be classification but we got multi-column regression task
        case ("classification", "continuous-multioutput"):
            # NOTE: this is a matrix wide .unique, I'm not sure how things
            # work with multiclass-multioutput and whether it should be
            # done by 2 unique per column
            uniques_per_col = [np.unique(col) for col in y.T]
            binary_columns = all(len(col) <= 2 for col in uniques_per_col)  # noqa: PLR2004
            if binary_columns:
                warnings.warn(
                    f"`{task_hint=}` but `{inferred_type=}`."
                    " Set to `multilabel-indicator` as <=2 unique values per column."
                    " To silence this, provide a specific task type to"
                    f"`task_hint=` from {_valid_task_types}.",
                    MismatchedTaskTypeWarning,
                    stacklevel=2,
                )
                return "multilabel-indicator"
            else:  # noqa: RET505
                warnings.warn(
                    f"`{task_hint=}` but `{inferred_type=}`."
                    " Set to `multiclass-multioutput` as at least one column has"
                    " >2 unique values."
                    " To silence this, provide a specific task type to"
                    f"`task_hint=` from {_valid_task_types}.",
                    MismatchedTaskTypeWarning,
                    stacklevel=2,
                )
                return "multiclass-multioutput"
        case _:
            raise RuntimeError(
                f"Unreachable, please report this bug. {task_hint=}, {inferred_type=}",
            )

resample_if_minority_class_too_few_for_n_splits #

resample_if_minority_class_too_few_for_n_splits(
    X_train: DataFrame,
    y_train: Series,
    *,
    n_splits: int,
    seed: Seed | None = None,
    _warning_if_occurs: str | None = None
) -> tuple[DataFrame, DataFrame | Series]

Rebalance the training data to allow stratification.

If your data only contains something such as 3 labels for a single class, and you wish to perform 5 fold cross-validation, you will need to rebalance the data to allow for stratification. This function will take the training data and labels and and resample the data to allow for stratification.

PARAMETER DESCRIPTION
X_train

The training data.

TYPE: DataFrame

y_train

The training labels.

TYPE: Series

n_splits

The number of splits to perform.

TYPE: int

seed

Used for deciding which instances to resample.

TYPE: Seed | None DEFAULT: None

RETURNS DESCRIPTION
tuple[DataFrame, DataFrame | Series]

The rebalanced training data and labels.

Source code in src/amltk/sklearn/evaluation.py
def resample_if_minority_class_too_few_for_n_splits(
    X_train: pd.DataFrame,  # noqa: N803
    y_train: pd.Series,
    *,
    n_splits: int,
    seed: Seed | None = None,
    _warning_if_occurs: str | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame | pd.Series]:
    """Rebalance the training data to allow stratification.

    If your data only contains something such as 3 labels for a single class, and you
    wish to perform 5 fold cross-validation, you will need to rebalance the data to
    allow for stratification. This function will take the training data and labels and
    and resample the data to allow for stratification.

    Args:
        X_train: The training data.
        y_train: The training labels.
        n_splits: The number of splits to perform.
        seed: Used for deciding which instances to resample.

    Returns:
        The rebalanced training data and labels.
    """
    if y_train.ndim != 1:
        raise NotImplementedError(
            "Rebalancing for multi-output classification is not yet supported.",
        )

    # If we are in binary/multilclass setting and there is not enough instances
    # with a given label to perform stratified sampling with `n_splits`, we first
    # find these labels, take the first N instances which have these labels and allows
    # us to reach `n_splits` instances for each label.
    indices_to_resample = None
    label_counts = y_train.value_counts()
    under_represented_labels = label_counts[label_counts < n_splits]  # type: ignore

    collected_indices = []
    if any(under_represented_labels):
        if _warning_if_occurs is not None:
            warnings.warn(_warning_if_occurs, UserWarning, stacklevel=2)
        under_rep_instances = y_train[y_train.isin(under_represented_labels.index)]  # type: ignore

        grouped_by_label = under_rep_instances.to_frame("label").groupby(  # type: ignore
            "label",
            observed=True,  # Handles categoricals
        )
        for _label, instances_with_label in grouped_by_label:
            n_to_take = n_splits - len(instances_with_label)

            need_to_sample_repeatedly = n_to_take > len(instances_with_label)
            resampled_instances = instances_with_label.sample(
                n=n_to_take,
                random_state=seed,  # type: ignore
                # It could be that we have to repeat sample if there are not enough
                # instances to hit `n_splits` for a given label.
                replace=need_to_sample_repeatedly,
            )
            collected_indices.append(np.asarray(resampled_instances.index))

        indices_to_resample = np.concatenate(collected_indices)

    if indices_to_resample is not None:
        # Give the new samples a new index to not overlap with the original data.
        new_start_idx = X_train.index.max() + 1  # type: ignore
        new_end_idx = new_start_idx + len(indices_to_resample)
        new_idx = pd.RangeIndex(start=new_start_idx, stop=new_end_idx)
        resampled_X = X_train.loc[indices_to_resample].set_index(new_idx)
        resampled_y = y_train.loc[indices_to_resample].set_axis(new_idx)
        X_train = pd.concat([X_train, resampled_X])
        y_train = pd.concat([y_train, resampled_y])

    return X_train, y_train