DQN playing Acrobot-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/1d4094fbcc9082de7f53f4348dd4c7c354152907
16eb59e
import gym | |
import torch | |
import torch.nn as nn | |
from typing import Sequence, Type | |
from shared.module import FeatureExtractor, mlp | |
class CriticHead(nn.Module): | |
def __init__( | |
self, | |
hidden_sizes: Sequence[int] = (32,), | |
activation: Type[nn.Module] = nn.Tanh, | |
init_layers_orthogonal: bool = True, | |
) -> None: | |
super().__init__() | |
layer_sizes = tuple(hidden_sizes) + (1,) | |
self._fc = mlp( | |
layer_sizes, | |
activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
final_layer_gain=1.0, | |
) | |
def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
v = self._fc(obs) | |
return v.squeeze(-1) | |