diff --git a/CHANGELOG.md b/CHANGELOG.md index 886c68f40..2d03d6676 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ - policy: - introduced attribute `in_training_step` that is controlled by the trainer. #1123 - policy automatically set to `eval` mode when collecting and to `train` mode when updating. #1123 + - Extended interface of `compute_action` to also support array-like inputs #1169 - `highlevel`: - `SamplingConfig`: - Add support for `batch_size=None`. #1077 diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index b7ae5f23d..c4fc9af3b 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -10,6 +10,7 @@ import torch from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete from numba import njit +from numpy.typing import ArrayLike from overrides import override from torch import nn @@ -289,7 +290,7 @@ def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: def compute_action( self, - obs: arr_type, + obs: ArrayLike, info: dict[str, Any] | None = None, state: dict | BatchProtocol | np.ndarray | None = None, ) -> np.ndarray | int: @@ -300,8 +301,8 @@ def compute_action( :param state: the hidden state of RNN policy, used for recurrent policy. :return: action as int (for discrete env's) or array (for continuous ones). """ - # need to add empty batch dimension - obs = obs[None, :] + obs = np.array(obs) # convert array-like to array (e.g. LazyFrames) + obs = obs[None, :] # add batch dimension obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) act = self.forward(obs_batch, state=state).act.squeeze() if isinstance(act, torch.Tensor):