def __init__(self, obs_size, action_size, hidden_sizes=[64, 64], activation="relu"):
super().__init__()
self.obs_size = obs_size
self.action_size = action_size
# Policy network mapping observations to actions
self.policy_net = nn.Sequential(
make_feature_extractor(
architecture="mlp",
obs_shape=obs_size,
n_layers=len(hidden_sizes),
hidden_sizes=hidden_sizes,
activation=activation,
)[0],
nn.Linear(hidden_sizes[-1], 2),
)
# Q-networks mapping observation and actions to Q-values
self.q_net1 = nn.Sequential(
make_feature_extractor(
architecture="mlp",
obs_shape=obs_size + action_size,
n_layers=len(hidden_sizes),
hidden_sizes=hidden_sizes,
activation=activation,
)[0],
nn.Linear(hidden_sizes[-1], 1),
)
self.q_net2 = nn.Sequential(
make_feature_extractor(
architecture="mlp",
obs_shape=obs_size + action_size,
n_layers=len(hidden_sizes),
hidden_sizes=hidden_sizes,
activation=activation,
)[0],
nn.Linear(hidden_sizes[-1], 1),
)
# Value network
self.value_net = nn.Sequential(
make_feature_extractor(
architecture="mlp",
obs_shape=obs_size,
n_layers=len(hidden_sizes),
hidden_sizes=hidden_sizes,
activation=activation,
)[0],
nn.Linear(hidden_sizes[-1], 1),
)