From ea6250018097ba4118ce58152394bd271904eee1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 20 Aug 2020 08:48:38 +0800 Subject: [PATCH 01/60] collect fake data when buffer is None in Collector --- tianshou/data/buffer.py | 20 ++++++++++---------- tianshou/data/collector.py | 6 +++++- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 24aeb0bcc..0ed135409 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -198,12 +198,12 @@ def update(self, buffer: 'ReplayBuffer') -> None: buffer.stack_num = stack_num_orig def add(self, - obs: Union[dict, Batch, np.ndarray], - act: Union[np.ndarray, float], + obs: Union[dict, Batch, np.ndarray, float], + act: Union[dict, Batch, np.ndarray, float], rew: Union[int, float], - done: bool, - obs_next: Optional[Union[dict, Batch, np.ndarray]] = None, - info: dict = {}, + done: Union[bool, int], + obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None, + info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs) -> None: """Add a batch of data into replay buffer.""" @@ -393,12 +393,12 @@ def __getattr__(self, key: str) -> Union['Batch', Any]: return super().__getattr__(key) def add(self, - obs: Union[dict, np.ndarray], - act: Union[np.ndarray, float], + obs: Union[dict, Batch, np.ndarray, float], + act: Union[dict, Batch, np.ndarray, float], rew: Union[int, float], - done: bool, - obs_next: Optional[Union[dict, np.ndarray]] = None, - info: dict = {}, + done: Union[bool, int], + obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None, + info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, weight: float = None, **kwargs) -> None: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 268792a24..29eb34b66 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -282,7 +282,11 @@ def collect(self, for j, i in enumerate(self._ready_env_ids): # j is the index in current ready_env_ids # i is the index in all environments - self._cached_buf[i].add(**self.data[j]) + if self.buffer is not None: + self._cached_buf[i].add(**self.data[j]) + else: + self._cached_buf[i].add( + obs=0, act=0, rew=self.data.rew[j], done=0) # fakedata if self.data.done[j]: if n_step or np.isscalar(n_episode) or \ episode_count[i] < n_episode[i]: From b8a797c5f46db0ab0b1d118f14b3358989e99b5d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 20 Aug 2020 15:48:57 +0800 Subject: [PATCH 02/60] add env_id in info for all environments --- tianshou/env/venvs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 504d3e196..7e0f9c17b 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -182,7 +182,11 @@ def step(self, assert len(action) == len(id) for i, j in enumerate(id): self.workers[j].send_action(action[i]) - result = [self.workers[j].get_result() for j in id] + result = [] + for j in id: + obs, rew, done, info = self.workers[j].get_result() + info["env_id"] = j + result.append((obs, rew, done, info)) else: if action is not None: self._assert_id(id) From d47c38d3187ce82d9a01fa5dc12fa5dd80806f50 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 20 Aug 2020 17:56:33 +0800 Subject: [PATCH 03/60] fix test in collector preprocess_fn --- test/base/test_collector.py | 13 +++++++------ tianshou/policy/multiagent/mapolicy.py | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 1026c9407..1117d592e 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -42,14 +42,16 @@ def __init__(self, writer): def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb - # if info is not provided from env, it will be a ``Batch()``. - if not kwargs.get('info', Batch()).is_empty(): + # if obs/act/rew/done/... exist -> normal step + # if only obs exist -> reset + if 'rew' in kwargs: n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) - self.writer.add_scalar('key', np.mean( - info['key']), global_step=self.cnt) + if 'key' in info: + self.writer.add_scalar('key', np.mean( + info['key']), global_step=self.cnt) self.cnt += 1 return Batch(info=info) # or: return {'info': info} @@ -59,13 +61,12 @@ def preprocess_fn(self, **kwargs): @staticmethod def single_preprocess_fn(**kwargs): # same as above, without tfb - if not kwargs.get('info', Batch()).is_empty(): + if 'rew' in kwargs: n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) return Batch(info=info) - # or: return {'info': info} else: return Batch() diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index f6329888d..c0d991d63 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -64,12 +64,12 @@ def forward(self, batch: Batch, { "act": actions corresponding to the input - "state":{ + "state": { "agent_1": output state of agent_1's policy for the state "agent_2": xxx ... "agent_n": xxx} - "out":{ + "out": { "agent_1": output of agent_1's policy for the input "agent_2": xxx ... From 43cbed745c03876469ce5c1fd62fa4bec75cb439 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 20 Aug 2020 19:47:04 +0800 Subject: [PATCH 04/60] potential bugfix for subproc.wait --- tianshou/env/worker/subproc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 6ba108eba..3fc992947 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -165,11 +165,12 @@ def wait(workers: List['SubprocEnvWorker'], break else: remain_time = timeout - remain_conns = [conn for conn in remain_conns - if conn not in ready_conns] + # connection.wait hangs if the list is empty new_ready_conns = connection.wait( remain_conns, timeout=remain_time) ready_conns.extend(new_ready_conns) + remain_conns = [conn for conn in remain_conns + if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] def send_action(self, action: np.ndarray) -> None: From cea56cbe18801f689823665b0f13c610908b88bd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 20 Aug 2020 20:25:53 +0800 Subject: [PATCH 05/60] add steps count for test env; copy data for list buffer --- test/base/env.py | 3 +++ tianshou/data/buffer.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/base/env.py b/test/base/env.py index cc0991072..f0907f8e1 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -21,6 +21,8 @@ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, self.recurse_state = recurse_state self.ma_rew = ma_rew self._md_action = multidiscrete_action + # how many steps this env has stepped + self.steps = 0 if dict_state: self.observation_space = Dict( {"index": Box(shape=(1, ), low=0, high=size - 1), @@ -74,6 +76,7 @@ def _get_state(self): return np.array([self.index], dtype=np.float32) def step(self, action): + self.steps += 1 if self._md_action: action = action[0] if self.done: diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 0ed135409..7672928f8 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,5 +1,6 @@ import torch import numpy as np +from copy import deepcopy from typing import Any, Tuple, Union, Optional from tianshou.data import Batch, SegmentTree, to_numpy @@ -355,7 +356,7 @@ def _add_to_buffer( return if self._meta.__dict__.get(name, None) is None: self._meta.__dict__[name] = [] - self._meta.__dict__[name].append(inst) + self._meta.__dict__[name].append(deepcopy(inst)) def reset(self) -> None: self._index = self._size = 0 From b0ae34c98a268d25c581e5cec7f87ff95c422b12 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 20 Aug 2020 20:27:55 +0800 Subject: [PATCH 06/60] enable exact n_episode for each env. --- test/base/test_collector.py | 23 ++++++++++++++++++++++ tianshou/data/collector.py | 39 ++++++++++++++++++++++++++++--------- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 1117d592e..8ccce89ac 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -120,6 +120,28 @@ def test_collector(): c2.collect(n_episode=[1, 1, 1, 1], random=True) +def test_collector_with_exact_episodes(): + env_lens = [2, 6, 3, 10] + writer = SummaryWriter('log/exact_collector') + logger = Logger(writer) + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True) + for i in env_lens] + + venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) + policy = MyPolicy() + c1 = Collector(policy, venv, + ReplayBuffer(size=1000, ignore_obs_next=False), + logger.preprocess_fn) + n_episode1 = [2, 2, 5, 1] + n_episode2 = [1, 3, 2, 4] + c1.collect(n_episode=n_episode1) + c1.collect(n_episode=n_episode2) + expected_steps = sum( + [a * (b + c) for a, b, c in zip(env_lens, n_episode1, n_episode2)]) + actual_steps = sum(venv.steps) + assert expected_steps == actual_steps + + def test_collector_with_async(): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') @@ -242,3 +264,4 @@ def reward_metric(x): test_collector_with_dict_state() test_collector_with_ma() test_collector_with_async() + test_collector_with_exact_episodes() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 29eb34b66..45898d8ef 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -208,6 +208,10 @@ def collect(self, step_count = 0 # episode of each environment episode_count = np.zeros(self.env_num) + # If n_episode is a list, and some envs have collected the required + # number of episodes, these envs will be recorded in this list, + # and they will not be stepped. + finished_env_ids = [] reward_total = 0.0 whole_data = Batch() while True: @@ -217,11 +221,13 @@ def collect(self, 'You should add a time limitation to your environment!', Warning) - if self.is_async: + is_async = self.is_async or len(finished_env_ids) > 0 + if is_async: # self.data are the data for all environments - # in async simulation, only a subset of data are disposed + # in async simulation or some envs have finished, + # **only a subset of data are disposed** # so we store the whole data in ``whole_data``, let self.data - # to be all the data available in ready environments, and + # to be the data available in ready environments, and # finally set these back into all the data whole_data = self.data self.data = self.data[self._ready_env_ids] @@ -256,7 +262,7 @@ def collect(self, self.data.act += self._action_noise(self.data.act.shape) # step in env - if not self.is_async: + if not is_async: obs_next, rew, done, info = self.env.step(self.data.act) else: # store computed actions, states, etc @@ -282,11 +288,15 @@ def collect(self, for j, i in enumerate(self._ready_env_ids): # j is the index in current ready_env_ids # i is the index in all environments - if self.buffer is not None: - self._cached_buf[i].add(**self.data[j]) - else: + if self.buffer is None: + # users do not want to store data + # so we store small fake data here + # to make the code clean self._cached_buf[i].add( - obs=0, act=0, rew=self.data.rew[j], done=0) # fakedata + obs=0, act=0, rew=self.data.rew[j], done=0) + else: + self._cached_buf[i].add(**self.data[j]) + if self.data.done[j]: if n_step or np.isscalar(n_episode) or \ episode_count[i] < n_episode[i]: @@ -295,6 +305,11 @@ def collect(self, step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) + if not n_step and not np.isscalar(n_episode) and \ + episode_count[i] >= n_episode[i]: + # env i has collected enough data + # it has finished + finished_env_ids.append(i) self._cached_buf[i].reset() self._reset_state(j) obs_next = self.data.obs_next @@ -308,12 +323,14 @@ def collect(self, else: obs_next[env_ind_local] = obs_reset self.data.obs = obs_next - if self.is_async: + if is_async: # set data back _batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num) # let self.data be the data in all environments again self.data = whole_data + self._ready_env_ids = np.asarray( + [x for x in self._ready_env_ids if x not in finished_env_ids]) if n_step: if step_count >= n_step: break @@ -325,6 +342,10 @@ def collect(self, (episode_count >= n_episode).all(): break + # finished envs are ready, and can be used for the next collection + self._ready_env_ids = np.asarray( + self._ready_env_ids.tolist() + finished_env_ids) + # generate the statistics episode_count = sum(episode_count) duration = max(time.time() - start_time, 1e-9) From 6725bfc322f59eabf409bdfb9e8d22736e818eeb Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 20 Aug 2020 21:05:47 +0800 Subject: [PATCH 07/60] .keys() --- test/base/test_collector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 8ccce89ac..80a61e603 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -44,12 +44,12 @@ def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb # if obs/act/rew/done/... exist -> normal step # if only obs exist -> reset - if 'rew' in kwargs: + if 'rew' in kwargs.keys(): n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) - if 'key' in info: + if 'key' in info.keys(): self.writer.add_scalar('key', np.mean( info['key']), global_step=self.cnt) self.cnt += 1 @@ -61,7 +61,7 @@ def preprocess_fn(self, **kwargs): @staticmethod def single_preprocess_fn(**kwargs): # same as above, without tfb - if 'rew' in kwargs: + if 'rew' in kwargs.keys(): n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): From 9489e7aa33d6f671ccaeda765b161df5df2669cc Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 20 Aug 2020 21:21:39 +0800 Subject: [PATCH 08/60] test __contains__ --- tianshou/data/batch.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index ae07023ea..38479d90c 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -173,11 +173,15 @@ def __init__(self, if len(kwargs) > 0: self.__init__(kwargs, copy=copy) - def __setattr__(self, key: str, value: Any): + def __contains__(self, key: str) -> bool: + """Return key in self.""" + return key in self.__dict__ + + def __setattr__(self, key: str, value: Any) -> None: """self.key = value""" self.__dict__[key] = _parse_value(value) - def __getstate__(self): + def __getstate__(self) -> dict: """Pickling interface. Only the actual data are serialized for both efficiency and simplicity. """ @@ -188,7 +192,7 @@ def __getstate__(self): state[k] = v return state - def __setstate__(self, state): + def __setstate__(self, state) -> None: """Unpickling interface. At this point, self is an empty Batch instance that has not been initialized, so it can safely be initialized by the pickle state. From 4d943e760c663d3905c171edd1a663f3d82caf66 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 08:31:44 +0800 Subject: [PATCH 09/60] fix atari test --- examples/atari/pong_dqn.py | 38 +++++++++++++++++++++++++++++--------- tianshou/data/batch.py | 10 +++++----- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/examples/atari/pong_dqn.py b/examples/atari/pong_dqn.py index 6dda89400..8124607de 100644 --- a/examples/atari/pong_dqn.py +++ b/examples/atari/pong_dqn.py @@ -15,15 +15,16 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pong') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0.05) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--n-step', type=int, default=1) - parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps_test', type=float, default=0.005) + parser.add_argument('--eps_train', type=float, default=0.5) + parser.add_argument('--eps_train_final', type=float, default=0.05) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--n_step', type=int, default=3) + parser.add_argument('--target_update_freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) @@ -88,6 +89,25 @@ def train_fn(x): def test_fn(x): policy.set_eps(args.eps_test) + # watch agent's performance + def watch(): + print("Testing agent ...") + policy.eval() + policy.set_eps(args.eps_test) + envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + envs.seed(args.seed) + collector = Collector(policy, envs) + result = collector.collect(n_episode=[1] * args.test_num, + render=args.render) + pprint.pprint(result) + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * 4) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 38479d90c..723a537de 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -173,10 +173,6 @@ def __init__(self, if len(kwargs) > 0: self.__init__(kwargs, copy=copy) - def __contains__(self, key: str) -> bool: - """Return key in self.""" - return key in self.__dict__ - def __setattr__(self, key: str, value: Any) -> None: """self.key = value""" self.__dict__[key] = _parse_value(value) @@ -302,7 +298,7 @@ def __repr__(self) -> str: """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False - for k, v in self.items(): + for k, v in self.__dict__.items(): rpl = '\n' + ' ' * (6 + len(k)) obj = pprint.pformat(v).replace('\n', rpl) s += f' {k}: {obj},\n' @@ -313,6 +309,10 @@ def __repr__(self) -> str: s = self.__class__.__name__ + '()' return s + def __contains__(self, key: str) -> bool: + """Return key in self.""" + return key in self.__dict__ + def keys(self) -> List[str]: """Return self.keys().""" return self.__dict__.keys() From c1e4fbd65437ac78c533f1a7f52d79bb8935a362 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 09:40:29 +0800 Subject: [PATCH 10/60] move deepcopy to collector (whole_data inplace modification cause ListBuffer still keep reference) --- test/base/test_collector.py | 3 +++ tianshou/data/buffer.py | 3 +-- tianshou/data/collector.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 80a61e603..dd2d92ad2 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -135,6 +135,9 @@ def test_collector_with_exact_episodes(): n_episode1 = [2, 2, 5, 1] n_episode2 = [1, 3, 2, 4] c1.collect(n_episode=n_episode1) + expected_steps = sum([a * b for a, b in zip(env_lens, n_episode1)]) + actual_steps = sum(venv.steps) + assert expected_steps == actual_steps c1.collect(n_episode=n_episode2) expected_steps = sum( [a * (b + c) for a, b, c in zip(env_lens, n_episode1, n_episode2)]) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 7672928f8..0ed135409 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,6 +1,5 @@ import torch import numpy as np -from copy import deepcopy from typing import Any, Tuple, Union, Optional from tianshou.data import Batch, SegmentTree, to_numpy @@ -356,7 +355,7 @@ def _add_to_buffer( return if self._meta.__dict__.get(name, None) is None: self._meta.__dict__[name] = [] - self._meta.__dict__[name].append(deepcopy(inst)) + self._meta.__dict__[name].append(inst) def reset(self) -> None: self._index = self._size = 0 diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 45898d8ef..5eb37ea0a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -3,6 +3,7 @@ import torch import warnings import numpy as np +from copy import deepcopy from typing import Any, Dict, List, Union, Optional, Callable from tianshou.env import BaseVectorEnv, DummyVectorEnv @@ -325,6 +326,7 @@ def collect(self, self.data.obs = obs_next if is_async: # set data back + whole_data = deepcopy(whole_data) # avoid reference in ListBuf _batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num) # let self.data be the data in all environments again From 4516b9083de8cedad41d2ea8009abf050d9aa6ca Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 11:40:57 +0800 Subject: [PATCH 11/60] bypsas the attr check for batch.weight, test_dqn training fps 1870 -> 1910 --- tianshou/data/batch.py | 5 ++--- tianshou/data/collector.py | 8 +++----- tianshou/policy/base.py | 4 ++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 723a537de..b86126674 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -101,9 +101,7 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ dtype=target_type) elif isinstance(inst, torch.Tensor): return torch.full(shape, - fill_value=0, - device=inst.device, - dtype=inst.dtype) + fill_value=0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): @@ -155,6 +153,7 @@ class Batch: For a detailed description, please refer to :ref:`batch_concept`. """ + def __init__(self, batch_dict: Optional[Union[ dict, 'Batch', Tuple[Union[dict, 'Batch']], diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 5eb37ea0a..e6ee4f8d2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -290,9 +290,8 @@ def collect(self, # j is the index in current ready_env_ids # i is the index in all environments if self.buffer is None: - # users do not want to store data - # so we store small fake data here - # to make the code clean + # users do not want to store data, so we store + # small fake data here to make the code clean self._cached_buf[i].add( obs=0, act=0, rew=self.data.rew[j], done=0) else: @@ -308,8 +307,7 @@ def collect(self, self.buffer.update(self._cached_buf[i]) if not n_step and not np.isscalar(n_episode) and \ episode_count[i] >= n_episode[i]: - # env i has collected enough data - # it has finished + # env i has collected enough data, it has finished finished_env_ids.append(i) self._cached_buf[i].reset() self._reset_state(j) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 01398ca3c..03ab4e36c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -221,8 +221,8 @@ def compute_nstep_return( # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): batch.weight = to_torch_as(batch.weight, target_q) - else: - batch.weight = torch.ones_like(target_q) + else: # avoid type check + batch.__dict__['weight'] = 1. return batch def post_process_fn(self, batch: Batch, From c8ca9e588111ad685e1c8b71445bb8a96ccff90d Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 11:55:07 +0800 Subject: [PATCH 12/60] change nstep to_torch and reach 1950+ (near v0.2.4.post1) --- tianshou/policy/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 03ab4e36c..240f61c25 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -213,11 +213,11 @@ def compute_nstep_return( returns[done[now] > 0] = 0 returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len - target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ] + target_q_ = target_q_fn(buffer, terminal).flatten() # shape: (bsz, ) + target_q = to_numpy(target_q_) target_q[gammas != n_step] = 0 - returns = to_torch_as(returns, target_q) - gammas = to_torch_as(gamma ** gammas, target_q) - batch.returns = target_q * gammas + returns + target_q = target_q * (gamma ** gammas) + returns + batch.returns = to_torch_as(target_q, target_q_) # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): batch.weight = to_torch_as(batch.weight, target_q) From c50de098bfb988e4313625debcb531aede75bb4b Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 12:10:04 +0800 Subject: [PATCH 13/60] fix a bug in per --- tianshou/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 240f61c25..5148cc943 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -220,7 +220,7 @@ def compute_nstep_return( batch.returns = to_torch_as(target_q, target_q_) # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): - batch.weight = to_torch_as(batch.weight, target_q) + batch.weight = to_torch_as(batch.weight, target_q_) else: # avoid type check batch.__dict__['weight'] = 1. return batch From b0e3cd549fec5b1b0dfad8c7e78f42707ac202c1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 13:27:25 +0800 Subject: [PATCH 14/60] batch.pop --- tianshou/data/batch.py | 8 +++++++- tianshou/policy/base.py | 2 -- tianshou/policy/modelfree/ddpg.py | 3 ++- tianshou/policy/modelfree/dqn.py | 3 ++- tianshou/policy/modelfree/sac.py | 5 +++-- tianshou/policy/modelfree/td3.py | 5 +++-- 6 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index b86126674..2541d580c 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -324,10 +324,16 @@ def items(self) -> List[Tuple[str, Any]]: """Return self.items().""" return self.__dict__.items() - def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: + def get(self, k: str, d: Optional[Any] = None) -> Any: """Return self[k] if k in self else d. d defaults to None.""" return self.__dict__.get(k, d) + def pop(self, k: str, d: Optional[Any] = None) -> Any: + """Return and remove self[k] if k in self else d. d defaults to + None. + """ + return self.__dict__.pop(k, d) + def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray. This is an in-place operation. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5148cc943..92895b3fa 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -221,8 +221,6 @@ def compute_nstep_return( # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): batch.weight = to_torch_as(batch.weight, target_q_) - else: # avoid type check - batch.__dict__['weight'] = 1. return batch def post_process_fn(self, batch: Batch, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 79a65d3bb..6c34e34ac 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -141,10 +141,11 @@ def forward(self, batch: Batch, return Batch(act=actions, state=h) def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + weight = batch.pop('weight', 1.) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q - critic_loss = (td.pow(2) * batch.weight).mean() + critic_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index f1a01a6e7..9d562f1f7 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -156,11 +156,12 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() + weight = batch.pop('weight', 1.) q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() td = r - q - loss = (td.pow(2) * batch.weight).mean() + loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 341fe7b11..dfbc60e05 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -137,11 +137,12 @@ def _target_q(self, buffer: ReplayBuffer, return target_q def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + weight = batch.pop('weight', 1.) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q - critic1_loss = (td1.pow(2) * batch.weight).mean() + critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() @@ -149,7 +150,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q - critic2_loss = (td2.pow(2) * batch.weight).mean() + critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 9a340950b..9150f3770 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -115,11 +115,12 @@ def _target_q(self, buffer: ReplayBuffer, return target_q def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + weight = batch.pop('weight', 1.) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q - critic1_loss = (td1.pow(2) * batch.weight).mean() + critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() @@ -127,7 +128,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q - critic2_loss = (td2.pow(2) * batch.weight).mean() + critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() From d52f4f328c2b51af774ed290b70d33a0485754cf Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 13:31:20 +0800 Subject: [PATCH 15/60] move previous script to runnable/ --- examples/atari/runnable/atari.py | 133 ++++++++++++++++++++++++++++ examples/atari/runnable/pong_a2c.py | 102 +++++++++++++++++++++ examples/atari/runnable/pong_ppo.py | 106 ++++++++++++++++++++++ 3 files changed, 341 insertions(+) create mode 100644 examples/atari/runnable/atari.py create mode 100644 examples/atari/runnable/pong_a2c.py create mode 100644 examples/atari/runnable/pong_ppo.py diff --git a/examples/atari/runnable/atari.py b/examples/atari/runnable/atari.py new file mode 100644 index 000000000..8e2ea5168 --- /dev/null +++ b/examples/atari/runnable/atari.py @@ -0,0 +1,133 @@ +import cv2 +import gym +import numpy as np +from gym.spaces.box import Box +from tianshou.data import Batch + +SIZE = 84 +FRAME = 4 + + +def create_atari_environment(name=None, sticky_actions=True, + max_episode_steps=2000): + game_version = 'v0' if sticky_actions else 'v4' + name = '{}NoFrameskip-{}'.format(name, game_version) + env = gym.make(name) + env = env.env + env = preprocessing(env, max_episode_steps=max_episode_steps) + return env + + +def preprocess_fn(obs=None, act=None, rew=None, done=None, + obs_next=None, info=None, policy=None, **kwargs): + if obs_next is not None: + obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:])) + obs_next = np.moveaxis(obs_next, 0, -1) + obs_next = cv2.resize(obs_next, (SIZE, SIZE)) + obs_next = np.asanyarray(obs_next, dtype=np.uint8) + obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE)) + obs_next = np.moveaxis(obs_next, 1, -1) + elif obs is not None: + obs = np.reshape(obs, (-1, *obs.shape[2:])) + obs = np.moveaxis(obs, 0, -1) + obs = cv2.resize(obs, (SIZE, SIZE)) + obs = np.asanyarray(obs, dtype=np.uint8) + obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE)) + obs = np.moveaxis(obs, 1, -1) + + return Batch(obs=obs, act=act, rew=rew, done=done, + obs_next=obs_next, info=info) + + +class preprocessing(object): + def __init__(self, env, frame_skip=4, terminal_on_life_loss=False, + size=84, max_episode_steps=2000): + self.max_episode_steps = max_episode_steps + self.env = env + self.terminal_on_life_loss = terminal_on_life_loss + self.frame_skip = frame_skip + self.size = size + self.count = 0 + obs_dims = self.env.observation_space + + self.screen_buffer = [ + np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), + np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8) + ] + + self.game_over = False + self.lives = 0 + + @property + def observation_space(self): + return Box(low=0, high=255, + shape=(self.size, self.size, self.frame_skip), + dtype=np.uint8) + + def action_space(self): + return self.env.action_space + + def reward_range(self): + return self.env.reward_range + + def metadata(self): + return self.env.metadata + + def close(self): + return self.env.close() + + def reset(self): + self.count = 0 + self.env.reset() + self.lives = self.env.ale.lives() + self._grayscale_obs(self.screen_buffer[0]) + self.screen_buffer[1].fill(0) + + return np.array([self._pool_and_resize() + for _ in range(self.frame_skip)]) + + def render(self, mode='human'): + return self.env.render(mode) + + def step(self, action): + total_reward = 0. + observation = [] + for t in range(self.frame_skip): + self.count += 1 + _, reward, terminal, info = self.env.step(action) + total_reward += reward + + if self.terminal_on_life_loss: + lives = self.env.ale.lives() + is_terminal = terminal or lives < self.lives + self.lives = lives + else: + is_terminal = terminal + + if is_terminal: + break + elif t >= self.frame_skip - 2: + t_ = t - (self.frame_skip - 2) + self._grayscale_obs(self.screen_buffer[t_]) + + observation.append(self._pool_and_resize()) + if len(observation) == 0: + observation = [self._pool_and_resize() + for _ in range(self.frame_skip)] + while len(observation) > 0 and \ + len(observation) < self.frame_skip: + observation.append(observation[-1]) + terminal = self.count >= self.max_episode_steps + return np.array(observation), total_reward, \ + (terminal or is_terminal), info + + def _grayscale_obs(self, output): + self.env.ale.getScreenGrayscale(output) + return output + + def _pool_and_resize(self): + if self.frame_skip > 1: + np.maximum(self.screen_buffer[0], self.screen_buffer[1], + out=self.screen_buffer[0]) + + return self.screen_buffer[0] diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py new file mode 100644 index 000000000..f4b0a3031 --- /dev/null +++ b/examples/atari/runnable/pong_a2c.py @@ -0,0 +1,102 @@ +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import A2CPolicy +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import onpolicy_trainer +from tianshou.data import Collector, ReplayBuffer +from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import Net + +from atari import create_atari_environment, preprocess_fn + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pong') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--repeat-per-collect', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=2) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=8) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + # a2c special + parser.add_argument('--vf-coef', type=float, default=0.5) + parser.add_argument('--ent-coef', type=float, default=0.001) + parser.add_argument('--max-grad-norm', type=float, default=None) + parser.add_argument('--max_episode_steps', type=int, default=2000) + return parser.parse_args() + + +def test_a2c(args=get_args()): + env = create_atari_environment(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.env.action_space.shape or env.env.action_space.n + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: create_atari_environment(args.task) + for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: create_atari_environment(args.task) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = Actor(net, args.action_shape).to(args.device) + critic = Critic(net).to(args.device) + optim = torch.optim.Adam(list( + actor.parameters()) + list(critic.parameters()), lr=args.lr) + dist = torch.distributions.Categorical + policy = A2CPolicy( + actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef, + ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size), + preprocess_fn=preprocess_fn) + test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) + # log + writer = SummaryWriter(args.logdir + '/' + 'a2c') + + def stop_fn(x): + if env.env.spec.reward_threshold: + return x >= env.spec.reward_threshold + else: + return False + + # trainer + result = onpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, + args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = create_atari_environment(args.task) + collector = Collector(policy, env, preprocess_fn=preprocess_fn) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +if __name__ == '__main__': + test_a2c() diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py new file mode 100644 index 000000000..9d5563fe1 --- /dev/null +++ b/examples/atari/runnable/pong_ppo.py @@ -0,0 +1,106 @@ +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import PPOPolicy +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import onpolicy_trainer +from tianshou.data import Collector, ReplayBuffer +from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import Net + +from atari import create_atari_environment, preprocess_fn + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pong') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--repeat-per-collect', type=int, default=2) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=1) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=8) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + # ppo special + parser.add_argument('--vf-coef', type=float, default=0.5) + parser.add_argument('--ent-coef', type=float, default=0.0) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--max_episode_steps', type=int, default=2000) + return parser.parse_args() + + +def test_ppo(args=get_args()): + env = create_atari_environment(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space().shape or env.action_space().n + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv([ + lambda: create_atari_environment(args.task) + for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv([ + lambda: create_atari_environment(args.task) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = Actor(net, args.action_shape).to(args.device) + critic = Critic(net).to(args.device) + optim = torch.optim.Adam(list( + actor.parameters()) + list(critic.parameters()), lr=args.lr) + dist = torch.distributions.Categorical + policy = PPOPolicy( + actor, critic, optim, dist, args.gamma, + max_grad_norm=args.max_grad_norm, + eps_clip=args.eps_clip, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + action_range=None) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size), + preprocess_fn=preprocess_fn) + test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) + # log + writer = SummaryWriter(args.logdir + '/' + 'ppo') + + def stop_fn(x): + if env.env.spec.reward_threshold: + return x >= env.spec.reward_threshold + else: + return False + + # trainer + result = onpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, + args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = create_atari_environment(args.task) + collector = Collector(policy, env, preprocess_fn=preprocess_fn) + result = collector.collect(n_step=2000, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +if __name__ == '__main__': + test_ppo() From 6be957c406062dc6f67a0f135c938402c58576b5 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 13:32:52 +0800 Subject: [PATCH 16/60] rename --- examples/box2d/{sac_mcc.py => mcc_sac.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/box2d/{sac_mcc.py => mcc_sac.py} (100%) diff --git a/examples/box2d/sac_mcc.py b/examples/box2d/mcc_sac.py similarity index 100% rename from examples/box2d/sac_mcc.py rename to examples/box2d/mcc_sac.py From 1367e2c4f1050516222bd127379775a29f320c39 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 17:15:24 +0800 Subject: [PATCH 17/60] little enhancement by modifying _parse_value --- docs/tutorials/cheatsheet.rst | 7 ++++--- test/base/test_collector.py | 3 +-- tianshou/data/batch.py | 35 +++++++++++++++++------------------ tianshou/data/buffer.py | 3 ++- tianshou/data/collector.py | 28 +++++++++++++--------------- 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 0f42f8198..18e86a084 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -96,7 +96,7 @@ This is related to `Issue 42 `_. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. -This function receives typically 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a dict or a Batch. For example, you can write your hook as: +This function receives typically up to 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a Batch. For example, you can write your hook as: :: import numpy as np @@ -109,9 +109,11 @@ This function receives typically 7 keys, as listed in :class:`~tianshou.data.Bat self.baseline = 0 def preprocess_fn(**kwargs): """change reward to zero mean""" + # if only obs exist -> reset + # if obs/act/rew/done/... exist -> normal step if 'rew' not in kwargs: # means that it is called after env.reset(), it can only process the obs - return {} # none of the variables are needed to be updated + return Batch() # none of the variables are needed to be updated else: n = len(kwargs['rew']) # the number of envs in collector if self.episode_log is None: @@ -125,7 +127,6 @@ This function receives typically 7 keys, as listed in :class:`~tianshou.data.Bat self.episode_log[i] = [] self.baseline = np.mean(self.main_log) return Batch(rew=kwargs['rew']) - # you can also return with {'rew': kwargs['rew']} And finally, :: diff --git a/test/base/test_collector.py b/test/base/test_collector.py index dd2d92ad2..63279fc69 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -42,8 +42,8 @@ def __init__(self, writer): def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb - # if obs/act/rew/done/... exist -> normal step # if only obs exist -> reset + # if obs/act/rew/done/... exist -> normal step if 'rew' in kwargs.keys(): n = len(kwargs['obs']) info = kwargs['info'] @@ -54,7 +54,6 @@ def preprocess_fn(self, **kwargs): info['key']), global_step=self.cnt) self.cnt += 1 return Batch(info=info) - # or: return {'info': info} else: return Batch() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 2541d580c..6769e49a7 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -48,10 +48,8 @@ def _is_number(value: Any) -> bool: # isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc. # isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc. # isinstance(value, np.bool_) checks np.bool_(True), etc. - is_number = isinstance(value, Number) - is_number = is_number or isinstance(value, np.number) - is_number = is_number or isinstance(value, np.bool_) - return is_number + # similar to np.isscalar but np.isscalar('st') returns True + return isinstance(value, (Number, np.number, np.bool_)) def _to_array_with_correct_type(v: Any) -> np.ndarray: @@ -120,10 +118,14 @@ def _assert_type_keys(keys): def _parse_value(v: Any): - if isinstance(v, dict): - v = Batch(v) - elif isinstance(v, (Batch, torch.Tensor)): - pass + if _is_number(v): + return np.asanyarray(v) + elif isinstance(v, np.ndarray) and \ + issubclass(v.dtype.type, (np.bool_, np.number)) or \ + isinstance(v, (Batch, torch.Tensor)): + return v + elif isinstance(v, dict): + return Batch(v) else: if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \ len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v): @@ -132,18 +134,17 @@ def _parse_value(v: Any): except RuntimeError as e: raise TypeError("Batch does not support non-stackable iterable" " of torch.Tensor as unique value yet.") from e - try: - v_ = _to_array_with_correct_type(v) - except ValueError as e: - raise TypeError("Batch does not support heterogeneous list/tuple" - " of tensors as unique value yet.") from e if _is_batch_set(v): v = Batch(v) # list of dict / Batch else: # None, scalar, normal data list (main case) # or an actual list of objects - v = v_ - return v + try: + v = _to_array_with_correct_type(v) + except ValueError as e: + raise TypeError("Batch does not support heterogeneous list/" + "tuple of tensors as unique value yet.") from e + return v class Batch: @@ -642,10 +643,8 @@ def update(self, batch: Optional[Union[dict, 'Batch']] = None, if batch is None: self.update(kwargs) return - if isinstance(batch, dict): - batch = Batch(batch) for k, v in batch.items(): - self.__dict__[k] = v + self.__dict__[k] = _parse_value(v) if kwargs: self.update(kwargs) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 0ed135409..7e69f7540 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -163,7 +163,8 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] - if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: + if isinstance(inst, (np.ndarray, torch.Tensor)) \ + and value.shape[1:] != inst.shape: raise ValueError( "Cannot add data to a buffer with different shape, with key " f"{name}, expect {value.shape[1:]}, given {inst.shape}.") diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index e6ee4f8d2..0ea08e71b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -98,7 +98,6 @@ def __init__(self, self.is_async = env.is_async # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] - self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn @@ -107,8 +106,6 @@ def __init__(self, self._action_noise = action_noise self._rew_metric = reward_metric or Collector._default_rew_metric # avoid creating attribute outside __init__ - self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, - obs_next={}, policy={}) self.reset() @staticmethod @@ -252,14 +249,15 @@ def collect(self, # convert None to Batch(), since None is reserved for 0-init if state is None: state = Batch() - self.data.update(state=state, policy=result.get('policy', Batch())) + # since result is a Batch, it can bypass the type check here + self.data.__dict__['state'] = state + self.data.__dict__['policy'] = result.get('policy', Batch()) # save hidden state to policy._state, in order to save into buffer - if not (isinstance(self.data.state, Batch) - and self.data.state.is_empty()): - self.data.policy._state = self.data.state + if not (isinstance(state, Batch) and state.is_empty()): + self.data.policy.__dict__['_state'] = self.data.state self.data.act = to_numpy(result.act) - if self._action_noise is not None: + if self._action_noise is not None: # noqa self.data.act += self._action_noise(self.data.act.shape) # step in env @@ -271,7 +269,7 @@ def collect(self, self.data, self.env_num) # fetch finished data obs_next, rew, done, info = self.env.step( - action=self.data.act, id=self._ready_env_ids) + self.data.act, id=self._ready_env_ids) self._ready_env_ids = np.array([i['env_id'] for i in info]) # get the stepped data self.data = whole_data[self._ready_env_ids] @@ -286,18 +284,18 @@ def collect(self, if self.preprocess_fn: result = self.preprocess_fn(**self.data) self.data.update(result) + for j, i in enumerate(self._ready_env_ids): # j is the index in current ready_env_ids # i is the index in all environments if self.buffer is None: # users do not want to store data, so we store # small fake data here to make the code clean - self._cached_buf[i].add( - obs=0, act=0, rew=self.data.rew[j], done=0) + self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0) else: self._cached_buf[i].add(**self.data[j]) - if self.data.done[j]: + if done[j]: if n_step or np.isscalar(n_episode) or \ episode_count[i] < n_episode[i]: episode_count[i] += 1 @@ -312,8 +310,8 @@ def collect(self, self._cached_buf[i].reset() self._reset_state(j) obs_next = self.data.obs_next - if sum(self.data.done): - env_ind_local = np.where(self.data.done)[0] + if sum(done): + env_ind_local = np.where(done)[0] env_ind_global = self._ready_env_ids[env_ind_local] obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: @@ -321,7 +319,7 @@ def collect(self, obs=obs_reset).get('obs', obs_reset) else: obs_next[env_ind_local] = obs_reset - self.data.obs = obs_next + self.data.__dict__['obs'] = obs_next if is_async: # set data back whole_data = deepcopy(whole_data) # avoid reference in ListBuf From cdbf8f65051d25e7181cb22236e0a893942b9718 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 17:26:59 +0800 Subject: [PATCH 18/60] add max_batchsize in a2c and ppo --- tianshou/policy/modelfree/a2c.py | 8 ++++++-- tianshou/policy/modelfree/ppo.py | 6 ++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 52d8dd248..1d835de67 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -25,6 +25,10 @@ class A2CPolicy(PGPolicy): defaults to ``None``. :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation, defaults to 0.95. + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to ``False``. + :param int max_batchsize: the maximum number of batchsize when computing + GAE, defaults to 2000. .. seealso:: @@ -44,6 +48,7 @@ def __init__(self, max_grad_norm: Optional[float] = None, gae_lambda: float = 0.95, reward_normalization: bool = False, + max_batchsize: int = 2000, **kwargs) -> None: super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self.actor = actor @@ -53,7 +58,7 @@ def __init__(self, self._w_vf = vf_coef self._w_ent = ent_coef self._grad_norm = max_grad_norm - self._batch = 64 + self._batch = max_batchsize self._rew_norm = reward_normalization def process_fn(self, batch: Batch, buffer: ReplayBuffer, @@ -97,7 +102,6 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: - self._batch = batch_size losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 3094be82e..3f1fea4fb 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -34,6 +34,8 @@ class PPOPolicy(PGPolicy): defaults to ``True``. :param bool reward_normalization: normalize the returns to Normal(0, 1), defaults to ``True``. + :param int max_batchsize: the maximum number of batchsize when computing + GAE, defaults to 2000. .. seealso:: @@ -56,6 +58,7 @@ def __init__(self, dual_clip: Optional[float] = None, value_clip: bool = True, reward_normalization: bool = True, + max_batchsize: int = 2000, **kwargs) -> None: super().__init__(None, None, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm @@ -66,7 +69,7 @@ def __init__(self, self.actor = actor self.critic = critic self.optim = optim - self._batch = 64 + self._batch = max_batchsize assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].' self._lambda = gae_lambda assert dual_clip is None or dual_clip > 1, \ @@ -132,7 +135,6 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: - self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): From 7585f49adb85b50b052318dce433d26942d0dc8a Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 17:50:39 +0800 Subject: [PATCH 19/60] move test_gae to test/base/test_returns --- test/base/test_returns.py | 72 +++++++++++++++++++++++++++++++ test/discrete/test_pg.py | 74 ++------------------------------ tianshou/data/utils/converter.py | 15 +++++-- 3 files changed, 87 insertions(+), 74 deletions(-) create mode 100644 test/base/test_returns.py diff --git a/test/base/test_returns.py b/test/base/test_returns.py new file mode 100644 index 000000000..d74ec5bd3 --- /dev/null +++ b/test/base/test_returns.py @@ -0,0 +1,72 @@ +import time +import numpy as np + +from tianshou.data import Batch +from tianshou.policy import BasePolicy + + +def compute_episodic_return_base(batch, aa=None, bb=None, gamma=0.1): + returns = np.zeros_like(batch.rew) + last = 0 + for i in reversed(range(len(batch.rew))): + returns[i] = batch.rew[i] + if not batch.done[i]: + returns[i] += last * gamma + last = returns[i] + batch.returns = returns + return batch + + +def test_episodic_returns(size=2560): + fn = BasePolicy.compute_episodic_return + batch = Batch( + done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), + rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), + ) + batch = fn(batch, None, gamma=.1, gae_lambda=1) + ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) + assert np.allclose(batch.returns, ans) + batch = Batch( + done=np.array([0, 1, 0, 1, 0, 1, 0.]), + rew=np.array([7, 6, 1, 2, 3, 4, 5.]), + ) + batch = fn(batch, None, gamma=.1, gae_lambda=1) + ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) + assert np.allclose(batch.returns, ans) + batch = Batch( + done=np.array([0, 1, 0, 1, 0, 0, 1.]), + rew=np.array([7, 6, 1, 2, 3, 4, 5.]), + ) + batch = fn(batch, None, gamma=.1, gae_lambda=1) + ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) + assert np.allclose(batch.returns, ans) + batch = Batch( + done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), + rew=np.array([ + 101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]) + ) + v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) + ret = fn(batch, v, gamma=0.99, gae_lambda=0.95) + returns = np.array([ + 454.8344, 376.1143, 291.298, 200., + 464.5610, 383.1085, 295.387, 201., + 474.2876, 390.1027, 299.476, 202.]) + assert np.allclose(ret.returns, returns) + if __name__ == '__main__': + batch = Batch( + done=np.random.randint(100, size=size) == 0, + rew=np.random.random(size), + ) + cnt = 3000 + t = time.time() + for _ in range(cnt): + compute_episodic_return_base(batch) + print(f'vanilla: {(time.time() - t) / cnt}') + t = time.time() + for _ in range(cnt): + fn(batch, None, gamma=.1, gae_lambda=1) + print(f'policy: {(time.time() - t) / cnt}') + + +if __name__ == '__main__': + test_episodic_returns() diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index ee934a340..7a3876bc1 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,84 +1,16 @@ import os import gym -import time import torch import pprint import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.net.common import Net -from tianshou.env import DummyVectorEnv from tianshou.policy import PGPolicy +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Batch, Collector, ReplayBuffer - - -def compute_return_base(batch, aa=None, bb=None, gamma=0.1): - returns = np.zeros_like(batch.rew) - last = 0 - for i in reversed(range(len(batch.rew))): - returns[i] = batch.rew[i] - if not batch.done[i]: - returns[i] += last * gamma - last = returns[i] - batch.returns = returns - return batch - - -def test_fn(size=2560): - policy = PGPolicy(None, None, None, discount_factor=0.1) - buf = ReplayBuffer(100) - buf.add(1, 1, 1, 1, 1) - fn = policy.process_fn - # fn = compute_return_base - batch = Batch( - done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), - rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), - ) - batch = fn(batch, buf, 0) - ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) - assert np.allclose(batch.returns, ans) - batch = Batch( - done=np.array([0, 1, 0, 1, 0, 1, 0.]), - rew=np.array([7, 6, 1, 2, 3, 4, 5.]), - ) - batch = fn(batch, buf, 0) - ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) - assert np.allclose(batch.returns, ans) - batch = Batch( - done=np.array([0, 1, 0, 1, 0, 0, 1.]), - rew=np.array([7, 6, 1, 2, 3, 4, 5.]), - ) - batch = fn(batch, buf, 0) - ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) - assert np.allclose(batch.returns, ans) - batch = Batch( - done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), - rew=np.array([ - 101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]) - ) - v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) - ret = policy.compute_episodic_return(batch, v, gamma=0.99, gae_lambda=0.95) - returns = np.array([ - 454.8344, 376.1143, 291.298, 200., - 464.5610, 383.1085, 295.387, 201., - 474.2876, 390.1027, 299.476, 202.]) - assert np.allclose(ret.returns, returns) - if __name__ == '__main__': - batch = Batch( - done=np.random.randint(100, size=size) == 0, - rew=np.random.random(size), - ) - cnt = 3000 - t = time.time() - for _ in range(cnt): - compute_return_base(batch) - print(f'vanilla: {(time.time() - t) / cnt}') - t = time.time() - for _ in range(cnt): - policy.process_fn(batch, buf, 0) - print(f'policy: {(time.time() - t) / cnt}') +from tianshou.data import Collector, ReplayBuffer def get_args(): diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 92a9db0f6..8e3ff0bf6 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -10,8 +10,12 @@ def to_numpy(x: Union[ Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[ Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without torch.Tensor.""" - if isinstance(x, torch.Tensor): + if isinstance(x, np.ndarray): + pass + elif isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() + elif isinstance(x, (np.number, np.bool_, Number)): + x = np.asanyarray(x) elif isinstance(x, dict): for k, v in x.items(): x[k] = to_numpy(v) @@ -36,13 +40,18 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], if dtype is not None: x = x.type(dtype) x = x.to(device) + elif isinstance(x, (np.number, np.bool_, Number)): + x = to_torch(np.asanyarray(x), dtype, device) + elif isinstance(x, np.ndarray) and \ + issubclass(x.dtype.type, (np.bool_, np.number)): + x = torch.from_numpy(x).to(device) + if dtype is not None: + x = x.type(dtype) elif isinstance(x, dict): for k, v in x.items(): x[k] = to_torch(v, dtype, device) elif isinstance(x, Batch): x.to_torch(dtype, device) - elif isinstance(x, (np.number, np.bool_, Number)): - x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): try: x = to_torch(_parse_value(x), dtype, device) From 543a7d9a17f33ca0fdfb0e305553048124291c72 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 18:47:02 +0800 Subject: [PATCH 20/60] add test_nstep --- test/base/test_buffer.py | 8 +++---- test/base/test_collector.py | 45 +++++++++++++++++-------------------- test/base/test_returns.py | 38 ++++++++++++++++++++++++++++--- test/discrete/test_drqn.py | 2 +- test/discrete/test_pg.py | 1 - 5 files changed, 60 insertions(+), 34 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 393534c03..1f564f683 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -89,10 +89,10 @@ def test_stack(size=5, bufsize=9, stack_num=4): if done: obs = env.reset(1) indice = np.arange(len(buf)) - assert np.allclose(buf.get(indice, 'obs'), np.expand_dims( - [[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], - [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1)) + assert np.allclose(buf.get(indice, 'obs')[..., 0], [ + [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], + [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], + [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 63279fc69..c678fdcab 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -82,40 +82,36 @@ def test_collector(): c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) c0.collect(n_step=3) - assert np.allclose(c0.buffer.obs[:4], - np.expand_dims([0, 1, 0, 1], axis=-1)) - assert np.allclose(c0.buffer[:4].obs_next, - np.expand_dims([1, 2, 1, 2], axis=-1)) + assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 1]) + assert np.allclose(c0.buffer[:4].obs_next[..., 0], [1, 2, 1, 2]) c0.collect(n_episode=3) - assert np.allclose(c0.buffer.obs[:10], - np.expand_dims([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], axis=-1)) - assert np.allclose(c0.buffer[:10].obs_next, - np.expand_dims([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], axis=-1)) + assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) + assert np.allclose(c0.buffer[:10].obs_next[..., 0], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) c1.collect(n_step=6) - assert np.allclose(c1.buffer.obs[:11], np.expand_dims( - [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3], axis=-1)) - assert np.allclose(c1.buffer[:11].obs_next, np.expand_dims( - [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4], axis=-1)) + assert np.allclose(c1.buffer.obs[:11, 0], + [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) + assert np.allclose(c1.buffer[:11].obs_next[..., 0], + [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) c1.collect(n_episode=2) - assert np.allclose(c1.buffer.obs[11:21], - np.expand_dims([0, 1, 2, 3, 4, 0, 1, 0, 1, 2], axis=-1)) - assert np.allclose(c1.buffer[11:21].obs_next, - np.expand_dims([1, 2, 3, 4, 5, 1, 2, 1, 2, 3], axis=-1)) + assert np.allclose(c1.buffer.obs[11:21, 0], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) + assert np.allclose(c1.buffer[11:21].obs_next[..., 0], + [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) c1.collect(n_episode=3, random=True) c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) c2.collect(n_episode=[1, 2, 2, 2]) - assert np.allclose(c2.buffer.obs_next[:26], np.expand_dims([ + assert np.allclose(c2.buffer.obs_next[:26, 0], [ 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, - 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], axis=-1)) + 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) c2.reset_env() c2.collect(n_episode=[2, 2, 2, 2]) - assert np.allclose(c2.buffer.obs_next[26:54], np.expand_dims([ + assert np.allclose(c2.buffer.obs_next[26:54, 0], [ 1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5, - 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], axis=-1)) + 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) c2.collect(n_episode=[1, 1, 1, 1], random=True) @@ -210,10 +206,10 @@ def test_collector_with_dict_state(): batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) - assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([ + assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index[..., 0], [ 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., - 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.], axis=-1)) + 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), Logger.single_preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) @@ -244,11 +240,10 @@ def reward_metric(x): batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) - obs = np.array(np.expand_dims([ + assert np.allclose(c0.buffer[:len(c0.buffer)].obs[..., 0], [ 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., - 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.], axis=-1)) - assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs) + 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1] diff --git a/test/base/test_returns.py b/test/base/test_returns.py index d74ec5bd3..fa4bd7b04 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -1,11 +1,12 @@ import time +import torch import numpy as np -from tianshou.data import Batch from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer -def compute_episodic_return_base(batch, aa=None, bb=None, gamma=0.1): +def compute_episodic_return_base(batch, gamma): returns = np.zeros_like(batch.rew) last = 0 for i in reversed(range(len(batch.rew))): @@ -60,7 +61,7 @@ def test_episodic_returns(size=2560): cnt = 3000 t = time.time() for _ in range(cnt): - compute_episodic_return_base(batch) + compute_episodic_return_base(batch, gamma=.1) print(f'vanilla: {(time.time() - t) / cnt}') t = time.time() for _ in range(cnt): @@ -68,5 +69,36 @@ def test_episodic_returns(size=2560): print(f'policy: {(time.time() - t) / cnt}') +def target_q_fn(buffer, indice): + print('target_q_fn:', indice) + indice = (indice + 1 - buffer.done[indice]) % len(buffer) + return torch.tensor(-buffer.rew[indice], dtype=torch.float32) + + +def test_nstep_returns(): + buf = ReplayBuffer(10) + for i in range(12): + buf.add(obs=0, act=0, rew=i + 1, done=i % 4 == 3) + batch, indice = buf.sample(0) + assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) + # rew: [10, 11, 2, 3, 4, 5, 6, 7, 8, 9] + # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] + # test nstep = 1 + returns = BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=1).pop('returns') + assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) + # test nstep = 2 + returns = BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns') + assert np.allclose(returns, [ + 3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + # test nstep = 10 + returns = BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns') + assert np.allclose(returns, [ + 3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) + + if __name__ == '__main__': + test_nstep_returns() test_episodic_returns() diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index c4d976715..000a31ec5 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -16,7 +16,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 7a3876bc1..ec5a003ca 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -93,5 +93,4 @@ def stop_fn(x): if __name__ == '__main__': - # test_fn() test_pg() From 76817420d7c8946f57e8f566b065a6a2662d11f1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 18:50:10 +0800 Subject: [PATCH 21/60] increase drqn gamma --- test/discrete/test_drqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 000a31ec5..a5e8c4338 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -22,7 +22,7 @@ def get_args(): parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--stack-num', type=int, default=4) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) From a9c1b2e2e402a409bf901d842f59a9a486b300b2 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 20:21:08 +0800 Subject: [PATCH 22/60] performance improvement (+50) by analyzing traces --- test/base/test_returns.py | 2 +- tianshou/data/batch.py | 28 ++++++++++++++-------------- tianshou/data/utils/converter.py | 18 +++++++++--------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index fa4bd7b04..664968541 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -70,7 +70,7 @@ def test_episodic_returns(size=2560): def target_q_fn(buffer, indice): - print('target_q_fn:', indice) + # return the next reward indice = (indice + 1 - buffer.done[indice]) % len(buffer) return torch.tensor(-buffer.rew[indice], dtype=torch.float32) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 6769e49a7..29e407b53 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -18,14 +18,14 @@ def _is_batch_set(data: Any) -> bool: # Batch set is a list/tuple of dict/Batch objects, # or 1-D np.ndarray with np.object type, # where each element is a dict/Batch object - if isinstance(data, (list, tuple)): - if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): - return True - elif isinstance(data, np.ndarray) and data.dtype == np.object: + if isinstance(data, np.ndarray): # most often case # ``for e in data`` will just unpack the first dimension, # but data.tolist() will flatten ndarray of objects # so do not use data.tolist() - if all(isinstance(e, (dict, Batch)) for e in data): + return data.dtype == np.object and \ + all(isinstance(e, (dict, Batch)) for e in data) + elif isinstance(data, (list, tuple)): + if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): return True return False @@ -53,6 +53,9 @@ def _is_number(value: Any) -> bool: def _to_array_with_correct_type(v: Any) -> np.ndarray: + if isinstance(v, np.ndarray) and \ + issubclass(v.dtype.type, (np.bool_, np.number)): # most often case + return v # convert the value to np.ndarray # convert to np.object data type if neither bool nor number # raises an exception if array's elements are tensors themself @@ -111,18 +114,19 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ return np.array([None for _ in range(size)]) -def _assert_type_keys(keys): - keys = list(keys) +def _assert_type_keys(keys) -> None: assert all(isinstance(e, str) for e in keys), \ f"keys should all be string, but got {keys}" def _parse_value(v: Any): - if _is_number(v): + if isinstance(v, Batch): # most often case + return v + elif _is_number(v): # second often case return np.asanyarray(v) - elif isinstance(v, np.ndarray) and \ + elif v is None or isinstance(v, np.ndarray) and \ issubclass(v.dtype.type, (np.bool_, np.number)) or \ - isinstance(v, (Batch, torch.Tensor)): + isinstance(v, (Batch, torch.Tensor)): # third often case return v elif isinstance(v, dict): return Batch(v) @@ -405,7 +409,6 @@ def __cat(self, for batch in batches] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] - _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): batch_holder = Batch() @@ -421,7 +424,6 @@ def __cat(self, self.__dict__[k] = v keys_total = set.union(*[set(b.keys()) for b in batches]) keys_reserve_or_partial = set.difference(keys_total, keys_shared) - _assert_type_keys(keys_reserve_or_partial) # keys that are reserved in all batches keys_reserve = set.difference(keys_total, set.union(*keys_map)) # keys that occur only in some batches, but not all @@ -513,7 +515,6 @@ def stack_(self, for batch in batches] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] - _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): self.__dict__[k] = Batch.stack(v, axis) @@ -535,7 +536,6 @@ def stack_(self, raise ValueError( f"Stack of Batch with non-shared keys {keys_partial} " f"is only supported with axis=0, but got axis={axis}!") - _assert_type_keys(keys_reserve_or_partial) for k in keys_reserve: # reserved keys self.__dict__[k] = Batch() diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 8e3ff0bf6..7c90e5f09 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -10,10 +10,10 @@ def to_numpy(x: Union[ Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[ Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without torch.Tensor.""" - if isinstance(x, np.ndarray): - pass - elif isinstance(x, torch.Tensor): + if isinstance(x, torch.Tensor): # most often case x = x.detach().cpu().numpy() + elif isinstance(x, np.ndarray): # second often case + pass elif isinstance(x, (np.number, np.bool_, Number)): x = np.asanyarray(x) elif isinstance(x, dict): @@ -36,17 +36,17 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], device: Union[str, int, torch.device] = 'cpu' ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without np.ndarray.""" - if isinstance(x, torch.Tensor): + if isinstance(x, np.ndarray) and \ + issubclass(x.dtype.type, (np.bool_, np.number)): # most often case + x = torch.from_numpy(x).to(device) + if dtype is not None: + x = x.type(dtype) + elif isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype) x = x.to(device) elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) - elif isinstance(x, np.ndarray) and \ - issubclass(x.dtype.type, (np.bool_, np.number)): - x = torch.from_numpy(x).to(device) - if dtype is not None: - x = x.type(dtype) elif isinstance(x, dict): for k, v in x.items(): x[k] = to_torch(v, dtype, device) From 4d24149a9be9cdc7321eaadcc0661c88ba2d4627 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 21 Aug 2020 21:11:17 +0800 Subject: [PATCH 23/60] add policy.eval() before watching its performance --- README.md | 2 ++ docs/tutorials/dqn.rst | 2 ++ docs/tutorials/tictactoe.rst | 2 ++ examples/atari/pong_dqn.py | 10 ++++------ examples/box2d/acrobot_dualdqn.py | 9 ++++++--- examples/box2d/bipedal_hardcore_sac.py | 18 +++++++++++------- examples/box2d/lunarlander_dqn.py | 9 ++++++--- examples/box2d/mcc_sac.py | 8 +++++--- examples/mujoco/ant_v2_ddpg.py | 8 +++++--- examples/mujoco/ant_v2_sac.py | 8 +++++--- examples/mujoco/ant_v2_td3.py | 8 +++++--- examples/mujoco/halfcheetahBullet_v0_sac.py | 8 +++++--- examples/mujoco/point_maze_td3.py | 8 +++++--- test/continuous/test_ddpg.py | 1 + test/continuous/test_ppo.py | 1 + test/continuous/test_sac_with_il.py | 3 +++ test/continuous/test_td3.py | 1 + test/discrete/test_a2c_with_il.py | 3 +++ test/discrete/test_dqn.py | 2 ++ test/discrete/test_drqn.py | 1 + test/discrete/test_pg.py | 1 + test/discrete/test_ppo.py | 1 + test/multiagent/tic_tac_toe.py | 2 ++ 23 files changed, 79 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 1764051cd..c86c40529 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,8 @@ policy.load_state_dict(torch.load('dqn.pth')) Watch the performance with 35 FPS: ```python +policy.eval() +policy.set_eps(eps_test) collector = ts.data.Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) ``` diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index e01760058..9655ee8e1 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -176,6 +176,8 @@ Watch the Agent's Performance :class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS: :: + policy.eval() + policy.set_eps(0.05) collector = ts.data.Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 6ab79d800..0a20bf969 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -285,6 +285,8 @@ With the above preparation, we are close to the first learned agent. The followi env = TicTacToeEnv(args.board_size, args.win_size) policy, optim = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent) + policy.eval() + policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/atari/pong_dqn.py b/examples/atari/pong_dqn.py index 8124607de..ccbaee80d 100644 --- a/examples/atari/pong_dqn.py +++ b/examples/atari/pong_dqn.py @@ -94,12 +94,10 @@ def watch(): print("Testing agent ...") policy.eval() policy.set_eps(args.eps_test) - envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) - envs.seed(args.seed) - collector = Collector(policy, envs) - result = collector.collect(n_episode=[1] * args.test_num, - render=args.render) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) pprint.pprint(result) if args.watch: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index e3de12de7..6345d62eb 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -102,9 +102,12 @@ def test_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 31b83f43b..a92963d83 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -44,6 +44,7 @@ def get_args(): class EnvWrapper(object): """Env wrapper for reward scale, action repeat and action noise""" + def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3): self._env = gym.make(task) @@ -71,10 +72,11 @@ def step(self, action): def test_sac_bipedal(args=get_args()): torch.set_num_threads(1) # we just need only one thread for NN + env = EnvWrapper(args.task) + def IsStop(reward): - return reward >= 300 * 5 + return reward >= env.spec.reward_threshold - env = EnvWrapper(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] @@ -82,8 +84,8 @@ def IsStop(reward): train_envs = SubprocVectorEnv( [lambda: EnvWrapper(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: EnvWrapper(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: EnvWrapper(args.task, reward_scale=1) + for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -138,9 +140,11 @@ def save_fn(policy): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = EnvWrapper(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=16, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 0e66c65f7..aa0f5888c 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -99,9 +99,12 @@ def test_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 845ffcd7b..6e09e6c1f 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -112,9 +112,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/ant_v2_ddpg.py b/examples/mujoco/ant_v2_ddpg.py index ef7ea6c42..db65f582a 100644 --- a/examples/mujoco/ant_v2_ddpg.py +++ b/examples/mujoco/ant_v2_ddpg.py @@ -88,9 +88,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/ant_v2_sac.py b/examples/mujoco/ant_v2_sac.py index 402784f28..108be79e3 100644 --- a/examples/mujoco/ant_v2_sac.py +++ b/examples/mujoco/ant_v2_sac.py @@ -98,9 +98,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/ant_v2_td3.py b/examples/mujoco/ant_v2_td3.py index fad3f911c..db59e18d5 100644 --- a/examples/mujoco/ant_v2_td3.py +++ b/examples/mujoco/ant_v2_td3.py @@ -98,9 +98,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/halfcheetahBullet_v0_sac.py b/examples/mujoco/halfcheetahBullet_v0_sac.py index 8f1a103e4..3aec4f85a 100644 --- a/examples/mujoco/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/halfcheetahBullet_v0_sac.py @@ -104,9 +104,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/point_maze_td3.py b/examples/mujoco/point_maze_td3.py index 42e91146c..1f4a217ef 100644 --- a/examples/mujoco/point_maze_td3.py +++ b/examples/mujoco/point_maze_td3.py @@ -104,9 +104,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_step=1000, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 457fcd592..5d8a7bc82 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -108,6 +108,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index ed42e7901..eba53789e 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -123,6 +123,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index dffebc70e..cab46a9c1 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -109,11 +109,13 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') # here we define an imitation collector with a trivial policy + policy.eval() if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal net = Actor(Net(1, args.state_shape, device=args.device), @@ -136,6 +138,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index d2b95421e..e3a325f7e 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -113,6 +113,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index d99bc1448..3eafd0e42 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -98,10 +98,12 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') + policy.eval() # here we define an imitation collector with a trivial policy if args.task == 'CartPole-v0': env.spec.reward_threshold = 190 # lower the goal @@ -124,6 +126,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index aeb849f41..bcf193ff9 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -114,6 +114,8 @@ def test_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index a5e8c4338..e403c21a1 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -100,6 +100,7 @@ def test_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index ec5a003ca..b3bacd782 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -87,6 +87,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 515e2f225..0c52c899a 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -112,6 +112,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 5422c6e3b..96383ae3b 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -169,6 +169,8 @@ def watch(args: argparse.Namespace = get_args(), env = TicTacToeEnv(args.board_size, args.win_size) policy, optim = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent) + policy.eval() + policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') From dbddda759aac3ce3f76083516f08c9d1c451abe4 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 08:17:12 +0800 Subject: [PATCH 24/60] remove previous atari script --- examples/atari/atari.py | 133 ------------------------------------- examples/atari/pong_a2c.py | 102 ---------------------------- examples/atari/pong_dqn.py | 126 ----------------------------------- examples/atari/pong_ppo.py | 106 ----------------------------- 4 files changed, 467 deletions(-) delete mode 100644 examples/atari/atari.py delete mode 100644 examples/atari/pong_a2c.py delete mode 100644 examples/atari/pong_dqn.py delete mode 100644 examples/atari/pong_ppo.py diff --git a/examples/atari/atari.py b/examples/atari/atari.py deleted file mode 100644 index 8e2ea5168..000000000 --- a/examples/atari/atari.py +++ /dev/null @@ -1,133 +0,0 @@ -import cv2 -import gym -import numpy as np -from gym.spaces.box import Box -from tianshou.data import Batch - -SIZE = 84 -FRAME = 4 - - -def create_atari_environment(name=None, sticky_actions=True, - max_episode_steps=2000): - game_version = 'v0' if sticky_actions else 'v4' - name = '{}NoFrameskip-{}'.format(name, game_version) - env = gym.make(name) - env = env.env - env = preprocessing(env, max_episode_steps=max_episode_steps) - return env - - -def preprocess_fn(obs=None, act=None, rew=None, done=None, - obs_next=None, info=None, policy=None, **kwargs): - if obs_next is not None: - obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:])) - obs_next = np.moveaxis(obs_next, 0, -1) - obs_next = cv2.resize(obs_next, (SIZE, SIZE)) - obs_next = np.asanyarray(obs_next, dtype=np.uint8) - obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE)) - obs_next = np.moveaxis(obs_next, 1, -1) - elif obs is not None: - obs = np.reshape(obs, (-1, *obs.shape[2:])) - obs = np.moveaxis(obs, 0, -1) - obs = cv2.resize(obs, (SIZE, SIZE)) - obs = np.asanyarray(obs, dtype=np.uint8) - obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE)) - obs = np.moveaxis(obs, 1, -1) - - return Batch(obs=obs, act=act, rew=rew, done=done, - obs_next=obs_next, info=info) - - -class preprocessing(object): - def __init__(self, env, frame_skip=4, terminal_on_life_loss=False, - size=84, max_episode_steps=2000): - self.max_episode_steps = max_episode_steps - self.env = env - self.terminal_on_life_loss = terminal_on_life_loss - self.frame_skip = frame_skip - self.size = size - self.count = 0 - obs_dims = self.env.observation_space - - self.screen_buffer = [ - np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), - np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8) - ] - - self.game_over = False - self.lives = 0 - - @property - def observation_space(self): - return Box(low=0, high=255, - shape=(self.size, self.size, self.frame_skip), - dtype=np.uint8) - - def action_space(self): - return self.env.action_space - - def reward_range(self): - return self.env.reward_range - - def metadata(self): - return self.env.metadata - - def close(self): - return self.env.close() - - def reset(self): - self.count = 0 - self.env.reset() - self.lives = self.env.ale.lives() - self._grayscale_obs(self.screen_buffer[0]) - self.screen_buffer[1].fill(0) - - return np.array([self._pool_and_resize() - for _ in range(self.frame_skip)]) - - def render(self, mode='human'): - return self.env.render(mode) - - def step(self, action): - total_reward = 0. - observation = [] - for t in range(self.frame_skip): - self.count += 1 - _, reward, terminal, info = self.env.step(action) - total_reward += reward - - if self.terminal_on_life_loss: - lives = self.env.ale.lives() - is_terminal = terminal or lives < self.lives - self.lives = lives - else: - is_terminal = terminal - - if is_terminal: - break - elif t >= self.frame_skip - 2: - t_ = t - (self.frame_skip - 2) - self._grayscale_obs(self.screen_buffer[t_]) - - observation.append(self._pool_and_resize()) - if len(observation) == 0: - observation = [self._pool_and_resize() - for _ in range(self.frame_skip)] - while len(observation) > 0 and \ - len(observation) < self.frame_skip: - observation.append(observation[-1]) - terminal = self.count >= self.max_episode_steps - return np.array(observation), total_reward, \ - (terminal or is_terminal), info - - def _grayscale_obs(self, output): - self.env.ale.getScreenGrayscale(output) - return output - - def _pool_and_resize(self): - if self.frame_skip > 1: - np.maximum(self.screen_buffer[0], self.screen_buffer[1], - out=self.screen_buffer[0]) - - return self.screen_buffer[0] diff --git a/examples/atari/pong_a2c.py b/examples/atari/pong_a2c.py deleted file mode 100644 index f4b0a3031..000000000 --- a/examples/atari/pong_a2c.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import A2CPolicy -from tianshou.env import SubprocVectorEnv -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.discrete import Actor, Critic -from tianshou.utils.net.common import Net - -from atari import create_atari_environment, preprocess_fn - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pong') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=3e-4) - parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--repeat-per-collect', type=int, default=1) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--layer-num', type=int, default=2) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=8) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - # a2c special - parser.add_argument('--vf-coef', type=float, default=0.5) - parser.add_argument('--ent-coef', type=float, default=0.001) - parser.add_argument('--max-grad-norm', type=float, default=None) - parser.add_argument('--max_episode_steps', type=int, default=2000) - return parser.parse_args() - - -def test_a2c(args=get_args()): - env = create_atari_environment(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.env.action_space.shape or env.env.action_space.n - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv( - [lambda: create_atari_environment(args.task) - for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: create_atari_environment(args.task) - for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.layer_num, args.state_shape, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) - optim = torch.optim.Adam(list( - actor.parameters()) + list(critic.parameters()), lr=args.lr) - dist = torch.distributions.Categorical - policy = A2CPolicy( - actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef, - ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size), - preprocess_fn=preprocess_fn) - test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) - # log - writer = SummaryWriter(args.logdir + '/' + 'a2c') - - def stop_fn(x): - if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold - else: - return False - - # trainer - result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = create_atari_environment(args.task) - collector = Collector(policy, env, preprocess_fn=preprocess_fn) - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_a2c() diff --git a/examples/atari/pong_dqn.py b/examples/atari/pong_dqn.py deleted file mode 100644 index ccbaee80d..000000000 --- a/examples/atari/pong_dqn.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import DQNPolicy -from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.discrete import DQN -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer - -from atari import create_atari_environment, preprocess_fn - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--eps_test', type=float, default=0.005) - parser.add_argument('--eps_train', type=float, default=0.5) - parser.add_argument('--eps_train_final', type=float, default=0.05) - parser.add_argument('--buffer-size', type=int, default=100000) - parser.add_argument('--lr', type=float, default=0.0001) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--n_step', type=int, default=3) - parser.add_argument('--target_update_freq', type=int, default=500) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--layer-num', type=int, default=3) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=8) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - return parser.parse_args() - - -def test_dqn(args=get_args()): - env = create_atari_environment(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.env.action_space.shape or env.env.action_space.n - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv([ - lambda: create_atari_environment(args.task) - for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([ - lambda: create_atari_environment(args.task) - for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = DQN( - args.state_shape[0], args.state_shape[1], - args.action_shape, args.device) - net = net.to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size), - preprocess_fn=preprocess_fn) - test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) - # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * 4) - print(len(train_collector.buffer)) - # log - writer = SummaryWriter(args.logdir + '/' + 'dqn') - - def stop_fn(x): - if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold - else: - return False - - def train_fn(x): - policy.set_eps(args.eps_train) - - def test_fn(x): - policy.set_eps(args.eps_test) - - # watch agent's performance - def watch(): - print("Testing agent ...") - policy.eval() - policy.set_eps(args.eps_test) - test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - pprint.pprint(result) - - if args.watch: - watch() - exit(0) - - # test train_collector and start filling replay buffer - train_collector.collect(n_step=args.batch_size * 4) - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, writer=writer) - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = create_atari_environment(args.task) - collector = Collector(policy, env, preprocess_fn=preprocess_fn) - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_dqn(get_args()) diff --git a/examples/atari/pong_ppo.py b/examples/atari/pong_ppo.py deleted file mode 100644 index 9d5563fe1..000000000 --- a/examples/atari/pong_ppo.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import PPOPolicy -from tianshou.env import SubprocVectorEnv -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.discrete import Actor, Critic -from tianshou.utils.net.common import Net - -from atari import create_atari_environment, preprocess_fn - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pong') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--repeat-per-collect', type=int, default=2) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--layer-num', type=int, default=1) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=8) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - # ppo special - parser.add_argument('--vf-coef', type=float, default=0.5) - parser.add_argument('--ent-coef', type=float, default=0.0) - parser.add_argument('--eps-clip', type=float, default=0.2) - parser.add_argument('--max-grad-norm', type=float, default=0.5) - parser.add_argument('--max_episode_steps', type=int, default=2000) - return parser.parse_args() - - -def test_ppo(args=get_args()): - env = create_atari_environment(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space().shape or env.action_space().n - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv([ - lambda: create_atari_environment(args.task) - for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([ - lambda: create_atari_environment(args.task) - for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.layer_num, args.state_shape, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) - optim = torch.optim.Adam(list( - actor.parameters()) + list(critic.parameters()), lr=args.lr) - dist = torch.distributions.Categorical - policy = PPOPolicy( - actor, critic, optim, dist, args.gamma, - max_grad_norm=args.max_grad_norm, - eps_clip=args.eps_clip, - vf_coef=args.vf_coef, - ent_coef=args.ent_coef, - action_range=None) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size), - preprocess_fn=preprocess_fn) - test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) - # log - writer = SummaryWriter(args.logdir + '/' + 'ppo') - - def stop_fn(x): - if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold - else: - return False - - # trainer - result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = create_atari_environment(args.task) - collector = Collector(policy, env, preprocess_fn=preprocess_fn) - result = collector.collect(n_step=2000, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_ppo() From 88db97ba6081f1b8de45a5ceb226f07d8ede87a4 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 08:46:11 +0800 Subject: [PATCH 25/60] find a bug in exact n_episode --- test/base/test_collector.py | 4 ++-- tianshou/data/collector.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index c678fdcab..0ef3ed2dc 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -127,8 +127,8 @@ def test_collector_with_exact_episodes(): c1 = Collector(policy, venv, ReplayBuffer(size=1000, ignore_obs_next=False), logger.preprocess_fn) - n_episode1 = [2, 2, 5, 1] - n_episode2 = [1, 3, 2, 4] + n_episode1 = [2, 0, 5, 1] + n_episode2 = [1, 3, 2, 0] c1.collect(n_episode=n_episode1) expected_steps = sum([a * b for a, b in zip(env_lens, n_episode1)]) actual_steps = sum(venv.steps) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 0ea08e71b..a7162404e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -200,7 +200,8 @@ def collect(self, * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ - assert (n_step and not n_episode) or (not n_step and n_episode), \ + assert (n_step and not n_episode and n_step > 0) or \ + (not n_step and n_episode and np.sum(n_episode) > 0), \ "One and only one collection number specification is permitted!" start_time = time.time() step_count = 0 From 226a518e3c1c47f3a711776aab30fcd948bf286e Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 09:12:19 +0800 Subject: [PATCH 26/60] fix 0 in n_episode --- tianshou/data/collector.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index a7162404e..ed9babb8a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -200,19 +200,25 @@ def collect(self, * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ - assert (n_step and not n_episode and n_step > 0) or \ - (not n_step and n_episode and np.sum(n_episode) > 0), \ - "One and only one collection number specification is permitted!" + assert (n_step is not None and n_episode is None and n_step > 0) or ( + n_step is None and n_episode is not None and np.sum(n_episode) > 0 + ), "One and only one collection number specification is permitted!" start_time = time.time() step_count = 0 # episode of each environment episode_count = np.zeros(self.env_num) # If n_episode is a list, and some envs have collected the required - # number of episodes, these envs will be recorded in this list, - # and they will not be stepped. + # number of episodes, these envs will be recorded in this list, and + # they will not be stepped. finished_env_ids = [] reward_total = 0.0 whole_data = Batch() + if n_episode is not None and not np.isscalar(n_episode): + assert len(n_episode) == self.get_env_num() + finished_env_ids = [ + i for i in self._ready_env_ids if n_episode[i] <= 0] + self._ready_env_ids = np.asarray( + [x for x in self._ready_env_ids if x not in finished_env_ids]) while True: if step_count >= 100000 and episode_count.sum() == 0: warnings.warn( @@ -222,12 +228,12 @@ def collect(self, is_async = self.is_async or len(finished_env_ids) > 0 if is_async: - # self.data are the data for all environments - # in async simulation or some envs have finished, - # **only a subset of data are disposed** + # self.data are the data for all environments in async + # simulation or some envs have finished, + # **only a subset of data are disposed**, # so we store the whole data in ``whole_data``, let self.data - # to be the data available in ready environments, and - # finally set these back into all the data + # to be the data available in ready environments, and finally + # set these back into all the data whole_data = self.data self.data = self.data[self._ready_env_ids] From 5fe08506b0c11ca6d8da1aa7e3935295b25f59a9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 09:32:22 +0800 Subject: [PATCH 27/60] improve little coverage --- test/base/test_batch.py | 10 +++++++++- tianshou/data/batch.py | 18 +++++++----------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index a823491b4..fe4a48288 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -26,7 +26,7 @@ def test_batch(): assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) - assert b.a == 3 + assert b.a == 3 and 'a' in b with pytest.raises(AssertionError): Batch({1: 2}) with pytest.raises(TypeError): @@ -76,6 +76,7 @@ def test_batch(): assert Batch().shape == [] assert Batch(a=1).shape == [] assert batch2.shape[0] == 1 + assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): @@ -297,6 +298,13 @@ def test_batch_cat_and_stack(): assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) + # exceptions + assert Batch.cat([]).is_empty() + b1 = Batch(e=[4, 5], d=6) + b2 = Batch(e=[4, 6]) + with pytest.raises(ValueError): + Batch.cat([b1, b2]) + def test_batch_over_batch_to_torch(): batch = Batch( diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 29e407b53..9e0484339 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -465,11 +465,10 @@ def cat_(self, lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches] except TypeError as e: - e2 = ValueError( - f'Batch.cat_ meets an exception. Maybe because there is ' - f'any scalar in {batches} but Batch.cat_ does not support' - f'the concatenation of scalar.') - raise Exception([e, e2]) + raise ValueError( + f'Batch.cat_ meets an exception. Maybe because there is any ' + f'scalar in {batches} but Batch.cat_ does not support the ' + f'concatenation of scalar.') from e if not self.is_empty(): batches = [self] + list(batches) lens = [0 if self.is_empty(recurse=True) else len(self)] + lens @@ -712,19 +711,16 @@ def shape(self) -> List[int]: return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ else data_shape[0] - def split(self, size: Optional[int] = None, + def split(self, size: int, shuffle: bool = True) -> Iterator['Batch']: """Split whole data into multiple small batches. - :param int size: if it is ``None``, it does not split the data batch; - otherwise it will divide the data batch with the given size. - Default to ``None``. + :param int size: divide the data batch with the given size, defaults to + ``None``. :param bool shuffle: randomly shuffle the entire data batch if it is ``True``, otherwise remain in the same. Default to ``True``. """ length = len(self) - if size is None: - size = length if shuffle: indices = np.random.permutation(length) else: From 8c0f4149abc3105fbb6a10c1df221bf9516763f1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 12:37:49 +0800 Subject: [PATCH 28/60] add missing test --- test/base/test_batch.py | 25 ++++++++++++++++++++++--- test/base/test_buffer.py | 2 +- tianshou/data/batch.py | 9 +++++---- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index fe4a48288..2ca019711 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -41,6 +41,8 @@ def test_batch(): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) + batch.cat_(batch) + assert torch.allclose(batch.a, torch.ones(4, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] @@ -75,6 +77,7 @@ def test_batch(): assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] + assert Batch(a=set((1, 2, 1))).shape == [] assert batch2.shape[0] == 1 assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) with pytest.raises(IndexError): @@ -102,10 +105,14 @@ def test_batch(): assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e + batch2.a.d.f = {} batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 + assert batch2_sum.a.d.f.is_empty() + with pytest.raises(TypeError): + batch2 += [1] batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) @@ -172,6 +179,11 @@ def test_batch_over_batch(): batch5[:, -1] += 1 assert np.allclose(batch5.a, [1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) + with pytest.raises(ValueError): + batch5[:, -1] = 1 + batch5[:, 0] = {'a': -1} + assert np.allclose(batch5.a, [-1, 3]) + assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) def test_batch_cat_and_stack(): @@ -200,9 +212,9 @@ def test_batch_cat_and_stack(): assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert ans.a.t.is_empty() - b12_stack = Batch.stack((b1, b2)) - assert isinstance(b12_stack.a.d.e, np.ndarray) - assert b12_stack.a.d.e.ndim == 2 + assert b1.stack_([b2]) is None + assert isinstance(b1.a.d.e, np.ndarray) + assert b1.a.d.e.ndim == 2 # test cat with incompatible keys b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) @@ -300,10 +312,13 @@ def test_batch_cat_and_stack(): # exceptions assert Batch.cat([]).is_empty() + assert Batch.stack([]).is_empty() b1 = Batch(e=[4, 5], d=6) b2 = Batch(e=[4, 6]) with pytest.raises(ValueError): Batch.cat([b1, b2]) + with pytest.raises(ValueError): + Batch.stack([b1, b2], axis=1) def test_batch_over_batch_to_torch(): @@ -314,17 +329,21 @@ def test_batch_over_batch_to_torch(): d=torch.ones((1,), dtype=torch.float64) ) ) + batch.b.__dict__['e'] = 1 # bypass the check batch.to_torch() assert isinstance(batch.a, torch.Tensor) assert isinstance(batch.b.c, torch.Tensor) assert isinstance(batch.b.d, torch.Tensor) + assert isinstance(batch.b.e, torch.Tensor) assert batch.a.dtype == torch.float64 assert batch.b.c.dtype == torch.float32 assert batch.b.d.dtype == torch.float64 + assert batch.b.e.dtype == torch.int64 batch.to_torch(dtype=torch.float32) assert batch.a.dtype == torch.float32 assert batch.b.c.dtype == torch.float32 assert batch.b.d.dtype == torch.float32 + assert batch.b.e.dtype == torch.float32 def test_utils_to_torch_numpy(): diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 1f564f683..3f0e0e8d2 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -20,7 +20,7 @@ def test_replaybuffer(size=10, bufsize=20): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(obs, a, rew, done, obs_next, info) + buf.add(obs, [a], rew, done, obs_next, info) obs = obs_next assert len(buf) == min(bufsize, i + 1) data, indice = buf.sample(bufsize * 2) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 9e0484339..a50365991 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -86,6 +86,7 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) is_scalar = _is_scalar(inst) if not stack and is_scalar: + # should never hit since it has already checked in Batch.cat_ # here we do not consider scalar types, following the behavior of numpy # which does not support concatenation of zero-dimensional arrays # (scalars) @@ -220,13 +221,13 @@ def __setitem__(self, index: Union[ str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: """Assign value to self[index].""" + value = _parse_value(value) if isinstance(index, str): - self.__dict__[index] = _parse_value(value) + self.__dict__[index] = value return - value = _parse_value(value) if isinstance(value, (np.ndarray, torch.Tensor)): - raise ValueError("Batch does not supported tensor assignment." - " Use a compatible Batch or dict instead.") + raise ValueError("Batch does not supported tensor assignment. " + "Use a compatible Batch or dict instead.") if not set(value.keys()).issubset(self.__dict__.keys()): raise KeyError( "Creating keys is not supported by item assignment.") From 0a168cea4e44f483355bd180fd643ed6ce1b61a3 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 13:14:32 +0800 Subject: [PATCH 29/60] add missing test for buffer, to_numpy and to_torch --- test/base/test_batch.py | 16 ++++++++++++++++ test/base/test_buffer.py | 14 ++++++++++++++ tianshou/data/buffer.py | 4 +--- tianshou/data/utils/converter.py | 14 ++++---------- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 2ca019711..226b26fca 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -393,6 +393,22 @@ def test_utils_to_torch_numpy(): assert isinstance(data_empty_array, np.ndarray) assert data_empty_array.shape == (0, 2, 2) assert np.allclose(to_numpy(to_torch(data_array)), data_array) + # additional test for to_numpy, for code-coverage + assert isinstance(to_numpy(1), np.ndarray) + assert isinstance(to_numpy(1.), np.ndarray) + assert isinstance(to_numpy({'a': torch.tensor(1)})['a'], np.ndarray) + assert isinstance(to_numpy(Batch(a=torch.tensor(1))).a, np.ndarray) + assert to_numpy(None).item() is None + assert to_numpy(to_numpy).item() == to_numpy + # additional test for to_torch, for code-coverage + assert isinstance(to_torch(1), torch.Tensor) + assert to_torch(1).dtype == torch.int64 + assert to_torch(1.).dtype == torch.float64 + assert isinstance(to_torch({'a': [1]})['a'], torch.Tensor) + with pytest.raises(TypeError): + to_torch(None) + with pytest.raises(TypeError): + to_torch(np.array([{}, '2'])) def test_batch_pickle(): diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 3f0e0e8d2..4476d66da 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -16,6 +16,8 @@ def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) + buf.update(buf) + assert str(buf) == buf.__class__.__name__ + '()' obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): @@ -23,6 +25,8 @@ def test_replaybuffer(size=10, bufsize=20): buf.add(obs, [a], rew, done, obs_next, info) obs = obs_next assert len(buf) == min(bufsize, i + 1) + with pytest.raises(ValueError): + buf._add_to_buffer('rew', [1, 2, 3]) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() @@ -37,6 +41,11 @@ def test_replaybuffer(size=10, bufsize=20): assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert np.all(b.info.b.c[1:] == 0.0) + with pytest.raises(IndexError): + b[22] + b = ListReplayBuffer() + with pytest.raises(NotImplementedError): + b.sample(0) def test_ignore_obs_next(size=10): @@ -97,6 +106,8 @@ def test_stack(size=5, bufsize=9, stack_num=4): assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) assert indice in [2, 6] + with pytest.raises(IndexError): + buf[bufsize * 2] def test_priortized_replaybuffer(size=32, bufsize=15): @@ -139,6 +150,7 @@ def test_segtree(): # small test actual_len = 8 tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes + assert len(tree) == actual_len assert np.all([tree[i] == init for i in range(actual_len)]) with pytest.raises(IndexError): tree[actual_len] @@ -154,6 +166,8 @@ def test_segtree(): ref = realop(naive[i:j]) out = tree.reduce(i, j) assert np.allclose(ref, out) + assert np.allclose(tree.reduce(start=1), realop(naive[1:])) + assert np.allclose(tree.reduce(end=-1), realop(naive[:-1])) # batch setitem for _ in range(1000): index = np.random.choice(actual_len, size=4) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 7e69f7540..7bea14eff 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -352,9 +352,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def _add_to_buffer( self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: - if inst is None: - return - if self._meta.__dict__.get(name, None) is None: + if self._meta.__dict__.get(name) is None: self._meta.__dict__[name] = [] self._meta.__dict__[name].append(inst) diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 7c90e5f09..45f97f6fd 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -14,13 +14,13 @@ def to_numpy(x: Union[ x = x.detach().cpu().numpy() elif isinstance(x, np.ndarray): # second often case pass - elif isinstance(x, (np.number, np.bool_, Number)): + elif isinstance(x, (np.number, np.bool_, Number)) or x is None: x = np.asanyarray(x) + elif isinstance(x, Batch): + x.to_numpy() elif isinstance(x, dict): for k, v in x.items(): x[k] = to_numpy(v) - elif isinstance(x, Batch): - x.to_numpy() elif isinstance(x, (list, tuple)): try: x = to_numpy(_parse_value(x)) @@ -58,13 +58,7 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], except TypeError: x = [to_torch(e, dtype, device) for e in x] else: # fallback - x = np.asanyarray(x) - if issubclass(x.dtype.type, (np.bool_, np.number)): - x = torch.from_numpy(x).to(device) - if dtype is not None: - x = x.type(dtype) - else: - raise TypeError(f"object {x} cannot be converted to torch.") + raise TypeError(f"object {x} cannot be converted to torch.") return x From a13bffa5d1445e74da21edab9cae5f4591d9ee23 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 13:39:25 +0800 Subject: [PATCH 30/60] add missing test for venv and utils --- test/base/test_env.py | 1 + test/base/test_utils.py | 30 ++++++++++++++++++++++++++++++ tianshou/env/worker/dummy.py | 6 ++---- tianshou/env/worker/ray.py | 2 +- tianshou/env/worker/subproc.py | 2 +- 5 files changed, 35 insertions(+), 6 deletions(-) create mode 100644 test/base/test_utils.py diff --git a/test/base/test_env.py b/test/base/test_env.py index 96de70236..6f67df4b2 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -48,6 +48,7 @@ def test_async_env(size=10000, num=8, sleep=0.1): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num // 2, timeout=1e-3) + v.seed(None) v.reset() # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un} # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1} diff --git a/test/base/test_utils.py b/test/base/test_utils.py new file mode 100644 index 000000000..932f92a12 --- /dev/null +++ b/test/base/test_utils.py @@ -0,0 +1,30 @@ +import torch +import numpy as np + +from tianshou.utils import MovAvg +from tianshou.exploration import GaussianNoise, OUNoise + + +def test_noise(): + noise = GaussianNoise() + size = (3, 4, 5) + assert np.allclose(noise(size).shape, size) + noise = OUNoise() + noise.reset() + assert np.allclose(noise(size).shape, size) + + +def test_moving_average(): + stat = MovAvg(10) + stat.add(torch.tensor([1])) + stat.add(np.array([2])) + stat.add([3, 4]) + stat.add(5.) + assert np.allclose(stat.get(), 3.) + assert np.allclose(stat.mean(), 3.) + assert np.allclose(stat.std() ** 2, 2.) + + +if __name__ == '__main__': + test_noise() + test_moving_average() diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 97b7087b0..893500b28 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -12,10 +12,8 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: super().__init__(env_fn) self.env = env_fn() - def __getattr__(self, key: str): - if hasattr(self.env, key): - return getattr(self.env, key) - return None + def __getattr__(self, key: str) -> Any: + return getattr(self.env, key) def reset(self) -> Any: return self.env.reset() diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index f9f4fa9ff..3f71d828b 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -17,7 +17,7 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: super().__init__(env_fn) self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn()) - def __getattr__(self, key: str): + def __getattr__(self, key: str) -> Any: return ray.get(self.env.__getattr__.remote(key)) def reset(self) -> Any: diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 3fc992947..3186b01db 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -127,7 +127,7 @@ def __init__(self, env_fn: Callable[[], gym.Env], self.process.start() self.child_remote.close() - def __getattr__(self, key: str): + def __getattr__(self, key: str) -> Any: self.parent_remote.send(['getattr', key]) return self.parent_remote.recv() From e920906f47921c374549eb28e804436a6b3979e9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 14:00:20 +0800 Subject: [PATCH 31/60] fix test --- test/base/test_buffer.py | 2 +- test/base/test_utils.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 4476d66da..8b96466b8 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -26,7 +26,7 @@ def test_replaybuffer(size=10, bufsize=20): obs = obs_next assert len(buf) == min(bufsize, i + 1) with pytest.raises(ValueError): - buf._add_to_buffer('rew', [1, 2, 3]) + buf._add_to_buffer('rew', np.array([1, 2, 3])) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 932f92a12..bdf697e9b 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -16,13 +16,16 @@ def test_noise(): def test_moving_average(): stat = MovAvg(10) + assert np.allclose(stat.get(), 0) + assert np.allclose(stat.mean(), 0) + assert np.allclose(stat.std() ** 2, 0) stat.add(torch.tensor([1])) stat.add(np.array([2])) stat.add([3, 4]) stat.add(5.) - assert np.allclose(stat.get(), 3.) - assert np.allclose(stat.mean(), 3.) - assert np.allclose(stat.std() ** 2, 2.) + assert np.allclose(stat.get(), 3) + assert np.allclose(stat.mean(), 3) + assert np.allclose(stat.std() ** 2, 2) if __name__ == '__main__': From a68fc99accb365d17ff42145b4c303556ad52f55 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 16:01:53 +0800 Subject: [PATCH 32/60] fix RecurrentActorProb and add test --- test/base/test_utils.py | 35 ++++++++++++++++++++++++++++++++ tianshou/utils/net/continuous.py | 35 ++++++++++++++++++++++---------- tianshou/utils/net/discrete.py | 7 +++---- 3 files changed, 62 insertions(+), 15 deletions(-) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index bdf697e9b..1f24dbd33 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -3,6 +3,9 @@ from tianshou.utils import MovAvg from tianshou.exploration import GaussianNoise, OUNoise +from tianshou.utils.net.common import Net +from tianshou.utils.net.discrete import DQN +from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic def test_noise(): @@ -28,6 +31,38 @@ def test_moving_average(): assert np.allclose(stat.std() ** 2, 2) +def test_net(): + # here test the networks that does not appear in the other script + bsz = 64 + # common net + state_shape = (10, 2) + action_shape = (5, ) + data = torch.rand([bsz, *state_shape]) + expect_output_shape = [bsz, *action_shape] + net = Net(3, state_shape, action_shape, norm_layer=torch.nn.LayerNorm) + assert list(net(data)[0].shape) == expect_output_shape + net = Net(3, state_shape, action_shape, dueling=(2, 2)) + assert list(net(data)[0].shape) == expect_output_shape + # recurrent actor/critic + data = data.flatten(1) + net = RecurrentActorProb(3, state_shape, action_shape) + mu, sigma = net(data)[0] + assert mu.shape == sigma.shape + assert list(mu.shape) == [bsz, 5] + net = RecurrentCritic(3, state_shape, action_shape) + data = torch.rand([bsz, 8, np.prod(state_shape)]) + act = torch.rand(expect_output_shape) + assert list(net(data, act).shape) == [bsz, 1] + # DQN + state_shape = (4, 84, 84) + action_shape = (6, ) + data = torch.rand([bsz, *state_shape]) + expect_output_shape = [bsz, *action_shape] + net = DQN(*state_shape, action_shape) + assert list(net(data)[0].shape) == expect_output_shape + + if __name__ == '__main__': test_noise() test_moving_average() + test_net() diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index bbd2d9655..5cde55011 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -10,8 +10,8 @@ class Actor(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, action_shape, - max_action, device='cpu', hidden_layer_size=128): + def __init__(self, preprocess_net, action_shape, max_action=1., + device='cpu', hidden_layer_size=128): super().__init__() self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, np.prod(action_shape)) @@ -35,7 +35,7 @@ def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128): self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, 1) - def forward(self, s, a=None, **kwargs): + def forward(self, s, a=None, info={}): """(s, a) -> logits -> Q(s, a)""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) @@ -53,7 +53,7 @@ class ActorProb(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, action_shape, max_action, + def __init__(self, preprocess_net, action_shape, max_action=1., device='cpu', unbounded=False, hidden_layer_size=128): super().__init__() self.preprocess = preprocess_net @@ -63,7 +63,7 @@ def __init__(self, preprocess_net, action_shape, max_action, self._max = max_action self._unbounded = unbounded - def forward(self, s, state=None, **kwargs): + def forward(self, s, state=None, info={}): """s -> logits -> (mu, sigma)""" logits, h = self.preprocess(s, state) mu = self.mu(logits) @@ -80,8 +80,8 @@ class RecurrentActorProb(nn.Module): :ref:`build_the_network`. """ - def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu', hidden_layer_size=128): + def __init__(self, layer_num, state_shape, action_shape, max_action=1., + device='cpu', unbounded=False, hidden_layer_size=128): super().__init__() self.device = device self.nn = nn.LSTM(input_size=np.prod(state_shape), @@ -89,8 +89,10 @@ def __init__(self, layer_num, state_shape, action_shape, num_layers=layer_num, batch_first=True) self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape)) self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) + self._max = max_action + self._unbounded = unbounded - def forward(self, s, **kwargs): + def forward(self, s, state=None, info={}): """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -98,13 +100,24 @@ def forward(self, s, **kwargs): # in evaluation phase. if len(s.shape) == 2: s = s.unsqueeze(-2) - logits, _ = self.nn(s) - logits = logits[:, -1] + self.nn.flatten_parameters() + if state is None: + s, (h, c) = self.nn(s) + else: + # we store the stack data in [bsz, len, ...] format + # but pytorch rnn needs [len, bsz, ...] + s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(), + state['c'].transpose(0, 1).contiguous())) + logits = s[:, -1] mu = self.mu(logits) + if not self._unbounded: + mu = self._max * torch.tanh(mu) shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() - return (mu, sigma), None + # please ensure the first dim is batch size: [bsz, len, ...] + return (mu, sigma), {'h': h.transpose(0, 1).detach(), + 'c': c.transpose(0, 1).detach()} class RecurrentCritic(nn.Module): diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index afed6dfb5..c7fed2bcb 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -45,7 +45,7 @@ class DQN(nn.Module): Reference paper: "Human-level control through deep reinforcement learning". """ - def __init__(self, h, w, action_shape, device='cpu'): + def __init__(self, c, h, w, action_shape, device='cpu'): super(DQN, self).__init__() self.device = device @@ -66,7 +66,7 @@ def conv2d_layers_size_out(size, linear_input_size = convw * convh * 64 self.net = nn.Sequential( - nn.Conv2d(4, 32, kernel_size=8, stride=4), + nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), @@ -74,12 +74,11 @@ def conv2d_layers_size_out(size, nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(linear_input_size, 512), - nn.Linear(512, action_shape) + nn.Linear(512, np.prod(action_shape)) ) def forward(self, x, state=None, info={}): r"""x -> Q(x, \*)""" if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device, dtype=torch.float32) - x = x.permute(0, 3, 1, 2) return self.net(x), state From b26526572c6cef3f86c9b6c66520efab5909b225 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 16:21:14 +0800 Subject: [PATCH 33/60] little increase --- test/base/test_utils.py | 2 +- tianshou/utils/net/continuous.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 1f24dbd33..5944bfbe5 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -56,7 +56,7 @@ def test_net(): # DQN state_shape = (4, 84, 84) action_shape = (6, ) - data = torch.rand([bsz, *state_shape]) + data = np.random.rand(bsz, *state_shape) expect_output_shape = [bsz, *action_shape] net = DQN(*state_shape, action_shape) assert list(net(data)[0].shape) == expect_output_shape diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 5cde55011..03a11f59c 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -2,7 +2,7 @@ import numpy as np from torch import nn -from tianshou.data import to_torch +from tianshou.data import to_torch, to_torch_as class Actor(nn.Module): @@ -147,8 +147,7 @@ def forward(self, s, a=None): s, (h, c) = self.nn(s) s = s[:, -1] if a is not None: - if not isinstance(a, torch.Tensor): - a = torch.tensor(a, device=self.device, dtype=torch.float32) + a = to_torch_as(a, s) s = torch.cat([s, a], dim=1) s = self.fc2(s) return s From 51336a88d917741de95ca1ec5c2e73634f6d36f9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 16:40:05 +0800 Subject: [PATCH 34/60] add a little test --- test/base/test_batch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 226b26fca..f3f69c187 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -27,6 +27,8 @@ def test_batch(): # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert b.a == 3 and 'a' in b + assert b.pop('a') == 3 + assert 'a' not in b with pytest.raises(AssertionError): Batch({1: 2}) with pytest.raises(TypeError): From e7387ffb939c6e828557d80fe90cca7fa472a0cf Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 22 Aug 2020 19:42:36 +0800 Subject: [PATCH 35/60] minor fix --- test/base/test_batch.py | 2 +- test/base/test_collector.py | 4 ++-- test/discrete/test_drqn.py | 2 +- tianshou/data/batch.py | 5 ++--- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index f3f69c187..12d3a7e22 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -26,7 +26,7 @@ def test_batch(): assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) - assert b.a == 3 and 'a' in b + assert 'a' in b and b.a == 3 assert b.pop('a') == 3 assert 'a' not in b with pytest.raises(AssertionError): diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 0ef3ed2dc..217531611 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -44,7 +44,7 @@ def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb # if only obs exist -> reset # if obs/act/rew/done/... exist -> normal step - if 'rew' in kwargs.keys(): + if 'rew' in kwargs: n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): @@ -60,7 +60,7 @@ def preprocess_fn(self, **kwargs): @staticmethod def single_preprocess_fn(**kwargs): # same as above, without tfb - if 'rew' in kwargs.keys(): + if 'rew' in kwargs: n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index e403c21a1..c22e0b151 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -16,7 +16,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index a50365991..e140b554f 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -127,7 +127,7 @@ def _parse_value(v: Any): return np.asanyarray(v) elif v is None or isinstance(v, np.ndarray) and \ issubclass(v.dtype.type, (np.bool_, np.number)) or \ - isinstance(v, (Batch, torch.Tensor)): # third often case + isinstance(v, torch.Tensor): # third often case return v elif isinstance(v, dict): return Batch(v) @@ -716,8 +716,7 @@ def split(self, size: int, shuffle: bool = True) -> Iterator['Batch']: """Split whole data into multiple small batches. - :param int size: divide the data batch with the given size, defaults to - ``None``. + :param int size: divide the data batch with the given size. :param bool shuffle: randomly shuffle the entire data batch if it is ``True``, otherwise remain in the same. Default to ``True``. """ From b710aed45631e84a14a5896a4aed7e831ef0d827 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 23 Aug 2020 06:51:41 +0800 Subject: [PATCH 36/60] merge_last in batch.split() (#185) --- test/base/test_batch.py | 25 +++++++++++++++++++++++-- tianshou/data/batch.py | 17 +++++++++++++---- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 12d3a7e22..a70bf5407 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -64,6 +64,27 @@ def test_batch(): with pytest.raises(AttributeError): b.obs print(batch) + batch = Batch(a=np.arange(10)) + with pytest.raises(AssertionError): + list(batch.split(0)) + bs1 = list(batch.split(1, shuffle=False)) + assert len(bs1) == len(batch) + assert [b.a for b in bs1] == list(range(len(batch))) + bs1 = list(batch.split(1, shuffle=False, merge_last=True)) + # since 10 % 1 == 0, the merge_last will not work on split + assert [b.a for b in bs1] == list(range(len(batch))) + bs3 = list(batch.split(3, shuffle=False)) + assert [bs3[i].a.tolist() for i in range(len(bs3))] == [ + [0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + bs3 = list(batch.split(3, shuffle=False, merge_last=True)) + assert [bs3[i].a.tolist() for i in range(len(bs3))] == [ + [0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] + bs7 = list(batch.split(7, shuffle=False)) + assert [bs7[i].a.tolist() for i in range(len(bs7))] == [ + [0, 1, 2, 3, 4, 5, 6], [7, 8, 9]] + bs7 = list(batch.split(7, shuffle=False, merge_last=True)) + assert [bs7[i].a.tolist() for i in range(len(bs7))] == [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) @@ -376,13 +397,13 @@ def test_utils_to_torch_numpy(): assert isinstance(data_list_3_torch, list) assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch) assert all(starmap(np.allclose, - zip(to_numpy(to_torch(data_list_3)), data_list_3))) + zip(to_numpy(to_torch(data_list_3)), data_list_3))) data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))] data_list_4_torch = to_torch(data_list_4) assert isinstance(data_list_4_torch, list) assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch) assert all(starmap(np.allclose, - zip(to_numpy(to_torch(data_list_4)), data_list_4))) + zip(to_numpy(to_torch(data_list_4)), data_list_4))) data_list_5 = [np.zeros(2), np.zeros((3, 3))] data_list_5_torch = to_torch(data_list_5) assert isinstance(data_list_5_torch, list) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e140b554f..897ce06b5 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -712,18 +712,27 @@ def shape(self) -> List[int]: return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ else data_shape[0] - def split(self, size: int, - shuffle: bool = True) -> Iterator['Batch']: + def split(self, size: int, shuffle: bool = True, + merge_last: bool = False) -> Iterator['Batch']: """Split whole data into multiple small batches. :param int size: divide the data batch with the given size. :param bool shuffle: randomly shuffle the entire data batch if it is ``True``, otherwise remain in the same. Default to ``True``. + :param bool merge_last: merge the last batch into the previous one. + Default to ``False``. """ length = len(self) + assert 1 <= size <= length if shuffle: indices = np.random.permutation(length) else: indices = np.arange(length) - for idx in np.arange(0, length, size): - yield self[indices[idx:(idx + size)]] + count = length // size + (length % size > 0 and not merge_last) + if merge_last and length % size == 0: + merge_last = False + for idx in range(count): + if idx == count - 1 and merge_last: + yield self[indices[idx * size:]] + else: + yield self[indices[idx * size:(idx + 1) * size]] From d98ed7911ae65d91eb90914d2bf8f7522f82c306 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 23 Aug 2020 07:31:22 +0800 Subject: [PATCH 37/60] fix batch.split --- test/base/test_batch.py | 18 ++++++++++++++++++ tianshou/data/batch.py | 7 ++++--- tianshou/data/collector.py | 8 +++++--- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index a70bf5407..81fabe0c9 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -79,12 +79,30 @@ def test_batch(): bs3 = list(batch.split(3, shuffle=False, merge_last=True)) assert [bs3[i].a.tolist() for i in range(len(bs3))] == [ [0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] + bs5 = list(batch.split(5, shuffle=False)) + assert [bs5[i].a.tolist() for i in range(len(bs5))] == [ + [0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + bs5 = list(batch.split(5, shuffle=False, merge_last=True)) + assert [bs5[i].a.tolist() for i in range(len(bs5))] == [ + [0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] bs7 = list(batch.split(7, shuffle=False)) assert [bs7[i].a.tolist() for i in range(len(bs7))] == [ [0, 1, 2, 3, 4, 5, 6], [7, 8, 9]] bs7 = list(batch.split(7, shuffle=False, merge_last=True)) assert [bs7[i].a.tolist() for i in range(len(bs7))] == [ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] + bs10 = list(batch.split(10, shuffle=False)) + assert [bs10[i].a.tolist() for i in range(len(bs10))] == [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] + bs10 = list(batch.split(10, shuffle=False, merge_last=True)) + assert [bs10[i].a.tolist() for i in range(len(bs10))] == [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] + bs100 = list(batch.split(100, shuffle=False)) + assert [bs100[i].a.tolist() for i in range(len(bs100))] == [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] + bs100 = list(batch.split(100, shuffle=False, merge_last=True)) + assert [bs100[i].a.tolist() for i in range(len(bs100))] == [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 897ce06b5..069bb2b56 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -723,13 +723,14 @@ def split(self, size: int, shuffle: bool = True, Default to ``False``. """ length = len(self) - assert 1 <= size <= length + assert 1 <= size # size can be greater than length, return whole batch if shuffle: indices = np.random.permutation(length) else: indices = np.arange(length) - count = length // size + (length % size > 0 and not merge_last) - if merge_last and length % size == 0: + count = length // size + \ + (length % size > 0 and not merge_last or length < size) + if merge_last and (length % size == 0 or count == 0): merge_last = False for idx in range(count): if idx == count - 1 and merge_last: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index ed9babb8a..e7ff0d08e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -213,8 +213,10 @@ def collect(self, finished_env_ids = [] reward_total = 0.0 whole_data = Batch() + list_n_episode = False if n_episode is not None and not np.isscalar(n_episode): assert len(n_episode) == self.get_env_num() + list_n_episode = True finished_env_ids = [ i for i in self._ready_env_ids if n_episode[i] <= 0] self._ready_env_ids = np.asarray( @@ -303,14 +305,14 @@ def collect(self, self._cached_buf[i].add(**self.data[j]) if done[j]: - if n_step or np.isscalar(n_episode) or \ - episode_count[i] < n_episode[i]: + if not (list_n_episode and + episode_count[i] >= n_episode[i]): episode_count[i] += 1 reward_total += np.sum(self._cached_buf[i].rew, axis=0) step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) - if not n_step and not np.isscalar(n_episode) and \ + if list_n_episode and \ episode_count[i] >= n_episode[i]: # env i has collected enough data, it has finished finished_env_ids.append(i) From 25ddfd7974fc2c833cb7b32c810e1f7d3fef08ef Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 23 Aug 2020 07:31:47 +0800 Subject: [PATCH 38/60] add merge_last in policy --- tianshou/policy/modelfree/a2c.py | 4 ++-- tianshou/policy/modelfree/pg.py | 2 +- tianshou/policy/modelfree/ppo.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 1d835de67..146a97b47 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -68,7 +68,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): - for b in batch.split(self._batch, shuffle=False): + for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(to_numpy(self.critic(b.obs_next))) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( @@ -104,7 +104,7 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): - for b in batch.split(batch_size): + for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs).flatten() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 8fded95ec..d07e66bf0 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -82,7 +82,7 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): - for b in batch.split(batch_size): + for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() dist = self(b).dist a = to_torch_as(b.act, dist.logits) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 3f1fea4fb..ab565951e 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -86,7 +86,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, batch.rew = (batch.rew - mean) / std v, v_, old_log_prob = [], [], [] with torch.no_grad(): - for b in batch.split(self._batch, shuffle=False): + for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(self.critic(b.obs_next)) v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( @@ -137,7 +137,7 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): - for b in batch.split(batch_size): + for b in batch.split(batch_size, merge_last=True): dist = self(b).dist value = self.critic(b.obs).flatten() ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() From e288d8c06236d2358a8069ebcce88dea1c1ef228 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 24 Aug 2020 07:50:32 +0800 Subject: [PATCH 39/60] change merge_last logic and add docs in preprocess_fn --- docs/tutorials/cheatsheet.rst | 2 +- test/discrete/test_drqn.py | 2 +- tianshou/data/batch.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 18e86a084..e9d265a3e 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -96,7 +96,7 @@ This is related to `Issue 42 `_. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. -This function receives typically up to 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a Batch. For example, you can write your hook as: +This function receives typically up to 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. If it receives only ``obs``, it is applied in the env reset; otherwise, it is applied in the normal env step. For example, you can write your hook as: :: import numpy as np diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index c22e0b151..e403c21a1 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -16,7 +16,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 069bb2b56..a01264b3d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -716,7 +716,8 @@ def split(self, size: int, shuffle: bool = True, merge_last: bool = False) -> Iterator['Batch']: """Split whole data into multiple small batches. - :param int size: divide the data batch with the given size. + :param int size: divide the data batch with the given size, but one + batch if the length of the batch is smaller than ``size``. :param bool shuffle: randomly shuffle the entire data batch if it is ``True``, otherwise remain in the same. Default to ``True``. :param bool merge_last: merge the last batch into the previous one. @@ -728,10 +729,9 @@ def split(self, size: int, shuffle: bool = True, indices = np.random.permutation(length) else: indices = np.arange(length) - count = length // size + \ - (length % size > 0 and not merge_last or length < size) - if merge_last and (length % size == 0 or count == 0): - merge_last = False + merge_last = merge_last and length % size > 0 and length >= size + count = length // size + (1 - merge_last) * \ + (length % size > 0 or length < size) for idx in range(count): if idx == count - 1 and merge_last: yield self[indices[idx * size:]] From c3d2bad0270d918caa6dd0ec6c0c47d726bebf9c Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 24 Aug 2020 09:31:49 +0800 Subject: [PATCH 40/60] test_pg is too slow --- test/discrete/test_pg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index b3bacd782..3604adbc6 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -16,10 +16,10 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) From da6277cf758e7bf9c8d22f17e0f6ef3b09e7dded Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 24 Aug 2020 10:26:58 +0800 Subject: [PATCH 41/60] fix tensorboard logging --- tianshou/trainer/offpolicy.py | 12 ++++++------ tianshou/trainer/onpolicy.py | 10 +++++----- tianshou/trainer/utils.py | 11 +++++++++-- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 171cbb9da..e1b722542 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -88,7 +88,7 @@ def offpolicy_trainer( if test_in_train and stop_fn and stop_fn(result['rew']): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test) + epoch, episode_per_test, writer, global_step) if stop_fn and stop_fn(test_result['rew']): if save_fn: save_fn(policy) @@ -104,13 +104,13 @@ def offpolicy_trainer( train_fn(epoch) for i in range(update_per_step * min( result['n/st'] // collect_per_step, t.total - t.n)): - global_step += 1 + global_step += collect_per_step losses = policy.update(batch_size, train_collector.buffer) for k in result.keys(): data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: - writer.add_scalar( - k, result[k], global_step=global_step) + writer.add_scalar('train/' + k, result[k], + global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() @@ -124,8 +124,8 @@ def offpolicy_trainer( if t.n <= t.total: t.update() # test - result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test) + result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, global_step) if best_epoch == -1 or best_reward < result['rew']: best_reward = result['rew'] best_epoch = epoch diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index e31724d66..f3df25425 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -88,7 +88,7 @@ def onpolicy_trainer( if test_in_train and stop_fn and stop_fn(result['rew']): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test) + epoch, episode_per_test, writer, global_step) if stop_fn and stop_fn(test_result['rew']): if save_fn: save_fn(policy) @@ -109,12 +109,12 @@ def onpolicy_trainer( for k in losses.keys(): if isinstance(losses[k], list): step = max(step, len(losses[k])) - global_step += step + global_step += step * collect_per_step for k in result.keys(): data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: writer.add_scalar( - k, result[k], global_step=global_step) + 'train/' + k, result[k], global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() @@ -128,8 +128,8 @@ def onpolicy_trainer( if t.n <= t.total: t.update() # test - result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test) + result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, global_step) if best_epoch == -1 or best_reward < result['rew']: best_reward = result['rew'] best_epoch = epoch diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index eb9bd3245..32b936ba5 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,6 +1,7 @@ import time import numpy as np from typing import Dict, List, Union, Callable +from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -11,7 +12,9 @@ def test_episode( collector: Collector, test_fn: Callable[[int], None], epoch: int, - n_episode: Union[int, List[int]]) -> Dict[str, float]: + n_episode: Union[int, List[int]], + writer: SummaryWriter = None, + global_step: int = None) -> Dict[str, float]: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() @@ -23,7 +26,11 @@ def test_episode( n_ = np.zeros(n) + n_episode // n n_[:n_episode % n] += 1 n_episode = list(n_) - return collector.collect(n_episode=n_episode) + result = collector.collect(n_episode=n_episode) + if writer is not None and global_step is not None: + for k in result.keys(): + writer.add_scalar('test/' + k, result[k], global_step=global_step) + return result def gather_info(start_time: float, From 50aa5cf620edf4bb88963a8846df242854f764ab Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 24 Aug 2020 10:34:34 +0800 Subject: [PATCH 42/60] add a check of buffer --- test/base/test_buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 8b96466b8..16bf5c34f 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -27,6 +27,8 @@ def test_replaybuffer(size=10, bufsize=20): assert len(buf) == min(bufsize, i + 1) with pytest.raises(ValueError): buf._add_to_buffer('rew', np.array([1, 2, 3])) + assert buf.act.dtype == np.object + assert isinstance(buf.act[0], list) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() From 4b90396610ca8edfc020e2c8efce2489731b3978 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 24 Aug 2020 10:54:09 +0800 Subject: [PATCH 43/60] size 2000 -> 256 --- tianshou/policy/modelfree/a2c.py | 2 +- tianshou/policy/modelfree/ppo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 146a97b47..1050ad8ab 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -48,7 +48,7 @@ def __init__(self, max_grad_norm: Optional[float] = None, gae_lambda: float = 0.95, reward_normalization: bool = False, - max_batchsize: int = 2000, + max_batchsize: int = 256, **kwargs) -> None: super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self.actor = actor diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index ab565951e..2b0e9586d 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -58,7 +58,7 @@ def __init__(self, dual_clip: Optional[float] = None, value_clip: bool = True, reward_normalization: bool = True, - max_batchsize: int = 2000, + max_batchsize: int = 256, **kwargs) -> None: super().__init__(None, None, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm From c8ed82d99805db54efeb8f1a6b9c020b66af9053 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 24 Aug 2020 12:00:18 +0800 Subject: [PATCH 44/60] simplify test batch.split --- test/base/test_batch.py | 53 +++++++++++++---------------------------- 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 81fabe0c9..ce85afcdf 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -67,42 +67,23 @@ def test_batch(): batch = Batch(a=np.arange(10)) with pytest.raises(AssertionError): list(batch.split(0)) - bs1 = list(batch.split(1, shuffle=False)) - assert len(bs1) == len(batch) - assert [b.a for b in bs1] == list(range(len(batch))) - bs1 = list(batch.split(1, shuffle=False, merge_last=True)) - # since 10 % 1 == 0, the merge_last will not work on split - assert [b.a for b in bs1] == list(range(len(batch))) - bs3 = list(batch.split(3, shuffle=False)) - assert [bs3[i].a.tolist() for i in range(len(bs3))] == [ - [0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] - bs3 = list(batch.split(3, shuffle=False, merge_last=True)) - assert [bs3[i].a.tolist() for i in range(len(bs3))] == [ - [0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] - bs5 = list(batch.split(5, shuffle=False)) - assert [bs5[i].a.tolist() for i in range(len(bs5))] == [ - [0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] - bs5 = list(batch.split(5, shuffle=False, merge_last=True)) - assert [bs5[i].a.tolist() for i in range(len(bs5))] == [ - [0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] - bs7 = list(batch.split(7, shuffle=False)) - assert [bs7[i].a.tolist() for i in range(len(bs7))] == [ - [0, 1, 2, 3, 4, 5, 6], [7, 8, 9]] - bs7 = list(batch.split(7, shuffle=False, merge_last=True)) - assert [bs7[i].a.tolist() for i in range(len(bs7))] == [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] - bs10 = list(batch.split(10, shuffle=False)) - assert [bs10[i].a.tolist() for i in range(len(bs10))] == [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] - bs10 = list(batch.split(10, shuffle=False, merge_last=True)) - assert [bs10[i].a.tolist() for i in range(len(bs10))] == [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] - bs100 = list(batch.split(100, shuffle=False)) - assert [bs100[i].a.tolist() for i in range(len(bs100))] == [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] - bs100 = list(batch.split(100, shuffle=False, merge_last=True)) - assert [bs100[i].a.tolist() for i in range(len(bs100))] == [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] + data = [ + (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), + (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), + (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), + (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), + (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), + (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), + (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), + (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + ] + for size, merge_last, result in data: + bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) + assert [bs[i].a.tolist() for i in range(len(bs))] == result batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) From 96bc6903c6cf9a598082a468cf7c0a2dee2c3b56 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 25 Aug 2020 06:58:49 +0800 Subject: [PATCH 45/60] fix docstring --- tianshou/policy/modelfree/a2c.py | 2 +- tianshou/policy/modelfree/ppo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 1050ad8ab..4ae122502 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -28,7 +28,7 @@ class A2CPolicy(PGPolicy): :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. :param int max_batchsize: the maximum number of batchsize when computing - GAE, defaults to 2000. + GAE, defaults to 256. .. seealso:: diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 2b0e9586d..1d89188ae 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -35,7 +35,7 @@ class PPOPolicy(PGPolicy): :param bool reward_normalization: normalize the returns to Normal(0, 1), defaults to ``True``. :param int max_batchsize: the maximum number of batchsize when computing - GAE, defaults to 2000. + GAE, defaults to 256. .. seealso:: From 20b7b488e9fdc74cd33d1417521b6ab3d09bd757 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 25 Aug 2020 13:39:40 +0800 Subject: [PATCH 46/60] simplify batch.split --- test/base/test_batch.py | 2 ++ tianshou/data/batch.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index ce85afcdf..25e50d40d 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -78,6 +78,8 @@ def test_batch(): (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), ] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index a01264b3d..22305ee74 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -729,11 +729,9 @@ def split(self, size: int, shuffle: bool = True, indices = np.random.permutation(length) else: indices = np.arange(length) - merge_last = merge_last and length % size > 0 and length >= size - count = length // size + (1 - merge_last) * \ - (length % size > 0 or length < size) - for idx in range(count): - if idx == count - 1 and merge_last: - yield self[indices[idx * size:]] - else: - yield self[indices[idx * size:(idx + 1) * size]] + merge_last = merge_last and length % size > 0 + for idx in range(0, length, size): + if merge_last and idx + size + size >= length: + yield self[indices[idx:]] + break + yield self[indices[idx:idx + size]] From 340931f5a364bc841f70cd1f459b9d5b56153bc2 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 18:34:24 +0800 Subject: [PATCH 47/60] optimize for batch.{cat/stack/empty} --- tianshou/data/batch.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 22305ee74..148103b9b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -421,8 +421,7 @@ def __cat(self, # cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch())) # will fail here v = np.concatenate(v) - v = _to_array_with_correct_type(v) - self.__dict__[k] = v + self.__dict__[k] = _to_array_with_correct_type(v) keys_total = set.union(*[set(b.keys()) for b in batches]) keys_reserve_or_partial = set.difference(keys_total, keys_shared) # keys that are reserved in all batches @@ -442,8 +441,8 @@ def __cat(self, try: self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val except KeyError: - self.__dict__[k] = \ - _create_value(val, sum_lens[-1], stack=False) + self.__dict__[k] = _create_value( + val, sum_lens[-1], stack=False) self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val def cat_(self, @@ -516,14 +515,13 @@ def stack_(self, keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] for k, v in zip(keys_shared, values_shared): - if all(isinstance(e, (dict, Batch)) for e in v): - self.__dict__[k] = Batch.stack(v, axis) - elif all(isinstance(e, torch.Tensor) for e in v): + if all(isinstance(e, torch.Tensor) for e in v): # second often self.__dict__[k] = torch.stack(v, axis) - else: + elif all(isinstance(e, (Batch, dict)) for e in v): # third often + self.__dict__[k] = Batch.stack(v, axis) + else: # most often case is np.ndarray v = np.stack(v, axis) - v = _to_array_with_correct_type(v) - self.__dict__[k] = v + self.__dict__[k] = _to_array_with_correct_type(v) # all the keys keys_total = set.union(*[set(b.keys()) for b in batches]) # keys that are reserved in all batches @@ -549,8 +547,7 @@ def stack_(self, try: self.__dict__[k][i] = val except KeyError: - self.__dict__[k] = \ - _create_value(val, len(batches)) + self.__dict__[k] = _create_value(val, len(batches)) self.__dict__[k][i] = val @staticmethod @@ -607,17 +604,17 @@ def empty_(self, index: Union[ ) """ for k, v in self.items(): - if v is None: - continue - if isinstance(v, Batch): - self.__dict__[k].empty_(index=index) - elif isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor): # most often case self.__dict__[k][index] = 0 + elif v is None: + continue elif isinstance(v, np.ndarray): if v.dtype == np.object: self.__dict__[k][index] = None else: self.__dict__[k][index] = 0 + elif isinstance(v, Batch): + self.__dict__[k].empty_(index=index) else: # scalar value warnings.warn('You are calling Batch.empty on a NumPy scalar, ' 'which may cause undefined behaviors.') From 1ca2d9838b9da9d5d4da1e18ad99e33458319a6c Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 20:25:12 +0800 Subject: [PATCH 48/60] remove buffer **kwargs --- test/base/test_batch.py | 3 +-- test/throughput/test_buffer_profile.py | 4 +--- test/throughput/test_collector_profile.py | 9 +++------ tianshou/data/buffer.py | 2 +- 4 files changed, 6 insertions(+), 12 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 25e50d40d..650d56080 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -124,8 +124,7 @@ def test_batch(): assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e - for batch_slice in [ - batch2[slice(0, 1)], batch2[:1], batch2[0:]]: + for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index b37444681..3134004f1 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -28,9 +28,7 @@ def data(): def test_init(): for _ in np.arange(1e5): _ = ReplayBuffer(1e5) - _ = PrioritizedReplayBuffer( - size=int(1e5), alpha=0.5, - beta=0.5, repeat_sample=True) + _ = PrioritizedReplayBuffer(size=int(1e5), alpha=0.5, beta=0.5) _ = ListReplayBuffer() diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index f9d8a3e4e..21260f5ec 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -22,8 +22,7 @@ def __init__(self): def reset(self): self._index = 0 self.done = np.random.randint(3, high=200) - return {'observable': np.zeros((10, 10, 1)), - 'hidden': self._index} + return {'observable': np.zeros((10, 10, 1)), 'hidden': self._index} def step(self, action): if self._index == self.done: @@ -56,11 +55,9 @@ def data(): np.random.seed(0) env = SimpleEnv() env.seed(0) - env_vec = DummyVectorEnv( - [lambda: SimpleEnv() for _ in range(100)]) + env_vec = DummyVectorEnv([lambda: SimpleEnv() for _ in range(100)]) env_vec.seed(np.random.randint(1000, size=100).tolist()) - env_subproc = SubprocVectorEnv( - [lambda: SimpleEnv() for _ in range(8)]) + env_subproc = SubprocVectorEnv([lambda: SimpleEnv() for _ in range(8)]) env_subproc.seed(np.random.randint(1000, size=100).tolist()) env_subproc_init = SubprocVectorEnv( [lambda: SimpleEnv() for _ in range(8)]) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 7bea14eff..f41785066 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -122,7 +122,7 @@ class ReplayBuffer: def __init__(self, size: int, stack_num: int = 1, ignore_obs_next: bool = False, - sample_avail: bool = False, **kwargs) -> None: + sample_avail: bool = False) -> None: super().__init__() self._maxsize = size self._indices = np.arange(size) From 9dce2c70d1fb1c01bcfae537989ef312206fe3b8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 21:14:06 +0800 Subject: [PATCH 49/60] fix some type --- setup.py | 4 +++- tianshou/env/venvs.py | 18 +++++++++--------- tianshou/env/worker/base.py | 1 + tianshou/exploration/random.py | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 64aac40b2..40873b97d 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,9 @@ # Get the version string with open(path.join(here, 'tianshou', '__init__.py')) as f: - version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) + raw = re.search(r'__version__ = \'(.*?)\'', f.read()) + assert raw is not None + version = raw.group(1) setup( name='tianshou', diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 7e0f9c17b..4163b2565 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -120,12 +120,11 @@ def _wrap_id( self, id: Optional[Union[int, List[int]]] = None) -> List[int]: if id is None: id = list(range(self.env_num)) - elif np.isscalar(id): + elif not isinstance(id, list): id = [id] return id - def _assert_id( - self, id: Optional[Union[int, List[int]]] = None) -> List[int]: + def _assert_id(self, id: List[int]) -> None: for i in id: assert i not in self.waiting_id, \ f'Cannot interact with environment {i} which is stepping now.' @@ -145,15 +144,16 @@ def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: return obs def step(self, - action: Optional[np.ndarray], + action: np.ndarray, id: Optional[Union[int, List[int]]] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + ) -> List[np.ndarray]: """Run one timestep of all the environments’ dynamics if id is "None", otherwise run one timestep for some environments with given id, either an int or a list. When the end of episode is reached, you are responsible for calling reset(id) to reset this environment’s state. - Accept a batch of action and return a tuple (obs, rew, done, info). + Accept a batch of action and return a tuple (obs, rew, done, info) in + numpy format. :param numpy.ndarray action: a batch of action provided by the agent. @@ -222,10 +222,10 @@ def seed(self, which a reproducer pass to "seed". """ self._assert_is_not_closed() - if np.isscalar(seed): - seed = [seed + _ for _ in range(self.env_num)] - elif seed is None: + if seed is None: seed = [seed] * self.env_num + elif not isinstance(seed, list): + seed = [seed + _ for _ in range(self.env_num)] return [w.seed(s) for w, s in zip(self.workers, seed)] def render(self, **kwargs) -> List[Any]: diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 87fb6c2e8..2b56dab9b 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -10,6 +10,7 @@ class EnvWorker(ABC): def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False + self.result = (None, None, None, None) @abstractmethod def __getattr__(self, key: str) -> Any: diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index 34fd50399..19f4424cc 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -14,7 +14,7 @@ def __call__(self, **kwargs) -> np.ndarray: """Generate new noise.""" raise NotImplementedError - def reset(self, **kwargs) -> None: + def reset(self) -> None: """Reset to the initial state.""" pass From df4129f5d3595d3eace3469dac368799653c82fb Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 21:34:21 +0800 Subject: [PATCH 50/60] fix --- tianshou/data/batch.py | 9 +++++---- tianshou/env/venvs.py | 13 +++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 148103b9b..60dc7a21d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -5,7 +5,8 @@ from copy import deepcopy from numbers import Number from collections.abc import Collection -from typing import Any, List, Tuple, Union, Iterator, Optional +from typing import Any, List, Tuple, Union, Iterator, KeysView, ValuesView, \ + ItemsView, Optional # Disable pickle warning related to torch, since it has been removed # on torch master branch. See Pull Request #39003 for details: @@ -318,15 +319,15 @@ def __contains__(self, key: str) -> bool: """Return key in self.""" return key in self.__dict__ - def keys(self) -> List[str]: + def keys(self) -> KeysView[str]: """Return self.keys().""" return self.__dict__.keys() - def values(self) -> List[Any]: + def values(self) -> ValuesView[Any]: """Return self.values().""" return self.__dict__.values() - def items(self) -> List[Tuple[str, Any]]: + def items(self) -> ItemsView[str, Any]: """Return self.items().""" return self.__dict__.items() diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 4163b2565..f11b86cfe 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -1,7 +1,7 @@ import gym import warnings import numpy as np -from typing import List, Tuple, Union, Optional, Callable, Any +from typing import List, Union, Optional, Callable, Any from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \ RayEnvWorker @@ -116,11 +116,11 @@ def __getattr__(self, key: str) -> Any: """ return [getattr(worker, key) for worker in self.workers] - def _wrap_id( - self, id: Optional[Union[int, List[int]]] = None) -> List[int]: + def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> List[int]: if id is None: id = list(range(self.env_num)) - elif not isinstance(id, list): + elif np.isscalar(id): id = [id] return id @@ -131,7 +131,8 @@ def _assert_id(self, id: List[int]) -> None: assert i in self.ready_id, \ f'Can only interact with ready environments {self.ready_id}.' - def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: + def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> np.ndarray: """Reset the state of all the environments and return initial observations if id is ``None``, otherwise reset the specific environments with the given id, either an int or a list. @@ -145,7 +146,7 @@ def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: def step(self, action: np.ndarray, - id: Optional[Union[int, List[int]]] = None + id: Optional[Union[int, List[int], np.ndarray]] = None ) -> List[np.ndarray]: """Run one timestep of all the environments’ dynamics if id is "None", otherwise run one timestep for some environments with given id, either From c1dade5da4230be532b5be9d60b2b34d20e852c6 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 21:41:55 +0800 Subject: [PATCH 51/60] reserve only one bypass in collector --- tianshou/data/collector.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index e7ff0d08e..e642ad780 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -258,12 +258,10 @@ def collect(self, # convert None to Batch(), since None is reserved for 0-init if state is None: state = Batch() - # since result is a Batch, it can bypass the type check here - self.data.__dict__['state'] = state - self.data.__dict__['policy'] = result.get('policy', Batch()) + self.data.update(state=state, policy=result.get('policy', Batch())) # save hidden state to policy._state, in order to save into buffer if not (isinstance(state, Batch) and state.is_empty()): - self.data.policy.__dict__['_state'] = self.data.state + self.data.policy._state = self.data.state self.data.act = to_numpy(result.act) if self._action_noise is not None: # noqa From efb4e001f18c3dd1db447a1b80598e03a347b5d7 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 22:01:21 +0800 Subject: [PATCH 52/60] minor fix --- tianshou/data/collector.py | 6 +++--- tianshou/env/venvs.py | 2 +- tianshou/policy/base.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index e642ad780..b928d22c3 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -219,7 +219,7 @@ def collect(self, list_n_episode = True finished_env_ids = [ i for i in self._ready_env_ids if n_episode[i] <= 0] - self._ready_env_ids = np.asarray( + self._ready_env_ids = np.array( [x for x in self._ready_env_ids if x not in finished_env_ids]) while True: if step_count >= 100000 and episode_count.sum() == 0: @@ -334,7 +334,7 @@ def collect(self, self.data, self.env_num) # let self.data be the data in all environments again self.data = whole_data - self._ready_env_ids = np.asarray( + self._ready_env_ids = np.array( [x for x in self._ready_env_ids if x not in finished_env_ids]) if n_step: if step_count >= n_step: @@ -348,7 +348,7 @@ def collect(self, break # finished envs are ready, and can be used for the next collection - self._ready_env_ids = np.asarray( + self._ready_env_ids = np.array( self._ready_env_ids.tolist() + finished_env_ids) # generate the statistics diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index f11b86cfe..81ee646b5 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -226,7 +226,7 @@ def seed(self, if seed is None: seed = [seed] * self.env_num elif not isinstance(seed, list): - seed = [seed + _ for _ in range(self.env_num)] + seed = [seed + i for i in range(self.env_num)] return [w.seed(s) for w, s in zip(self.workers, seed)] def render(self, **kwargs) -> List[Any]: diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 92895b3fa..ad35e9675 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -213,14 +213,14 @@ def compute_nstep_return( returns[done[now] > 0] = 0 returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len - target_q_ = target_q_fn(buffer, terminal).flatten() # shape: (bsz, ) - target_q = to_numpy(target_q_) + target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, ) + target_q = to_numpy(target_q_torch) target_q[gammas != n_step] = 0 target_q = target_q * (gamma ** gammas) + returns - batch.returns = to_torch_as(target_q, target_q_) + batch.returns = to_torch_as(target_q, target_q_torch) # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): - batch.weight = to_torch_as(batch.weight, target_q_) + batch.weight = to_torch_as(batch.weight, target_q_torch) return batch def post_process_fn(self, batch: Batch, From 089d784da74c698a11de995cbb75416daca1ae8b Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 22:21:50 +0800 Subject: [PATCH 53/60] minor fix for some docstrings --- docs/tutorials/cheatsheet.rst | 2 +- tianshou/data/batch.py | 2 +- tianshou/data/collector.py | 3 ++- tianshou/data/utils/converter.py | 4 ++-- tianshou/policy/modelfree/a2c.py | 4 ++-- tianshou/policy/modelfree/ppo.py | 4 ++-- 6 files changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index e9d265a3e..32358439b 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -96,7 +96,7 @@ This is related to `Issue 42 `_. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. -This function receives typically up to 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. If it receives only ``obs``, it is applied in the env reset; otherwise, it is applied in the normal env step. For example, you can write your hook as: +This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env reset, while every key is specified for normal steps. For example, you can write your hook as: :: import numpy as np diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 60dc7a21d..5308c273d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -380,7 +380,7 @@ def to_torch(self, dtype: Optional[torch.dtype] = None, self.__dict__[k] = v def __cat(self, - batches: Union['Batch', List[Union[dict, 'Batch']]], + batches: List[Union[dict, 'Batch']], lens: List[int]) -> None: """:: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b928d22c3..f5dd5244b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -202,7 +202,8 @@ def collect(self, """ assert (n_step is not None and n_episode is None and n_step > 0) or ( n_step is None and n_episode is not None and np.sum(n_episode) > 0 - ), "One and only one collection number specification is permitted!" + ), "Only one of n_step or n_episode is allowed in Collector.collect, " + f"got n_step = {n_step}, n_episode = {n_episode}." start_time = time.time() step_count = 0 # episode of each environment diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 45f97f6fd..d6d9bd5d8 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -62,9 +62,9 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], return x -def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], +def to_torch_as(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], y: torch.Tensor - ) -> Union[dict, Batch, torch.Tensor]: + ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without np.ndarray. Same as ``to_torch(x, dtype=y.dtype, device=y.device)``. """ diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 4ae122502..f3b7a4c23 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -27,8 +27,8 @@ class A2CPolicy(PGPolicy): Estimation, defaults to 0.95. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. - :param int max_batchsize: the maximum number of batchsize when computing - GAE, defaults to 256. + :param int max_batchsize: the maximum size of the batch when computing GAE, + defaults to 256. .. seealso:: diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 1d89188ae..b64efac19 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -34,8 +34,8 @@ class PPOPolicy(PGPolicy): defaults to ``True``. :param bool reward_normalization: normalize the returns to Normal(0, 1), defaults to ``True``. - :param int max_batchsize: the maximum number of batchsize when computing - GAE, defaults to 256. + :param int max_batchsize: the maximum size of the batch when computing GAE, + defaults to 256. .. seealso:: From 225ab6a855758e38ddd74e3465eb6c6e9ed3d7b3 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 22:28:26 +0800 Subject: [PATCH 54/60] simplify setup.py --- setup.py | 12 ++---------- tianshou/data/utils/converter.py | 4 +++- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 40873b97d..904037254 100644 --- a/setup.py +++ b/setup.py @@ -3,20 +3,12 @@ from setuptools import setup, find_packages -import re -from os import path +import tianshou -here = path.abspath(path.dirname(__file__)) - -# Get the version string -with open(path.join(here, 'tianshou', '__init__.py')) as f: - raw = re.search(r'__version__ = \'(.*?)\'', f.read()) - assert raw is not None - version = raw.group(1) setup( name='tianshou', - version=version, + version=tianshou.__version__, description='A Library for Deep Reinforcement Learning', long_description=open('README.md', encoding='utf8').read(), long_description_content_type='text/markdown', diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index d6d9bd5d8..e97b05416 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -14,8 +14,10 @@ def to_numpy(x: Union[ x = x.detach().cpu().numpy() elif isinstance(x, np.ndarray): # second often case pass - elif isinstance(x, (np.number, np.bool_, Number)) or x is None: + elif isinstance(x, (np.number, np.bool_, Number)): x = np.asanyarray(x) + elif x is None: + x = np.array(None, dtype=np.object) elif isinstance(x, Batch): x.to_numpy() elif isinstance(x, dict): From 5c2c5060ef07fc5e5ebe57a69e931bbaf661a740 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 26 Aug 2020 22:47:47 +0800 Subject: [PATCH 55/60] test new init --- tianshou/__init__.py | 11 ----------- tianshou/data/batch.py | 4 ++-- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index d73c93fc0..44b18069b 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,12 +1 @@ -from tianshou import data, env, utils, policy, trainer, \ - exploration - __version__ = '0.2.6' -__all__ = [ - 'env', - 'data', - 'utils', - 'policy', - 'trainer', - 'exploration', -] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 5308c273d..18d7c7276 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -126,8 +126,8 @@ def _parse_value(v: Any): return v elif _is_number(v): # second often case return np.asanyarray(v) - elif v is None or isinstance(v, np.ndarray) and \ - issubclass(v.dtype.type, (np.bool_, np.number)) or \ + elif v is None or (isinstance(v, np.ndarray) and + issubclass(v.dtype.type, (np.bool_, np.number))) or \ isinstance(v, torch.Tensor): # third often case return v elif isinstance(v, dict): From 9eee84f1f79c6852d3107385b0839c43c28ea978 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 27 Aug 2020 09:20:37 +0800 Subject: [PATCH 56/60] version file --- setup.py | 5 ++--- tianshou/__init__.py | 17 ++++++++++++++++- tianshou/data/batch.py | 10 +++++----- tianshou/data/collector.py | 2 +- tianshou/version.txt | 1 + 5 files changed, 25 insertions(+), 10 deletions(-) create mode 100644 tianshou/version.txt diff --git a/setup.py b/setup.py index 904037254..d56479895 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import os from setuptools import setup, find_packages -import tianshou - setup( name='tianshou', - version=tianshou.__version__, + version=open(os.path.join('tianshou', 'version.txt')).read().strip(), description='A Library for Deep Reinforcement Learning', long_description=open('README.md', encoding='utf8').read(), long_description_content_type='text/markdown', diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 44b18069b..31f8c3308 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1 +1,16 @@ -__version__ = '0.2.6' +import os + +from tianshou import data, env, utils, policy, trainer, \ + exploration + +version_file = os.path.join(os.path.dirname(__file__), "version.txt") + +__version__ = open(version_file).read().strip() +__all__ = [ + 'env', + 'data', + 'utils', + 'policy', + 'trainer', + 'exploration', +] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 18d7c7276..fe70d1f7a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -124,12 +124,12 @@ def _assert_type_keys(keys) -> None: def _parse_value(v: Any): if isinstance(v, Batch): # most often case return v - elif _is_number(v): # second often case - return np.asanyarray(v) - elif v is None or (isinstance(v, np.ndarray) and - issubclass(v.dtype.type, (np.bool_, np.number))) or \ - isinstance(v, torch.Tensor): # third often case + elif (isinstance(v, np.ndarray) and + issubclass(v.dtype.type, (np.bool_, np.number))) or \ + isinstance(v, torch.Tensor) or v is None: # third often case return v + elif _is_number(v): # second often case, but it is more time-consuming + return np.asanyarray(v) elif isinstance(v, dict): return Batch(v) else: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index f5dd5244b..7b51b0266 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -327,7 +327,7 @@ def collect(self, obs=obs_reset).get('obs', obs_reset) else: obs_next[env_ind_local] = obs_reset - self.data.__dict__['obs'] = obs_next + self.data.obs = obs_next if is_async: # set data back whole_data = deepcopy(whole_data) # avoid reference in ListBuf diff --git a/tianshou/version.txt b/tianshou/version.txt new file mode 100644 index 000000000..53a75d673 --- /dev/null +++ b/tianshou/version.txt @@ -0,0 +1 @@ +0.2.6 From c1ac1dcb79f9afdbcf576ca9862f6373f5bbe7bd Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 27 Aug 2020 09:28:09 +0800 Subject: [PATCH 57/60] version file --- setup.py | 3 +-- tianshou/__init__.py | 5 +---- tianshou/version.txt | 1 - 3 files changed, 2 insertions(+), 7 deletions(-) delete mode 100644 tianshou/version.txt diff --git a/setup.py b/setup.py index d56479895..d7487e269 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os from setuptools import setup, find_packages setup( name='tianshou', - version=open(os.path.join('tianshou', 'version.txt')).read().strip(), + version='0.2.6', description='A Library for Deep Reinforcement Learning', long_description=open('README.md', encoding='utf8').read(), long_description_content_type='text/markdown', diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 31f8c3308..77016dc85 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,11 +1,8 @@ -import os - from tianshou import data, env, utils, policy, trainer, \ exploration -version_file = os.path.join(os.path.dirname(__file__), "version.txt") -__version__ = open(version_file).read().strip() +__version__ = '0.2.6' __all__ = [ 'env', 'data', diff --git a/tianshou/version.txt b/tianshou/version.txt deleted file mode 100644 index 53a75d673..000000000 --- a/tianshou/version.txt +++ /dev/null @@ -1 +0,0 @@ -0.2.6 From 9b643cfeb6402a98ec2bc852287c9ae365f0deb4 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 27 Aug 2020 10:09:39 +0800 Subject: [PATCH 58/60] fix some type annotation --- tianshou/data/buffer.py | 2 +- tianshou/data/collector.py | 3 ++- tianshou/policy/base.py | 7 +++++-- tianshou/policy/modelfree/a2c.py | 3 +-- tianshou/policy/modelfree/dqn.py | 2 +- tianshou/policy/modelfree/pg.py | 3 +-- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 2 +- tianshou/trainer/utils.py | 4 ++-- tianshou/utils/net/common.py | 7 ++++--- 10 files changed, 19 insertions(+), 16 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index f41785066..7d77b3058 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -399,7 +399,7 @@ def add(self, obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None, info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, - weight: float = None, + weight: Optional[float] = None, **kwargs) -> None: """Add a batch of data into replay buffer.""" if weight is None: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 7b51b0266..57f885d9c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -80,7 +80,7 @@ def __init__(self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, - preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, + preprocess_fn: Callable[[Any], Batch] = None, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: @@ -384,6 +384,7 @@ def sample(self, batch_size: int) -> Batch: 'Collector.sample is deprecated and will cause error if you use ' 'prioritized experience replay! Collector.sample will be removed ' 'upon version 0.3. Use policy.update instead!', Warning) + assert self.buffer is not None, "Cannot get sample from empty buffer!" batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index ad35e9675..5a6c01dd7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -224,7 +224,7 @@ def compute_nstep_return( return batch def post_process_fn(self, batch: Batch, - buffer: ReplayBuffer, indice: np.ndarray): + buffer: ReplayBuffer, indice: np.ndarray) -> None: """Post-process the data from the provided replay buffer. Typical usage is to update the sampling weight in prioritized experience replay. Check out :ref:`policy_concept` for more information. @@ -233,7 +233,8 @@ def post_process_fn(self, batch: Batch, and hasattr(batch, 'weight'): buffer.update_weight(indice, batch.weight) - def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs): + def update(self, batch_size: int, buffer: Optional[ReplayBuffer], + *args, **kwargs) -> Dict[str, Union[float, List[float]]]: """Update the policy network and replay buffer (if needed). It includes three function steps: process_fn, learn, and post_process_fn. @@ -241,6 +242,8 @@ def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs): buffer, otherwise it will sample a batch with the given batch_size. :param ReplayBuffer buffer: the corresponding replay buffer. """ + if buffer is None: + return {} batch, indice = buffer.sample(batch_size) batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, *args, **kwargs) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index f3b7a4c23..1a9b7fc44 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -40,8 +40,7 @@ def __init__(self, actor: torch.nn.Module, critic: torch.nn.Module, optim: torch.optim.Optimizer, - dist_fn: torch.distributions.Distribution - = torch.distributions.Categorical, + dist_fn: torch.distributions.Distribution, discount_factor: float = 0.99, vf_coef: float = .5, ent_coef: float = .01, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 9d562f1f7..5c5e45d5a 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -37,7 +37,7 @@ def __init__(self, optim: torch.optim.Optimizer, discount_factor: float = 0.99, estimation_step: int = 1, - target_update_freq: Optional[int] = 0, + target_update_freq: int = 0, reward_normalization: bool = False, **kwargs) -> None: super().__init__(**kwargs) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index d07e66bf0..3eaae641e 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -24,8 +24,7 @@ class PGPolicy(BasePolicy): def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, - dist_fn: torch.distributions.Distribution - = torch.distributions.Categorical, + dist_fn: torch.distributions.Distribution, discount_factor: float = 0.99, reward_normalization: bool = False, **kwargs) -> None: diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index e1b722542..153f94d9c 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -71,7 +71,7 @@ def offpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ global_step = 0 - best_epoch, best_reward = -1, -1 + best_epoch, best_reward = -1, -1. stat = {} start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index f3df25425..ea57ed1c1 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -71,7 +71,7 @@ def onpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ global_step = 0 - best_epoch, best_reward = -1, -1 + best_epoch, best_reward = -1, -1. stat = {} start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 32b936ba5..ba914d842 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,7 +1,7 @@ import time import numpy as np -from typing import Dict, List, Union, Callable from torch.utils.tensorboard import SummaryWriter +from typing import Dict, List, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -10,7 +10,7 @@ def test_episode( policy: BasePolicy, collector: Collector, - test_fn: Callable[[int], None], + test_fn: Optional[Callable[[int], None]], epoch: int, n_episode: Union[int, List[int]], writer: SummaryWriter = None, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index a84a7e7cd..eb68a9710 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,12 +1,13 @@ import torch import numpy as np from torch import nn -from typing import Tuple, Union, Optional +from typing import List, Tuple, Union, Optional from tianshou.data import to_torch -def miniblock(inp: int, oup: int, norm_layer: nn.modules.Module): +def miniblock(inp: int, oup: int, + norm_layer: nn.modules.Module) -> List[nn.modules.Module]: ret = [nn.Linear(inp, oup)] if norm_layer is not None: ret += [norm_layer(oup)] @@ -28,7 +29,7 @@ class Net(nn.Module): """ def __init__(self, layer_num: int, state_shape: tuple, - action_shape: Optional[tuple] = 0, + action_shape: Optional[Union[tuple, int]] = 0, device: Union[str, torch.device] = 'cpu', softmax: bool = False, concat: bool = False, From 12bc4cd775b43e112378341bf9302b3a8efbf17b Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 27 Aug 2020 11:31:10 +0800 Subject: [PATCH 59/60] docstring --- tianshou/env/venvs.py | 6 +++--- tianshou/policy/modelfree/a2c.py | 2 +- tianshou/policy/modelfree/ppo.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 81ee646b5..04323498d 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -153,8 +153,8 @@ def step(self, an int or a list. When the end of episode is reached, you are responsible for calling reset(id) to reset this environment’s state. - Accept a batch of action and return a tuple (obs, rew, done, info) in - numpy format. + Accept a batch of action and return a tuple (batch_obs, batch_rew, + batch_done, batch_info) in numpy format. :param numpy.ndarray action: a batch of action provided by the agent. @@ -225,7 +225,7 @@ def seed(self, self._assert_is_not_closed() if seed is None: seed = [seed] * self.env_num - elif not isinstance(seed, list): + elif np.isscalar(seed): seed = [seed + i for i in range(self.env_num)] return [w.seed(s) for w, s in zip(self.workers, seed)] diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 1a9b7fc44..e3cf6ef58 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -28,7 +28,7 @@ class A2CPolicy(PGPolicy): :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. :param int max_batchsize: the maximum size of the batch when computing GAE, - defaults to 256. + should be as large as possible for efficiency; defaults to 256. .. seealso:: diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index b64efac19..70c24f09b 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -35,7 +35,7 @@ class PPOPolicy(PGPolicy): :param bool reward_normalization: normalize the returns to Normal(0, 1), defaults to ``True``. :param int max_batchsize: the maximum size of the batch when computing GAE, - defaults to 256. + should be as large as possible for efficiency; defaults to 256. .. seealso:: From d1b18313387a6dba9f8a8b3d490bd4043060c2e6 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 27 Aug 2020 11:46:40 +0800 Subject: [PATCH 60/60] docstring --- tianshou/policy/modelfree/a2c.py | 4 +++- tianshou/policy/modelfree/ppo.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index e3cf6ef58..0f7cffd58 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -28,7 +28,9 @@ class A2CPolicy(PGPolicy): :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. :param int max_batchsize: the maximum size of the batch when computing GAE, - should be as large as possible for efficiency; defaults to 256. + depends on the size of available memory and the memory cost of the + model; should be as large as possible within the memory constraint; + defaults to 256. .. seealso:: diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 70c24f09b..2db5baf3b 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -35,7 +35,9 @@ class PPOPolicy(PGPolicy): :param bool reward_normalization: normalize the returns to Normal(0, 1), defaults to ``True``. :param int max_batchsize: the maximum size of the batch when computing GAE, - should be as large as possible for efficiency; defaults to 256. + depends on the size of available memory and the memory cost of the + model; should be as large as possible within the memory constraint; + defaults to 256. .. seealso::