这是indexloc提供的服务,不要输入任何密码
Skip to content

Support deterministic evaluation for onpolicy algorithms #354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/continuous/test_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def dist(*logits):
gae_lambda=args.gae_lambda,
action_space=env.action_space,
optim_critic_iters=args.optim_critic_iters,
actor_step_size=args.actor_step_size)
actor_step_size=args.actor_step_size,
deterministic_eval=True)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
2 changes: 1 addition & 1 deletion test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def stop_fn(mean_rewards):
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(
net, optim, mode='continuous', action_space=env.action_space,
net, optim, action_space=env.action_space,
action_scaling=True, action_bound_method="clip")
il_test_collector = Collector(
il_policy,
Expand Down
2 changes: 1 addition & 1 deletion test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def stop_fn(mean_rewards):
device=args.device)
net = Actor(net, args.action_shape, device=args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='discrete')
il_policy = ImitationPolicy(net, optim, action_space=env.action_space)
il_test_collector = Collector(
il_policy,
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def test_ppo(args=get_args()):
reward_normalization=args.rew_norm,
dual_clip=args.dual_clip,
value_clip=args.value_clip,
action_space=env.action_space)
action_space=env.action_space,
deterministic_eval=True)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
6 changes: 6 additions & 0 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numba import njit
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Union, Optional, Callable
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary

from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy

Expand Down Expand Up @@ -66,6 +67,11 @@ def __init__(
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self.action_type = ""
if isinstance(action_space, (Discrete, MultiDiscrete, MultiBinary)):
self.action_type = "discrete"
elif isinstance(action_space, Box):
self.action_type = "continuous"
self.agent_id = 0
self.updating = False
self.action_scaling = action_scaling
Expand Down
15 changes: 6 additions & 9 deletions tianshou/policy/imitation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ class ImitationPolicy(BasePolicy):
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: for optimizing the model.
:param str mode: indicate the imitation type ("continuous" or "discrete"
action space). Default to "continuous".
:param gym.Space action_space: env's action space.

.. seealso::

Expand All @@ -26,15 +25,13 @@ def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
mode: str = "continuous",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
assert mode in ["continuous", "discrete"], \
f"Mode {mode} is not in ['continuous', 'discrete']."
self.mode = mode
assert self.action_type in ["continuous", "discrete"], \
"Please specify action_space."

def forward(
self,
Expand All @@ -43,19 +40,19 @@ def forward(
**kwargs: Any,
) -> Batch:
logits, h = self.model(batch.obs, state=state, info=batch.info)
if self.mode == "discrete":
if self.action_type == "discrete":
a = logits.max(dim=1)[1]
else:
a = logits
return Batch(logits=logits, act=a, state=h)

def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.optim.zero_grad()
if self.mode == "continuous": # regression
if self.action_type == "continuous": # regression
a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_) # type: ignore
elif self.mode == "discrete": # classification
elif self.action_type == "discrete": # classification
a = F.log_softmax(self(batch).logits, dim=-1)
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_) # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class A2CPolicy(PGPolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class NPGPolicy(A2CPolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
"""

def __init__(
Expand Down
14 changes: 12 additions & 2 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, Dict, List, Type, Union, Optional

from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.utils import RunningMeanStd
from tianshou.data import Batch, ReplayBuffer, to_torch_as


class PGPolicy(BasePolicy):
Expand All @@ -25,6 +25,8 @@ class PGPolicy(BasePolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.

.. seealso::

Expand All @@ -42,6 +44,7 @@ def __init__(
action_scaling: bool = True,
action_bound_method: str = "clip",
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
deterministic_eval: bool = False,
**kwargs: Any,
) -> None:
super().__init__(action_scaling=action_scaling,
Expand All @@ -55,6 +58,7 @@ def __init__(
self._rew_norm = reward_normalization
self.ret_rms = RunningMeanStd()
self._eps = 1e-8
self._deterministic_eval = deterministic_eval

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
Expand Down Expand Up @@ -103,7 +107,13 @@ def forward(
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
act = dist.sample()
if self._deterministic_eval and not self.training:
if self.action_type == "discrete":
act = logits.argmax(-1)
elif self.action_type == "continuous":
act = logits[0]
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)

def learn( # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class PPOPolicy(A2CPolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class TRPOPolicy(NPGPolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
"""

def __init__(
Expand Down