这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
+ [ ] RL algorithm bug
+ [ ] documentation request (i.e. "X is missing from the documentation.")
+ [ ] new feature request
+ [ ] design request (i.e. "X should be changed to Y.")
- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou/)
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
Expand Down
18 changes: 9 additions & 9 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
- [ ] I have marked all applicable categories:
+ [ ] exception-raising fix
+ [ ] algorithm implementation fix
+ [ ] documentation modification
+ [ ] new feature
- [ ] I have reformatted the code using `make format` (**required**)
- [ ] I have checked the code using `make commit-checks` (**required**)
- [ ] If applicable, I have mentioned the relevant/related issue(s)
- [ ] If applicable, I have listed every items in this Pull Request below
- [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s)
- [ ] I have provided a description of the changes in this Pull Request
- [ ] I have added documentation for my changes
- [ ] If applicable, I have added tests to cover my changes.
- [ ] I have reformatted the code using `poe format`
- [ ] I have checked style and types with `poe lint` and `poe type-check`
- [ ] (Optional) I ran tests locally with `poe test`
(or a subset of them with `poe test-reduced`) ,and they pass
- [ ] (Optional) I have tested that documentation builds correctly with `poe doc-build`
6 changes: 3 additions & 3 deletions examples/vizdoom/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_n
print(env.spec.reward_threshold)
print(obs.shape, action_num)
for _ in range(4000):
obs, rew, done, info = env.step(0)
if done:
obs, rew, terminated, truncated, info = env.step(0)
if terminated or truncated:
env.reset()
print(obs.shape, rew, done)
print(obs.shape, rew, terminated, truncated)
cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])
71 changes: 71 additions & 0 deletions test/base/test_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Independent, Normal

from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor

obs_shape = (5,)


def _to_hashable(x: np.ndarray | int):
return x if isinstance(x, int) else tuple(x.tolist())


@pytest.fixture(params=["continuous", "discrete"])
def policy(request):
action_type = request.param
if action_type == "continuous":
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
actor = ActorProb(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape,
)
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
elif action_type == "discrete":
action_space = gym.spaces.Discrete(3)
actor = Actor(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
action_shape=action_space.n,
)
dist_fn = lambda logits: Categorical(logits=logits)
else:
raise ValueError(f"Unknown action type: {action_type}")

critic = Critic(
Net(obs_shape, hidden_sizes=[64, 64]),
)

actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)

policy = PPOPolicy(
actor=actor,
critic=critic,
dist_fn=dist_fn,
optim=optim,
action_space=action_space,
action_scaling=False,
)
policy.eval()
return policy


class TestPolicyBasics:
def test_get_action(self, policy):
sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)]
assert all(policy.action_space.contains(a) for a in actions)

# check that the actions are different in non-deterministic mode
assert len(set(map(_to_hashable, actions))) > 1

policy.deterministic_eval = True
actions = [policy.compute_action(sample_obs) for _ in range(10)]
# check that the actions are the same in deterministic mode
assert len(set(map(_to_hashable, actions))) == 1
26 changes: 20 additions & 6 deletions tianshou/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@
from tianshou.data.batch import BatchProtocol, arr_type


class RolloutBatchProtocol(BatchProtocol):
"""Typically, the outcome of sampling from a replay buffer."""
class ObsBatchProtocol(BatchProtocol):
"""Observations of an environment that a policy can turn into actions.

Typically used inside a policy's forward
"""

obs: arr_type | BatchProtocol
info: arr_type


class RolloutBatchProtocol(ObsBatchProtocol):
"""Typically, the outcome of sampling from a replay buffer."""

obs_next: arr_type | BatchProtocol
act: arr_type
rew: np.ndarray
terminated: arr_type
truncated: arr_type
info: arr_type


class BatchWithReturnsProtocol(RolloutBatchProtocol):
Expand All @@ -39,11 +47,17 @@ class RecurrentStateBatch(BatchProtocol):
class ActBatchProtocol(BatchProtocol):
"""Simplest batch, just containing the action. Useful e.g., for random policy."""

act: np.ndarray
act: arr_type


class ActStateBatchProtocol(ActBatchProtocol):
"""Contains action and state (which can be None), useful for policies that can support RNNs."""

state: dict | BatchProtocol | np.ndarray | None


class ModelOutputBatchProtocol(ActBatchProtocol):
"""Contains model output: (logits) and potentially hidden states."""
class ModelOutputBatchProtocol(ActStateBatchProtocol):
"""In addition to state and action, contains model output: (logits)."""

logits: torch.Tensor
state: dict | BatchProtocol | np.ndarray | None
Expand Down
2 changes: 2 additions & 0 deletions tianshou/data/utils/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from tianshou.data.batch import Batch, _parse_value


# TODO: confusing name, could actually return a batch...
# Overrides and generic types should be added
@no_type_check
def to_numpy(x: Any) -> Batch | np.ndarray:
"""Return an object without torch.Tensor."""
Expand Down
97 changes: 62 additions & 35 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, Literal, TypeAlias, cast, overload
from typing import Any, Literal, TypeAlias, cast

import gymnasium as gym
import numpy as np
Expand All @@ -11,14 +11,18 @@
from torch import nn

from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
from tianshou.data.batch import BatchProtocol
from tianshou.data.batch import Batch, BatchProtocol, arr_type
from tianshou.data.buffer.base import TBuffer
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
from tianshou.data.types import (
ActBatchProtocol,
BatchWithReturnsProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.utils import MultipleLRSchedulers

logger = logging.getLogger(__name__)


TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers


Expand Down Expand Up @@ -149,13 +153,39 @@ def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None:
for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True):
tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)

def compute_action(
self,
obs: arr_type,
info: dict[str, Any] | None = None,
state: dict | BatchProtocol | np.ndarray | None = None,
) -> np.ndarray | int:
"""Get action as int (for discrete env's) or array (for continuous ones) from
an env's observation and info.

:param obs: observation from the gym's env.
:param info: information given by the gym's env.
:param state: the hidden state of RNN policy, used for recurrent policy.
:return: action as int (for discrete env's) or array (for continuous ones).
"""
# need to add empty batch dimension
obs = obs[None, :]
obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info))
act = self.forward(obs_batch, state=state).act.squeeze()
if isinstance(act, torch.Tensor):
act = act.detach().cpu().numpy()
act = self.map_action(act)
if isinstance(self.action_space, Discrete):
# could be an array of shape (), easier to just convert to int
act = int(act) # type: ignore
return act

@abstractmethod
def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> BatchProtocol:
) -> ActBatchProtocol:
"""Compute action over the given batch data.

:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
Expand Down Expand Up @@ -190,22 +220,19 @@ def forward(
act = policy.map_action(act, batch)
"""

@overload
def map_action(self, act: BatchProtocol) -> BatchProtocol:
...

@overload
def map_action(self, act: np.ndarray) -> np.ndarray:
...

@overload
def map_action(self, act: torch.Tensor) -> torch.Tensor:
...
@staticmethod
def _action_to_numpy(act: arr_type) -> np.ndarray:
act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch
if not isinstance(act, np.ndarray):
raise ValueError(
f"act should have been be a numpy.ndarray, but got {type(act)}.",
)
return act

def map_action(
self,
act: BatchProtocol | np.ndarray | torch.Tensor,
) -> BatchProtocol | np.ndarray | torch.Tensor:
act: arr_type,
) -> np.ndarray:
"""Map raw network output to action range in gym's env.action_space.

This function is called in :meth:`~tianshou.data.Collector.collect` and only
Expand All @@ -223,24 +250,24 @@ def map_action(
:return: action in the same form of input "act" but remap to the target action
space.
"""
if isinstance(self.action_space, gym.spaces.Box) and isinstance(act, np.ndarray):
# currently this action mapping only supports np.ndarray action
act = self._action_to_numpy(act)
if isinstance(self.action_space, gym.spaces.Box):
if self.action_bound_method == "clip":
act = np.clip(act, -1.0, 1.0)
elif self.action_bound_method == "tanh":
act = np.tanh(act)
if self.action_scaling:
assert (
np.min(act) >= -1.0 and np.max(act) <= 1.0 # type: ignore
np.min(act) >= -1.0 and np.max(act) <= 1.0
), f"action scaling only accepts raw action range = [-1, 1], but got: {act}"
low, high = self.action_space.low, self.action_space.high
act = low + (high - low) * (act + 1.0) / 2.0 # type: ignore
act = low + (high - low) * (act + 1.0) / 2.0
return act

def map_action_inverse(
self,
act: BatchProtocol | list | np.ndarray,
) -> BatchProtocol | list | np.ndarray:
act: arr_type,
) -> np.ndarray:
"""Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`.

This function is called in :meth:`~tianshou.data.Collector.collect` for
Expand All @@ -252,17 +279,17 @@ def map_action_inverse(

:return: action remapped.
"""
act = self._action_to_numpy(act)
if isinstance(self.action_space, gym.spaces.Box):
act = to_numpy(act)
if isinstance(act, np.ndarray):
if self.action_scaling:
low, high = self.action_space.low, self.action_space.high
scale = high - low
eps = np.finfo(np.float32).eps.item()
scale[scale < eps] += eps
act = (act - low) * 2.0 / scale - 1.0
if self.action_bound_method == "tanh":
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore
if self.action_scaling:
low, high = self.action_space.low, self.action_space.high
scale = high - low
eps = np.finfo(np.float32).eps.item()
scale[scale < eps] += eps
act = (act - low) * 2.0 / scale - 1.0
if self.action_bound_method == "tanh":
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0

return act

def process_buffer(self, buffer: TBuffer) -> TBuffer:
Expand Down
8 changes: 6 additions & 2 deletions tianshou/policy/imitation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from tianshou.data import Batch, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ModelOutputBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import (
ModelOutputBatchProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
)
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler

Expand Down Expand Up @@ -55,7 +59,7 @@ def __init__(

def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ModelOutputBatchProtocol:
Expand Down
10 changes: 5 additions & 5 deletions tianshou/policy/imitation/bcq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Any, Literal, Self
from typing import Any, Literal, Self, cast

import gymnasium as gym
import numpy as np
Expand All @@ -8,7 +8,7 @@

from tianshou.data import Batch, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import RolloutBatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils.net.continuous import VAE
Expand Down Expand Up @@ -112,10 +112,10 @@ def train(self, mode: bool = True) -> Self:

def forward(
self,
batch: RolloutBatchProtocol,
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
) -> ActBatchProtocol:
"""Compute action over the given batch data."""
# There is "obs" in the Batch
# obs_group: several groups. Each group has a state.
Expand All @@ -134,7 +134,7 @@ def forward(
max_indice = q1.argmax(0)
act_group.append(act[max_indice].cpu().data.numpy().flatten())
act_group = np.array(act_group)
return Batch(act=act_group)
return cast(ActBatchProtocol, Batch(act=act_group))

def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
Expand Down
Loading