diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index a6a018535..1ce394697 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -18,22 +18,36 @@ from tianshou.data.batch import Batch ModuleType = Type[nn.Module] +ArgsType = Union[Tuple[Any, ...], Dict[Any, Any], Sequence[Tuple[Any, ...]], + Sequence[Dict[Any, Any]]] def miniblock( input_size: int, output_size: int = 0, norm_layer: Optional[ModuleType] = None, + norm_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None, activation: Optional[ModuleType] = None, + act_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None, linear_layer: Type[nn.Linear] = nn.Linear, ) -> List[nn.Module]: """Construct a miniblock with given input/output-size, norm layer and \ activation.""" layers: List[nn.Module] = [linear_layer(input_size, output_size)] if norm_layer is not None: - layers += [norm_layer(output_size)] # type: ignore + if isinstance(norm_args, tuple): + layers += [norm_layer(output_size, *norm_args)] # type: ignore + elif isinstance(norm_args, dict): + layers += [norm_layer(output_size, **norm_args)] # type: ignore + else: + layers += [norm_layer(output_size)] # type: ignore if activation is not None: - layers += [activation()] + if isinstance(act_args, tuple): + layers += [activation(*act_args)] + elif isinstance(act_args, dict): + layers += [activation(**act_args)] + else: + layers += [activation()] return layers @@ -68,7 +82,9 @@ def __init__( output_dim: int = 0, hidden_sizes: Sequence[int] = (), norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, + norm_args: Optional[ArgsType] = None, activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU, + act_args: Optional[ArgsType] = None, device: Optional[Union[str, int, torch.device]] = None, linear_layer: Type[nn.Linear] = nn.Linear, flatten_input: bool = True, @@ -79,24 +95,41 @@ def __init__( if isinstance(norm_layer, list): assert len(norm_layer) == len(hidden_sizes) norm_layer_list = norm_layer + if isinstance(norm_args, list): + assert len(norm_args) == len(hidden_sizes) + norm_args_list = norm_args + else: + norm_args_list = [norm_args for _ in range(len(hidden_sizes))] else: norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))] + norm_args_list = [norm_args for _ in range(len(hidden_sizes))] else: norm_layer_list = [None] * len(hidden_sizes) + norm_args_list = [None] * len(hidden_sizes) if activation: if isinstance(activation, list): assert len(activation) == len(hidden_sizes) activation_list = activation + if isinstance(act_args, list): + assert len(act_args) == len(hidden_sizes) + act_args_list = act_args + else: + act_args_list = [act_args for _ in range(len(hidden_sizes))] else: activation_list = [activation for _ in range(len(hidden_sizes))] + act_args_list = [act_args for _ in range(len(hidden_sizes))] else: activation_list = [None] * len(hidden_sizes) + act_args_list = [None] * len(hidden_sizes) hidden_sizes = [input_dim] + list(hidden_sizes) model = [] - for in_dim, out_dim, norm, activ in zip( - hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list + for in_dim, out_dim, norm, norm_args, activ, act_args in zip( + hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, norm_args_list, + activation_list, act_args_list ): - model += miniblock(in_dim, out_dim, norm, activ, linear_layer) + model += miniblock( + in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer + ) if output_dim > 0: model += [linear_layer(hidden_sizes[-1], output_dim)] self.output_dim = output_dim or hidden_sizes[-1] @@ -161,8 +194,10 @@ def __init__( state_shape: Union[int, Sequence[int]], action_shape: Union[int, Sequence[int]] = 0, hidden_sizes: Sequence[int] = (), - norm_layer: Optional[ModuleType] = None, - activation: Optional[ModuleType] = nn.ReLU, + norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, + norm_args: Optional[ArgsType] = None, + activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU, + act_args: Optional[ArgsType] = None, device: Union[str, int, torch.device] = "cpu", softmax: bool = False, concat: bool = False, @@ -181,8 +216,8 @@ def __init__( self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 self.model = MLP( - input_dim, output_dim, hidden_sizes, norm_layer, activation, device, - linear_layer + input_dim, output_dim, hidden_sizes, norm_layer, norm_args, activation, + act_args, device, linear_layer ) self.output_dim = self.model.output_dim if self.use_dueling: # dueling DQN @@ -406,7 +441,9 @@ def __init__( value_hidden_sizes: List[int] = [], action_hidden_sizes: List[int] = [], norm_layer: Optional[ModuleType] = None, + norm_args: Optional[ArgsType] = None, activation: Optional[ModuleType] = nn.ReLU, + act_args: Optional[ArgsType] = None, device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__() @@ -418,14 +455,14 @@ def __init__( common_output_dim = 0 self.common = MLP( common_input_dim, common_output_dim, common_hidden_sizes, norm_layer, - activation, device + norm_args, activation, act_args, device ) # value network value_input_dim = common_hidden_sizes[-1] value_output_dim = 1 self.value = MLP( value_input_dim, value_output_dim, value_hidden_sizes, norm_layer, - activation, device + norm_args, activation, act_args, device ) # action branching network action_input_dim = common_hidden_sizes[-1] @@ -434,7 +471,7 @@ def __init__( [ MLP( action_input_dim, action_output_dim, action_hidden_sizes, - norm_layer, activation, device + norm_layer, norm_args, activation, act_args, device ) for _ in range(self.num_branches) ] )