from __future__ import annotations
from typing import Any
import math
import numpy as np
from smac.constants import VERY_SMALL_NUMBER
from smac.model.gaussian_process.priors.abstract_prior import AbstractPrior
__copyright__ = "Copyright 2022, automl.org"
__license__ = "3-clause BSD"
[docs]
class HorseshoePrior(AbstractPrior):
"""Horseshoe Prior as it is used in spearmint.
Parameters
----------
scale: float
Scaling parameter.
seed : int, defaults to 0
"""
def __init__(self, scale: float, seed: int = 0):
super().__init__(seed=seed)
self._scale = scale
self._scale_square = scale**2
@property
def meta(self) -> dict[str, Any]: # noqa: D102
meta = super().meta
meta.update({"scale": self._scale})
return meta
def _sample_from_prior(self, n_samples: int) -> np.ndarray:
# This is copied from RoBO - scale is most likely the tau parameter
lamda = np.abs(self._rng.standard_cauchy(size=n_samples))
p0 = np.abs(self._rng.randn() * lamda * self._scale)
return p0
def _get_log_probability(self, theta: float) -> float:
# We computed it exactly as in the original spearmint code, they basically say that there's no analytical form
# of the horseshoe prior, but that the multiplier is bounded between 2 and 4 and that they used the middle
# See "The horseshoe estimator for sparse signals" by Carvalho, Poloson and Scott (2010), Equation 1.
# https://www.jstor.org/stable/25734098
# Compared to the paper by Carvalho, there's a constant multiplicator missing
# Compared to Spearmint we first have to undo the log space transformation of the theta
# Note: "undo log space transformation" is done in parent class
if theta == 0:
return np.inf # POSITIVE infinity (this is the "spike")
else:
a = math.log(1 + 3.0 * (self._scale_square / theta**2))
return math.log(a + VERY_SMALL_NUMBER)
def _get_gradient(self, theta: float) -> float:
if theta == 0:
return np.inf # POSITIVE infinity (this is the "spike")
else:
a = -(6 * self._scale_square)
b = 3 * self._scale_square + theta**2
b *= math.log(3 * self._scale_square * theta ** (-2) + 1)
b = max(b, 1e-14)
return a / b