Skip to content

Q learning

mighty.mighty_update.q_learning #

Q-learning update.

ClippedDoubleQLearning #

ClippedDoubleQLearning(
    model, gamma, optimizer=Adam, **optimizer_kwargs
)

Bases: QLearning

Clipped Double Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def __init__(
    self, model, gamma, optimizer=torch.optim.Adam, **optimizer_kwargs
) -> None:
    """Initialize the Clipped Double Q-learning update."""
    super().__init__(model, gamma, optimizer, **optimizer_kwargs)

apply_update #

apply_update(preds, targets)

Apply the Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def apply_update(self, preds, targets):
    """Apply the Q-learning update."""
    self.optimizer.zero_grad()
    loss = F.mse_loss(preds, targets)
    loss.backward()
    self.optimizer.step()
    return {"Update/loss": loss.detach().numpy().item()}

td_error #

td_error(batch, q_net, target_net=None)

Compute the TD error for the Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def td_error(self, batch, q_net, target_net=None):
    """Compute the TD error for the Q-learning update."""
    preds, targets = self.get_targets(batch, q_net, target_net)
    return F.mse_loss(preds, targets, reduction="none").detach().mean(axis=1)

DoubleQLearning #

DoubleQLearning(
    model, gamma, optimizer=Adam, **optimizer_kwargs
)

Bases: QLearning

Double Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def __init__(
    self, model, gamma, optimizer=torch.optim.Adam, **optimizer_kwargs
) -> None:
    """Initialize the Double Q-learning update."""
    super().__init__(model, gamma, optimizer, **optimizer_kwargs)

apply_update #

apply_update(preds, targets)

Apply the Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def apply_update(self, preds, targets):
    """Apply the Q-learning update."""
    self.optimizer.zero_grad()
    loss = F.mse_loss(preds, targets)
    loss.backward()
    self.optimizer.step()
    return {"Update/loss": loss.detach().numpy().item()}

td_error #

td_error(batch, q_net, target_net=None)

Compute the TD error for the Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def td_error(self, batch, q_net, target_net=None):
    """Compute the TD error for the Q-learning update."""
    preds, targets = self.get_targets(batch, q_net, target_net)
    return F.mse_loss(preds, targets, reduction="none").detach().mean(axis=1)

QLearning #

QLearning(model, gamma, optimizer=Adam, **optimizer_kwargs)

Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def __init__(
    self, model, gamma, optimizer=torch.optim.Adam, **optimizer_kwargs
) -> None:
    """Initialize the Q-learning update."""
    self.gamma = gamma
    self.optimizer = optimizer(params=model.parameters(), **optimizer_kwargs)

apply_update #

apply_update(preds, targets)

Apply the Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def apply_update(self, preds, targets):
    """Apply the Q-learning update."""
    self.optimizer.zero_grad()
    loss = F.mse_loss(preds, targets)
    loss.backward()
    self.optimizer.step()
    return {"Update/loss": loss.detach().numpy().item()}

get_targets #

get_targets(batch, q_net, target_net=None)

Get targets for the Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def get_targets(self, batch, q_net, target_net=None):
    """Get targets for the Q-learning update."""
    if target_net is None:
        target_net = q_net
    max_next = (
        target_net(torch.as_tensor(batch.next_obs, dtype=torch.float32))
        .max(1)[0]
        .unsqueeze(1)
    )
    targets = (
        batch.rewards.unsqueeze(-1)
        + (~batch.dones.unsqueeze(-1)) * self.gamma * max_next
    )
    preds = q_net(torch.as_tensor(batch.observations, dtype=torch.float32)).gather(
        1, batch.actions.to(torch.int64).unsqueeze(-1)
    )
    return preds.to(torch.float32), targets.to(torch.float32)

td_error #

td_error(batch, q_net, target_net=None)

Compute the TD error for the Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def td_error(self, batch, q_net, target_net=None):
    """Compute the TD error for the Q-learning update."""
    preds, targets = self.get_targets(batch, q_net, target_net)
    return F.mse_loss(preds, targets, reduction="none").detach().mean(axis=1)

SPRQLearning #

SPRQLearning(
    model,
    gamma,
    optimizer=Adam,
    spr_loss_weight=1,
    huber_delta=1,
    **optimizer_kwargs,
)

Bases: QLearning

SPR Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def __init__(
    self,
    model,
    gamma,
    optimizer=torch.optim.Adam,
    spr_loss_weight=1,
    huber_delta=1,
    **optimizer_kwargs,
) -> None:
    """Initialize the SPR Q-learning update."""
    super().__init__(model, gamma, optimizer, **optimizer_kwargs)
    self.spr_loss_weight = spr_loss_weight
    self.huber_delta = huber_delta

td_error #

td_error(batch, q_net, target_net=None)

Compute the TD error for the Q-learning update.

Source code in mighty/mighty_update/q_learning.py
def td_error(self, batch, q_net, target_net=None):
    """Compute the TD error for the Q-learning update."""
    preds, targets = self.get_targets(batch, q_net, target_net)
    return F.mse_loss(preds, targets, reduction="none").detach().mean(axis=1)