Sac update
mighty.mighty_update.sac_update
#
SACUpdate
#
SACUpdate(
model: SACModel,
policy_lr: float = 0.001,
q_lr: float = 0.001,
value_lr: float = 0.001,
tau: float = 0.005,
alpha: float = 0.2,
gamma: float = 0.99,
)
:param model: The SAC model containing policy and Q-networks. :param policy_lr: Learning rate for the policy network. :param q_lr: Learning rate for the Q-networks. :param value_lr: Learning rate for the value network. :param tau: Soft update parameter for the target networks. :param alpha: Entropy regularization coefficient.
Source code in mighty/mighty_update/sac_update.py
calculate_td_error
#
calculate_td_error(transition: TransitionBatch) -> Tuple
Calculate the TD error for a given transition.
:param transition: Current transition :return: TD error
Source code in mighty/mighty_update/sac_update.py
update
#
update(batch: TransitionBatch) -> Dict
Perform an update of the SAC model using a batch of experience.
:param batch: A batch of experience data. :return: A dictionary of loss values for tracking.