def __init__(
self,
obs_size: int,
action_size: int,
log_std_min: float = -5,
log_std_max: float = 2,
action_low: float = -1,
action_high: float = +1,
**kwargs,
):
super().__init__()
self.obs_size = obs_size
self.action_size = action_size
self.log_std_min = log_std_min
self.log_std_max = log_std_max
# This model is continuous only
self.continuous_action = True
# Register the per-dim scale and bias so we can rescale [-1,1]→[low,high].
action_low = torch.as_tensor(action_low, dtype=torch.float32)
action_high = torch.as_tensor(action_high, dtype=torch.float32)
self.register_buffer(
"action_scale", (action_high - action_low) / 2.0
)
self.register_buffer(
"action_bias", (action_high + action_low) / 2.0
)
head_kwargs = {"hidden_sizes": [256, 256], "activation": "relu"}
feature_extractor_kwargs = {
"obs_shape": self.obs_size,
"activation": "relu",
"hidden_sizes": [256, 256],
"n_layers": 2,
}
# Allow direct specification of hidden_sizes and activation at top level
if "hidden_sizes" in kwargs:
feature_extractor_kwargs["hidden_sizes"] = kwargs["hidden_sizes"]
head_kwargs["hidden_sizes"] = kwargs["hidden_sizes"]
if "activation" in kwargs:
feature_extractor_kwargs["activation"] = kwargs["activation"]
head_kwargs["activation"] = kwargs["activation"]
if "head_kwargs" in kwargs:
head_kwargs.update(kwargs["head_kwargs"])
if "feature_extractor_kwargs" in kwargs:
feature_extractor_kwargs.update(kwargs["feature_extractor_kwargs"])
# Store for Q-network creation
self.hidden_sizes = feature_extractor_kwargs.get("hidden_sizes", [256, 256])
self.activation = feature_extractor_kwargs.get("activation", "relu")
# Policy feature extractor and head
self.policy_feature_extractor, policy_feat_dim = make_feature_extractor(
**feature_extractor_kwargs
)
# Policy head: just the final output layer
self.policy_head = make_policy_head(
in_size=policy_feat_dim,
out_size=self.action_size * 2, # mean and log_std
hidden_sizes=[], # No hidden layers, just final linear layer
activation=head_kwargs["activation"]
)
# Create policy_net for backward compatibility
self.policy_net = nn.Sequential(self.policy_feature_extractor, self.policy_head)
# Q-networks: feature extractors + heads
q_feature_extractor_kwargs = feature_extractor_kwargs.copy()
q_feature_extractor_kwargs["obs_shape"] = self.obs_size + self.action_size
# Q-network 1
self.q_feature_extractor1, q_feat_dim = make_feature_extractor(**q_feature_extractor_kwargs)
self.q_head1 = make_q_head(
in_size=q_feat_dim,
hidden_sizes=[], # No hidden layers, just final linear layer
activation=head_kwargs["activation"]
)
self.q_net1 = nn.Sequential(self.q_feature_extractor1, self.q_head1)
# Q-network 2
self.q_feature_extractor2, _ = make_feature_extractor(**q_feature_extractor_kwargs)
self.q_head2 = make_q_head(
in_size=q_feat_dim,
hidden_sizes=[], # No hidden layers, just final linear layer
activation=head_kwargs["activation"]
)
self.q_net2 = nn.Sequential(self.q_feature_extractor2, self.q_head2)
# Target Q-networks
self.target_q_feature_extractor1, _ = make_feature_extractor(**q_feature_extractor_kwargs)
self.target_q_head1 = make_q_head(
in_size=q_feat_dim,
hidden_sizes=[], # No hidden layers, just final linear layer
activation=head_kwargs["activation"]
)
self.target_q_net1 = nn.Sequential(self.target_q_feature_extractor1, self.target_q_head1)
self.target_q_feature_extractor2, _ = make_feature_extractor(**q_feature_extractor_kwargs)
self.target_q_head2 = make_q_head(
in_size=q_feat_dim,
hidden_sizes=[], # No hidden layers, just final linear layer
activation=head_kwargs["activation"]
)
self.target_q_net2 = nn.Sequential(self.target_q_feature_extractor2, self.target_q_head2)
# Copy weights from live to target networks
self.target_q_feature_extractor1.load_state_dict(self.q_feature_extractor1.state_dict())
self.target_q_head1.load_state_dict(self.q_head1.state_dict())
self.target_q_feature_extractor2.load_state_dict(self.q_feature_extractor2.state_dict())
self.target_q_head2.load_state_dict(self.q_head2.state_dict())
# Freeze target networks
for p in self.target_q_feature_extractor1.parameters():
p.requires_grad = False
for p in self.target_q_head1.parameters():
p.requires_grad = False
for p in self.target_q_feature_extractor2.parameters():
p.requires_grad = False
for p in self.target_q_head2.parameters():
p.requires_grad = False
# Create a value function wrapper for compatibility
class ValueFunctionWrapper(nn.Module):
def __init__(self, parent_model):
super().__init__()
self.parent_model = parent_model
def forward(self, x):
# SAC doesn't have a separate value function, but for compatibility
# we can return the minimum of the two Q-values with a zero action
# This is mainly for interface compatibility
batch_size = x.shape[0]
zero_action = torch.zeros(
batch_size, self.parent_model.action_size, device=x.device
)
state_action = torch.cat([x, zero_action], dim=-1)
q1 = self.parent_model.forward_q1(state_action)
q2 = self.parent_model.forward_q2(state_action)
return torch.min(q1, q2)
self.value_function_module = ValueFunctionWrapper(self)