From bfd4a60745f2abe9c4b8bcffe4d7fb251547f2aa Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 24 Jul 2020 08:01:04 +0800 Subject: [PATCH 1/8] dqn learn should keep eps=0 --- tianshou/policy/modelfree/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 162719939..9bdbd1191 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -157,7 +157,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() - q = self(batch).logits + q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() if hasattr(batch, 'update_weight'): From 7d73ca23cf2299a6cadda7d9dbd414ef98b81fc2 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 24 Jul 2020 08:41:44 +0800 Subject: [PATCH 2/8] add a warning in docs --- docs/tutorials/cheatsheet.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index bae20b19c..f70935e0c 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -48,6 +48,17 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]] venv = SubprocVectorEnv(env_fns) +.. warning:: + + If you use your own environment, please make sure the ``seed`` method is set up properly, e.g., + + :: + + def seed(self, seed): + np.random.seed(seed) + + Otherwise, the outputs of these envs will be the same with each other. + .. _preprocess_fn: Handle Batched Data Stream in Collector From e6719b22dd5a00ae4044a38da6247cb09af422a8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 24 Jul 2020 08:46:10 +0800 Subject: [PATCH 3/8] update the warning in docstring --- tianshou/env/basevecenv.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tianshou/env/basevecenv.py b/tianshou/env/basevecenv.py index 60394e3de..a2b694277 100644 --- a/tianshou/env/basevecenv.py +++ b/tianshou/env/basevecenv.py @@ -27,6 +27,17 @@ class BaseVectorEnv(ABC, gym.Env): obs, rew, done, info = envs.step([1] * 8) # step synchronously envs.render() # render all environments envs.close() # close all environments + + .. warning:: + + If you use your own environment, please make sure the ``seed`` method + is set up properly, e.g., + :: + + def seed(self, seed): + np.random.seed(seed) + + Otherwise, the outputs of these envs will be the same with each other. """ def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: From 025b6f5c92ccb89b8f19c894dcd70f284ca21c23 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 24 Jul 2020 10:37:27 +0800 Subject: [PATCH 4/8] fix #162 --- test/base/env.py | 14 +++++++++++--- tianshou/policy/modelfree/a2c.py | 7 +++++-- tianshou/policy/modelfree/pg.py | 6 +++++- tianshou/policy/modelfree/ppo.py | 7 ++++--- tianshou/utils/net/common.py | 5 ++++- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index b0962154f..8916d5fe0 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -1,18 +1,24 @@ import gym import time -from gym.spaces.discrete import Discrete +from gym.spaces import Discrete, MultiDiscrete, Box class MyTestEnv(gym.Env): """This is a "going right" task. The task is to go right ``size`` steps. """ - def __init__(self, size, sleep=0, dict_state=False, ma_rew=0): + def __init__(self, size, sleep=0, dict_state=False, ma_rew=0, + multidiscrete_action=False): self.size = size self.sleep = sleep self.dict_state = dict_state self.ma_rew = ma_rew - self.action_space = Discrete(2) + self._md_action = multidiscrete_action + self.observation_space = Box(shape=(1, ), low=0, high=size - 1) + if multidiscrete_action: + self.action_space = MultiDiscrete([2, 2]) + else: + self.action_space = Discrete(2) self.reset() def reset(self, state=0): @@ -32,6 +38,8 @@ def _get_dict_state(self): return {'index': self.index} if self.dict_state else self.index def step(self, action): + if self._md_action: + action = action[0] if self.done: raise ValueError('step after done !!!') if self.sleep > 0: diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 8b98f70c1..55c7b6482 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -108,8 +108,11 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, v = self.critic(b.obs).flatten() a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) - a_loss = -(dist.log_prob(a).flatten() * (r - v).detach() - ).mean() + log_prob = dist.log_prob(a) + # TODO: torch.movedim in version > 1.5.1 + log_prob = log_prob.permute( + *np.roll(range(len(log_prob.shape)), -1)) + a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 09f90eacc..91883338a 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -90,7 +90,11 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, dist = self(b).dist a = to_torch_as(b.act, dist.logits) r = to_torch_as(b.returns, dist.logits) - loss = -(dist.log_prob(a).flatten() * r).sum() + # TODO: torch.movedim in version > 1.5.1 + log_prob = dist.log_prob(a) + log_prob = log_prob.permute( + *np.roll(range(len(log_prob.shape)), -1)) + loss = -(log_prob * r).sum() loss.backward() self.optim.step() losses.append(loss.item()) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index bae6907c0..33bf1ed0e 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -132,7 +132,7 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, to_torch_as(b.act, v[0]))) batch.v = torch.cat(v, dim=0).flatten() # old value batch.act = to_torch_as(batch.act, v[0]) - batch.logp_old = torch.cat(old_log_prob, dim=0).flatten() + batch.logp_old = torch.cat(old_log_prob, dim=0) batch.returns = to_torch_as(batch.returns, v[0]) if self._rew_norm: mean, std = batch.returns.mean(), batch.returns.std() @@ -147,8 +147,9 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, for b in batch.split(batch_size): dist = self(b).dist value = self.critic(b.obs).flatten() - ratio = (dist.log_prob(b.act).flatten() - - b.logp_old).exp().float() + ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() + # TODO: torch.movedim in version > 1.5.1 + ratio = ratio.permute(*np.roll(range(len(ratio.shape)), -1)) surr1 = ratio * b.adv surr2 = ratio.clamp(1. - self._eps_clip, 1. + self._eps_clip) * b.adv diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index dc85ac5c9..7a8ef174f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -36,7 +36,10 @@ def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', def forward(self, s, state=None, info={}): """s -> flatten -> logits""" s = to_torch(s, device=self.device, dtype=torch.float32) - s = s.flatten(1) + if len(s.shape) > 1: + s = s.flatten(1) + else: + s = s.unsqueeze(-1) logits = self.model(s) return logits, state From ebc6f6254fc81c5935e18def7fb4700b152023d1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 24 Jul 2020 11:52:51 +0800 Subject: [PATCH 5/8] add test of vecenv seed --- test/base/env.py | 7 ++++++- test/base/test_batch.py | 2 +- test/base/test_collector.py | 4 ++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 8916d5fe0..40f0597fe 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -1,5 +1,6 @@ import gym import time +import numpy as np from gym.spaces import Discrete, MultiDiscrete, Box @@ -21,6 +22,9 @@ def __init__(self, size, sleep=0, dict_state=False, ma_rew=0, self.action_space = Discrete(2) self.reset() + def seed(self, seed=0): + np.random.seed(seed) + def reset(self, state=0): self.done = False self.index = state @@ -35,7 +39,8 @@ def _get_reward(self): def _get_dict_state(self): """Generate a dict_state if dict_state is True.""" - return {'index': self.index} if self.dict_state else self.index + return {'index': self.index, 'rand': np.random.rand()} \ + if self.dict_state else self.index def step(self, action): if self._md_action: diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 77d1343ef..a823491b4 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -1,5 +1,5 @@ -import torch import copy +import torch import pickle import pytest import numpy as np diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 9fa37b66f..b4a69c1ea 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -113,8 +113,12 @@ def test_collector_with_dict_state(): env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) + envs.seed(666) + obs = envs.reset() + assert not np.isclose(obs[0]['rand'], obs[1]['rand']) c1 = Collector(policy, envs, ReplayBuffer(size=100), Logger.single_preprocess_fn) + c1.seed(0) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) batch = c1.sample(10) From f7da4206ddc300cabf285d84aea42e3647465f47 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 24 Jul 2020 14:20:01 +0800 Subject: [PATCH 6/8] fix to reshape+transpose --- docs/tutorials/cheatsheet.rst | 2 +- docs/tutorials/dqn.rst | 11 +++++++++++ tianshou/env/basevecenv.py | 2 +- tianshou/policy/modelfree/a2c.py | 6 ++---- tianshou/policy/modelfree/pg.py | 6 ++---- tianshou/policy/modelfree/ppo.py | 3 +-- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index f70935e0c..44e109aa0 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -57,7 +57,7 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte def seed(self, seed): np.random.seed(seed) - Otherwise, the outputs of these envs will be the same with each other. + Otherwise, the outputs of these envs may be the same with each other. .. _preprocess_fn: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 9cbb2434f..6106ca507 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -40,6 +40,17 @@ Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_ For the demonstration, here we use the second block of codes. +.. warning:: + + If you use your own environment, please make sure the ``seed`` method is set up properly, e.g., + + :: + + def seed(self, seed): + np.random.seed(seed) + + Otherwise, the outputs of these envs may be the same with each other. + .. _build_the_network: Build the Network diff --git a/tianshou/env/basevecenv.py b/tianshou/env/basevecenv.py index a2b694277..b6c160dab 100644 --- a/tianshou/env/basevecenv.py +++ b/tianshou/env/basevecenv.py @@ -37,7 +37,7 @@ class BaseVectorEnv(ABC, gym.Env): def seed(self, seed): np.random.seed(seed) - Otherwise, the outputs of these envs will be the same with each other. + Otherwise, the outputs of these envs may be the same with each other. """ def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 55c7b6482..2a5c123c0 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -108,10 +108,8 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, v = self.critic(b.obs).flatten() a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) - log_prob = dist.log_prob(a) - # TODO: torch.movedim in version > 1.5.1 - log_prob = log_prob.permute( - *np.roll(range(len(log_prob.shape)), -1)) + log_prob = dist.log_prob(a).reshape( + r.shape[0], -1).transpose(0, 1) a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) ent_loss = dist.entropy().mean() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 91883338a..75a9bff3b 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -90,10 +90,8 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, dist = self(b).dist a = to_torch_as(b.act, dist.logits) r = to_torch_as(b.returns, dist.logits) - # TODO: torch.movedim in version > 1.5.1 - log_prob = dist.log_prob(a) - log_prob = log_prob.permute( - *np.roll(range(len(log_prob.shape)), -1)) + log_prob = dist.log_prob(a).reshape( + r.shape[0], -1).transpose(0, 1) loss = -(log_prob * r).sum() loss.backward() self.optim.step() diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 33bf1ed0e..2d1decea7 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -148,8 +148,7 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, dist = self(b).dist value = self.critic(b.obs).flatten() ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() - # TODO: torch.movedim in version > 1.5.1 - ratio = ratio.permute(*np.roll(range(len(ratio.shape)), -1)) + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv surr2 = ratio.clamp(1. - self._eps_clip, 1. + self._eps_clip) * b.adv From da5f9ecac4017bc622f683b6131d55c24ddb54d4 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 24 Jul 2020 15:12:46 +0800 Subject: [PATCH 7/8] change sum to mean in pg --- test/discrete/test_pg.py | 2 +- tianshou/policy/modelfree/pg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index c82076204..fabfdc9aa 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -86,7 +86,7 @@ def get_args(): parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 75a9bff3b..d6176e68d 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -92,7 +92,7 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, r = to_torch_as(b.returns, dist.logits) log_prob = dist.log_prob(a).reshape( r.shape[0], -1).transpose(0, 1) - loss = -(log_prob * r).sum() + loss = -(log_prob * r).mean() loss.backward() self.optim.step() losses.append(loss.item()) From 6ca97ab921d6891e67e0e3bd2a4a9facd0a248dc Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 24 Jul 2020 15:57:36 +0800 Subject: [PATCH 8/8] change flatten to reshape --- tianshou/utils/net/common.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 7a8ef174f..2401ebdea 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -36,10 +36,7 @@ def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', def forward(self, s, state=None, info={}): """s -> flatten -> logits""" s = to_torch(s, device=self.device, dtype=torch.float32) - if len(s.shape) > 1: - s = s.flatten(1) - else: - s = s.unsqueeze(-1) + s = s.reshape(s.size(0), -1) logits = self.model(s) return logits, state