From ade94e4b8c59e9a4e2b4a6babae498e785de86ae Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 08:11:25 +0800 Subject: [PATCH 01/17] remove _multi_env flag; fix info bug --- tianshou/data/collector.py | 109 +++++++++++++------------------------ 1 file changed, 39 insertions(+), 70 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 98c62daa2..d53505040 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional, Callable from tianshou.utils import MovAvg -from tianshou.env import BaseVectorEnv +from tianshou.env import BaseVectorEnv, VectorEnv from tianshou.policy import BasePolicy from tianshou.exploration import BaseNoise from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy @@ -100,23 +100,20 @@ def __init__(self, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: super().__init__() + if not isinstance(env, BaseVectorEnv): + env = VectorEnv([lambda : env]) self.env = env - self.env_num = 1 + self.env_num = len(env) + # need cache buffers before storing in the main buffer + self._cached_buf = [ListReplayBuffer() + for _ in range(self.env_num)] self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn - self._multi_env = isinstance(env, BaseVectorEnv) - # need multiple cache buffers only if storing in one buffer - self._cached_buf = [] - if self._multi_env: - self.env_num = len(env) - self._cached_buf = [ListReplayBuffer() - for _ in range(self.env_num)] self.stat_size = stat_size self._action_noise = action_noise - self._rew_metric = reward_metric or Collector._default_rew_metric self.reset() @@ -155,8 +152,6 @@ def reset_env(self) -> None: buffers (if need). """ obs = self.env.reset() - if not self._multi_env: - obs = self._make_batch(obs) if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get('obs', obs) self.data.obs = obs @@ -228,8 +223,6 @@ def collect(self, * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ - if not self._multi_env: - n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification is permitted!" @@ -274,22 +267,16 @@ def collect(self, self.data.act += self._action_noise(self.data.act.shape) # step in env - obs_next, rew, done, info = self.env.step( - self.data.act if self._multi_env else self.data.act[0]) + obs_next, rew, done, info = self.env.step(self.data.act) # move data to self.data - if not self._multi_env: - obs_next = self._make_batch(obs_next) - rew = self._make_batch(rew) - done = self._make_batch(done) - info = self._make_batch(info) self.data.obs_next = obs_next self.data.rew = rew self.data.done = done self.data.info = info if log_fn: - log_fn(info if self._multi_env else info[0]) + log_fn(self.data.info) if render: self.render() if render > 0: @@ -301,55 +288,37 @@ def collect(self, if self.preprocess_fn: result = self.preprocess_fn(**self.data) self.data.update(result) - if self._multi_env: # cache_buffer branch - for i in range(self.env_num): - self._cached_buf[i].add(**self.data[i]) - if self.data.done[i]: - if n_step != 0 or np.isscalar(n_episode) or \ - cur_episode[i] < n_episode[i]: - cur_episode[i] += 1 - reward_sum += self.reward[i] - length_sum += self.length[i] - if self._cached_buf: - cur_step += len(self._cached_buf[i]) - if self.buffer is not None: - self.buffer.update(self._cached_buf[i]) - self.reward[i], self.length[i] = 0., 0 + for i in range(self.env_num): + self._cached_buf[i].add(**self.data[i]) + if self.data.done[i]: + if n_step != 0 or np.isscalar(n_episode) or \ + cur_episode[i] < n_episode[i]: + cur_episode[i] += 1 + reward_sum += self.reward[i] + length_sum += self.length[i] if self._cached_buf: - self._cached_buf[i].reset() - self._reset_state(i) - obs_next = self.data.obs_next - if sum(self.data.done): - env_ind = np.where(self.data.done)[0] - obs_reset = self.env.reset(env_ind) - if self.preprocess_fn: - obs_next[env_ind] = self.preprocess_fn( - obs=obs_reset).get('obs', obs_reset) - else: - obs_next[env_ind] = obs_reset - self.data.obs_next = obs_next - if n_episode != 0: - if isinstance(n_episode, list) and \ - (cur_episode >= np.array(n_episode)).all() or \ - np.isscalar(n_episode) and \ - cur_episode.sum() >= n_episode: - break - else: # single buffer, without cache_buffer - if self.buffer is not None: - self.buffer.add(**self.data[0]) - cur_step += 1 - if self.data.done[0]: - cur_episode += 1 - reward_sum += self.reward[0] - length_sum += self.length[0] - self.reward, self.length = 0., np.zeros(self.env_num) - self.data.state = Batch() - obs_next = self._make_batch(self.env.reset()) - if self.preprocess_fn: - obs_next = self.preprocess_fn(obs=obs_next).get( - 'obs', obs_next) - self.data.obs_next = obs_next - if n_episode != 0 and cur_episode >= n_episode: + cur_step += len(self._cached_buf[i]) + if self.buffer is not None: + self.buffer.update(self._cached_buf[i]) + self.reward[i], self.length[i] = 0., 0 + if self._cached_buf: + self._cached_buf[i].reset() + self._reset_state(i) + obs_next = self.data.obs_next + if sum(self.data.done): + env_ind = np.where(self.data.done)[0] + obs_reset = self.env.reset(env_ind) + if self.preprocess_fn: + obs_next[env_ind] = self.preprocess_fn( + obs=obs_reset).get('obs', obs_reset) + else: + obs_next[env_ind] = obs_reset + self.data.obs_next = obs_next + if n_episode != 0: + if isinstance(n_episode, list) and \ + (cur_episode >= np.array(n_episode)).all() or \ + np.isscalar(n_episode) and \ + cur_episode.sum() >= n_episode: break if n_step != 0 and cur_step >= n_step: break From 51405632f9525972561944ce0fce94ecbf9b5713 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 10:04:31 +0800 Subject: [PATCH 02/17] update testcase --- test/base/test_collector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index ead017a01..d99cf4dbc 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -98,7 +98,7 @@ def test_collector_with_dict_state(): policy = MyPolicy(dict_state=True) c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn) c0.collect(n_step=3) - c0.collect(n_episode=3) + c0.collect(n_episode=2) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) @@ -126,9 +126,10 @@ def reward_metric(x): policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn, reward_metric=reward_metric) + # n_step=3 will collect a full episode r = c0.collect(n_step=3)['rew'] - assert np.asanyarray(r).size == 1 and r == 0. - r = c0.collect(n_episode=3)['rew'] + assert np.asanyarray(r).size == 1 and r == 4. + r = c0.collect(n_episode=2)['rew'] assert np.asanyarray(r).size == 1 and r == 4. env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] From 1e439e0aa820ae5c94d4961a7f4071fd0a8b9ed7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 10:04:56 +0800 Subject: [PATCH 03/17] pep8 fix --- tianshou/data/collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d53505040..1ee35d3d8 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -101,7 +101,7 @@ def __init__(self, ) -> None: super().__init__() if not isinstance(env, BaseVectorEnv): - env = VectorEnv([lambda : env]) + env = VectorEnv([lambda: env]) self.env = env self.env_num = len(env) # need cache buffers before storing in the main buffer From 993ec29eded90cbaaeb5d983e4dba2c505488972 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 10:11:01 +0800 Subject: [PATCH 04/17] remove dummy code --- tianshou/data/collector.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1ee35d3d8..a9109d3ba 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -172,13 +172,6 @@ def close(self) -> None: """Close the environment(s).""" self.env.close() - def _make_batch(self, data: Any) -> np.ndarray: - """Return [data].""" - if isinstance(data, np.ndarray): - return data[None] - else: - return np.array([data]) - def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset self.data.state[id].""" state = self.data.state # it is a reference @@ -243,11 +236,8 @@ def collect(self, # calculate the next action if random: - action_space = self.env.action_space - if isinstance(action_space, list): - result = Batch(act=[a.sample() for a in action_space]) - else: - result = Batch(act=self._make_batch(action_space.sample())) + result = Batch( + act=[a.sample() for a in self.env.action_space]) else: with torch.no_grad(): result = self.policy(self.data, last_state) From 90dd71981b708a26b34ee357f01f799c61da35c4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 10:47:03 +0800 Subject: [PATCH 05/17] remove dummy condition: _cached_buf never empty now --- tianshou/data/collector.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index a9109d3ba..ecfe686c9 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -286,13 +286,11 @@ def collect(self, cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] - if self._cached_buf: - cur_step += len(self._cached_buf[i]) - if self.buffer is not None: - self.buffer.update(self._cached_buf[i]) + cur_step += len(self._cached_buf[i]) + if self.buffer is not None: + self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0., 0 - if self._cached_buf: - self._cached_buf[i].reset() + self._cached_buf[i].reset() self._reset_state(i) obs_next = self.data.obs_next if sum(self.data.done): From 6846e2e92ae69f48394d2a29fd23786d081e2cbe Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 23 Jul 2020 11:05:15 +0800 Subject: [PATCH 06/17] minor fix --- tianshou/data/collector.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index ecfe686c9..363b112de 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -105,8 +105,7 @@ def __init__(self, self.env = env self.env_num = len(env) # need cache buffers before storing in the main buffer - self._cached_buf = [ListReplayBuffer() - for _ in range(self.env_num)] + self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy @@ -260,10 +259,7 @@ def collect(self, obs_next, rew, done, info = self.env.step(self.data.act) # move data to self.data - self.data.obs_next = obs_next - self.data.rew = rew - self.data.done = done - self.data.info = info + self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if log_fn: log_fn(self.data.info) From 9939ba322fcf6e55e5efa4071243ff6f3bff90ff Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 11:19:17 +0800 Subject: [PATCH 07/17] update doc --- docs/tutorials/dqn.rst | 2 +- tianshou/data/collector.py | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index cd663e504..9cbb2434f 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -179,7 +179,7 @@ Train a Policy with Customized Codes Tianshou supports user-defined training code. Here is the code snippet: :: - # pre-collect 5000 frames with random action before training + # pre-collect at least 5000 frames with random action before training policy.set_eps(1) train_collector.collect(n_step=5000) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 363b112de..690d853e6 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -21,8 +21,7 @@ class Collector(object): :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` - class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to - ``None``, it will automatically assign a small-size + class. If set to ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults @@ -56,9 +55,6 @@ class Collector(object): # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) - buffers = [ReplayBuffer(size=5000) for _ in range(3)] - # you can also pass a list of replay buffer to collector, for multi-env - # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) # collect at least 3 episodes @@ -81,9 +77,9 @@ class Collector(object): # clear the buffer collector.reset_buffer() - For the scenario of collecting data from multiple environments to a single - buffer, the cache buffers will turn on automatically. It may return the - data more than the given limitation. + Collected data always consist of full episodes. So if only ``n_step`` + argument is give, the collector may return the data more than the + ``n_step`` limitation. .. note:: From 44ac9bfad5208a1501d6179aef4e0d7ec51b4f0a Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 23 Jul 2020 11:22:35 +0800 Subject: [PATCH 08/17] minor fix --- docs/tutorials/concepts.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 1ba271ec4..3f033ead7 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -130,7 +130,7 @@ In short, :class:`~tianshou.data.Collector` has two main methods: * :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer; * :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. -Why do we mention **at least** here? For a single environment, the collector will finish exactly ``n_step`` or ``n_episode``. However, for multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. +Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. From a6dbff7450e2e9ce28e55cca56db7adfa7229df0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 23 Jul 2020 11:26:38 +0800 Subject: [PATCH 09/17] minor fix --- tianshou/data/collector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 690d853e6..fa45804f2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -21,8 +21,7 @@ class Collector(object): :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` - class. If set to ``None``, it will automatically assign a small-size - :class:`~tianshou.data.ReplayBuffer`. + class. If set to ``None`` (testing phase), it will not store the data. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to ``None``. @@ -57,7 +56,7 @@ class Collector(object): envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) collector = Collector(policy, envs, buffer=replay_buffer) - # collect at least 3 episodes + # collect 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) From d3e23acf2cdc3aad0d9038c8d84e735164021e1b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 12:42:12 +0800 Subject: [PATCH 10/17] remove movavg in collector --- tianshou/data/collector.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index fa45804f2..9701b2459 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -5,7 +5,6 @@ import numpy as np from typing import Any, Dict, List, Union, Optional, Callable -from tianshou.utils import MovAvg from tianshou.env import BaseVectorEnv, VectorEnv from tianshou.policy import BasePolicy from tianshou.exploration import BaseNoise @@ -25,8 +24,6 @@ class Collector(object): :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults to ``None``. - :param int stat_size: for the moving average of recording speed, defaults - to 100. :param BaseNoise action_noise: add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose. @@ -90,7 +87,6 @@ def __init__(self, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, - stat_size: Optional[int] = 100, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: @@ -101,12 +97,10 @@ def __init__(self, self.env_num = len(env) # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] - self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn - self.stat_size = stat_size self._action_noise = action_noise self._rew_metric = reward_metric or Collector._default_rew_metric self.reset() @@ -126,9 +120,6 @@ def reset(self) -> None: obs_next={}, policy={}) self.reset_env() self.reset_buffer() - self.step_speed = MovAvg(self.stat_size) - self.episode_speed = MovAvg(self.stat_size) - self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 if self._action_noise is not None: self._action_noise.reset() @@ -307,11 +298,6 @@ def collect(self, # generate the statistics cur_episode = sum(cur_episode) duration = max(time.time() - start_time, 1e-9) - self.step_speed.add(cur_step / duration) - self.episode_speed.add(cur_episode / duration) - self.collect_step += cur_step - self.collect_episode += cur_episode - self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: @@ -322,8 +308,8 @@ def collect(self, return { 'n/ep': cur_episode, 'n/st': cur_step, - 'v/st': self.step_speed.get(), - 'v/ep': self.episode_speed.get(), + 'v/st': cur_step / duration, + 'v/ep': cur_episode / duration, 'rew': reward_sum, 'len': length_sum / n_episode, } From 00a3e1ee32090bc312b20616a51d28bed44d4750 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 13:00:00 +0800 Subject: [PATCH 11/17] add back collect_time, collect_step and collect_episode --- tianshou/data/collector.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 9701b2459..272129afb 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -97,6 +97,7 @@ def __init__(self, self.env_num = len(env) # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] + self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn @@ -120,6 +121,7 @@ def reset(self) -> None: obs_next={}, policy={}) self.reset_env() self.reset_buffer() + self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 if self._action_noise is not None: self._action_noise.reset() @@ -298,6 +300,9 @@ def collect(self, # generate the statistics cur_episode = sum(cur_episode) duration = max(time.time() - start_time, 1e-9) + self.collect_step += cur_step + self.collect_episode += cur_episode + self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: From fae0dd6cc3e04a17a5e8c7a2062a8c7eec8fedb4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 13:35:22 +0800 Subject: [PATCH 12/17] simplify collector --- tianshou/data/collector.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 272129afb..3e7012851 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -142,8 +142,6 @@ def reset_env(self) -> None: if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get('obs', obs) self.data.obs = obs - self.reward = 0. # will be specified when the first data is ready - self.length = np.zeros(self.env_num) for b in self._cached_buf: b.reset() @@ -170,8 +168,8 @@ def _reset_state(self, id: Union[int, List[int]]) -> None: state.empty_(id) def collect(self, - n_step: int = 0, - n_episode: Union[int, List[int]] = 0, + n_step: Optional[int] = None, + n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None @@ -203,11 +201,11 @@ def collect(self, * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ - start_time = time.time() - assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ + assert (n_step and not n_episode) or (not n_step and n_episode), \ "One and only one collection number specification is permitted!" + start_time = time.time() cur_step, cur_episode = 0, np.zeros(self.env_num) - reward_sum, length_sum = 0., 0 + reward_sum = [0.0] * self.env_num while True: if cur_step >= 100000 and cur_episode.sum() == 0: warnings.warn( @@ -257,23 +255,20 @@ def collect(self, time.sleep(render) # add data into the buffer - self.length += 1 - self.reward += self.data.rew if self.preprocess_fn: result = self.preprocess_fn(**self.data) self.data.update(result) for i in range(self.env_num): self._cached_buf[i].add(**self.data[i]) if self.data.done[i]: - if n_step != 0 or np.isscalar(n_episode) or \ + if n_step or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 - reward_sum += self.reward[i] - length_sum += self.length[i] + reward_sum[i] += np.asarray( + self._cached_buf[i].rew).sum(axis=0) cur_step += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) - self.reward[i], self.length[i] = 0., 0 self._cached_buf[i].reset() self._reset_state(i) obs_next = self.data.obs_next @@ -286,13 +281,13 @@ def collect(self, else: obs_next[env_ind] = obs_reset self.data.obs_next = obs_next - if n_episode != 0: + if n_episode: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break - if n_step != 0 and cur_step >= n_step: + if n_step and cur_step >= n_step: break self.data.obs = self.data.obs_next self.data.obs = self.data.obs_next @@ -307,7 +302,7 @@ def collect(self, n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) - reward_sum /= n_episode + reward_sum = np.asarray(reward_sum).sum(axis=0) / n_episode if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum reward_sum = self._rew_metric(reward_sum) return { @@ -316,7 +311,7 @@ def collect(self, 'v/st': cur_step / duration, 'v/ep': cur_episode / duration, 'rew': reward_sum, - 'len': length_sum / n_episode, + 'len': cur_step / n_episode, } def sample(self, batch_size: int) -> Batch: From 1990a232df23c6fa14cc12323e53438da2f390f6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 14:33:37 +0800 Subject: [PATCH 13/17] simplify code --- tianshou/data/collector.py | 58 +++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3e7012851..624aee759 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -177,9 +177,10 @@ def collect(self, """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect (in each - environment). - :type n_episode: int or list + :param n_episode: how many episodes you want to collect. If it is + an int, it means to collect totally ``n_episode`` episodes; if + it is a list, it means to collect ``n_episode[i]`` episodes in + the i-th environment :param bool random: whether to use random policy for collecting data, defaults to ``False``. :param float render: the sleep time between rendering consecutive @@ -204,10 +205,12 @@ def collect(self, assert (n_step and not n_episode) or (not n_step and n_episode), \ "One and only one collection number specification is permitted!" start_time = time.time() - cur_step, cur_episode = 0, np.zeros(self.env_num) - reward_sum = [0.0] * self.env_num + step_count = 0 + # episode of each environment + episode_count = np.zeros(self.env_num) + reward_total = 0.0 while True: - if cur_step >= 100000 and cur_episode.sum() == 0: + if step_count >= 100000 and episode_count.sum() == 0: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', @@ -262,11 +265,11 @@ def collect(self, self._cached_buf[i].add(**self.data[i]) if self.data.done[i]: if n_step or np.isscalar(n_episode) or \ - cur_episode[i] < n_episode[i]: - cur_episode[i] += 1 - reward_sum[i] += np.asarray( + episode_count[i] < n_episode[i]: + episode_count[i] += 1 + reward_total += np.asarray( self._cached_buf[i].rew).sum(axis=0) - cur_step += len(self._cached_buf[i]) + step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) self._cached_buf[i].reset() @@ -283,35 +286,32 @@ def collect(self, self.data.obs_next = obs_next if n_episode: if isinstance(n_episode, list) and \ - (cur_episode >= np.array(n_episode)).all() or \ + (episode_count >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ - cur_episode.sum() >= n_episode: + episode_count.sum() >= n_episode: break - if n_step and cur_step >= n_step: + if n_step and step_count >= n_step: break self.data.obs = self.data.obs_next self.data.obs = self.data.obs_next # generate the statistics - cur_episode = sum(cur_episode) + episode_count = sum(episode_count) duration = max(time.time() - start_time, 1e-9) - self.collect_step += cur_step - self.collect_episode += cur_episode + self.collect_step += step_count + self.collect_episode += episode_count self.collect_time += duration - if isinstance(n_episode, list): - n_episode = np.sum(n_episode) - else: - n_episode = max(cur_episode, 1) - reward_sum = np.asarray(reward_sum).sum(axis=0) / n_episode - if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum - reward_sum = self._rew_metric(reward_sum) + # average reward across the number of episodes + reward_avg = reward_total / episode_count + if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg + reward_avg = self._rew_metric(reward_avg) return { - 'n/ep': cur_episode, - 'n/st': cur_step, - 'v/st': cur_step / duration, - 'v/ep': cur_episode / duration, - 'rew': reward_sum, - 'len': cur_step / n_episode, + 'n/ep': episode_count, + 'n/st': step_count, + 'v/st': step_count / duration, + 'v/ep': episode_count / duration, + 'rew': reward_avg, + 'len': step_count / episode_count, } def sample(self, batch_size: int) -> Batch: From ff98a6f0584ef0789401133ae11d9f3528b58e5b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 Jul 2020 15:22:43 +0800 Subject: [PATCH 14/17] simplify code --- tianshou/data/collector.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 624aee759..3ed3c5937 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -178,9 +178,9 @@ def collect(self, :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. If it is - an int, it means to collect totally ``n_episode`` episodes; if - it is a list, it means to collect ``n_episode[i]`` episodes in - the i-th environment + an int, it means to collect at lease ``n_episode`` episodes; if + it is list, it means to collect exactly ``n_episode[i]`` episodes + in the i-th environment :param bool random: whether to use random policy for collecting data, defaults to ``False``. :param float render: the sleep time between rendering consecutive @@ -254,8 +254,7 @@ def collect(self, log_fn(self.data.info) if render: self.render() - if render > 0: - time.sleep(render) + time.sleep(render) # add data into the buffer if self.preprocess_fn: @@ -283,17 +282,17 @@ def collect(self, obs=obs_reset).get('obs', obs_reset) else: obs_next[env_ind] = obs_reset - self.data.obs_next = obs_next - if n_episode: - if isinstance(n_episode, list) and \ - (episode_count >= np.array(n_episode)).all() or \ - np.isscalar(n_episode) and \ + self.data.obs = obs_next + if n_step: + if step_count >= n_step: + break + else: + if isinstance(n_episode, int) and \ episode_count.sum() >= n_episode: break - if n_step and step_count >= n_step: - break - self.data.obs = self.data.obs_next - self.data.obs = self.data.obs_next + if isinstance(n_episode, list) and \ + (episode_count >= n_episode).all(): + break # generate the statistics episode_count = sum(episode_count) From 1b51b336e9057c2d61a48122a7d432fbe7932b32 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 23 Jul 2020 15:37:34 +0800 Subject: [PATCH 15/17] remove log_fn since it is overlap with preprocess_fn --- test/base/test_collector.py | 61 ++++++++++++++++++----------------- tianshou/data/collector.py | 5 --- tianshou/trainer/offpolicy.py | 5 +-- tianshou/trainer/onpolicy.py | 5 +-- 4 files changed, 33 insertions(+), 43 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d99cf4dbc..39af7836b 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -25,29 +25,26 @@ def learn(self): pass -def preprocess_fn(**kwargs): - # modify info before adding into the buffer - # if info is not provided from env, it will be a ``Batch()``. - if not kwargs.get('info', Batch()).is_empty(): - n = len(kwargs['obs']) - info = kwargs['info'] - for i in range(n): - info[i].update(rew=kwargs['rew'][i]) - return {'info': info} - # or: return Batch(info=info) - else: - return Batch() - - -class Logger(object): +class Logger: def __init__(self, writer): self.cnt = 0 self.writer = writer - def log(self, info): - self.writer.add_scalar( - 'key', np.mean(info['key']), global_step=self.cnt) - self.cnt += 1 + def preprocess_fn(self, **kwargs): + # modify info before adding into the buffer, and recorded into tfb + # if info is not provided from env, it will be a ``Batch()``. + if not kwargs.get('info', Batch()).is_empty(): + n = len(kwargs['obs']) + info = kwargs['info'] + for i in range(n): + info[i].update(rew=kwargs['rew'][i]) + self.writer.add_scalar('key', np.mean( + info['key']), global_step=self.cnt) + self.cnt += 1 + return Batch(info=info) + # or: return {'info': info} + else: + return Batch() def test_collector(): @@ -60,16 +57,16 @@ def test_collector(): policy = MyPolicy() env = env_fns[0]() c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), - preprocess_fn) - c0.collect(n_step=3, log_fn=logger.log) + logger.preprocess_fn) + c0.collect(n_step=3) assert np.allclose(c0.buffer.obs[:3], [0, 1, 0]) assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1]) - c0.collect(n_episode=3, log_fn=logger.log) + c0.collect(n_episode=3) assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), - preprocess_fn) + logger.preprocess_fn) c1.collect(n_step=6) assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) assert np.allclose(c1.buffer[:11].obs_next, @@ -80,7 +77,7 @@ def test_collector(): [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) c1.collect(n_episode=3, random=True) c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False), - preprocess_fn) + logger.preprocess_fn) c2.collect(n_episode=[1, 2, 2, 2]) assert np.allclose(c2.buffer.obs_next[:26], [ 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, @@ -94,15 +91,17 @@ def test_collector(): def test_collector_with_dict_state(): + writer = SummaryWriter('log/ds_collector') + logger = Logger(writer) env = MyTestEnv(size=5, sleep=0, dict_state=True) policy = MyPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn) c0.collect(n_step=3) c0.collect(n_episode=2) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) - c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn) + c1 = Collector(policy, envs, ReplayBuffer(size=100), logger.preprocess_fn) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) batch = c1.sample(10) @@ -113,7 +112,7 @@ def test_collector_with_dict_state(): 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - preprocess_fn) + logger.preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) batch = c2.sample(10) print(batch['obs_next']['index']) @@ -122,10 +121,12 @@ def test_collector_with_dict_state(): def test_collector_with_ma(): def reward_metric(x): return x.sum() + writer = SummaryWriter('log/ma_collector') + logger = Logger(writer) env = MyTestEnv(size=5, sleep=0, ma_rew=4) policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), - preprocess_fn, reward_metric=reward_metric) + logger.preprocess_fn, reward_metric=reward_metric) # n_step=3 will collect a full episode r = c0.collect(n_step=3)['rew'] assert np.asanyarray(r).size == 1 and r == 4. @@ -135,7 +136,7 @@ def reward_metric(x): for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) c1 = Collector(policy, envs, ReplayBuffer(size=100), - preprocess_fn, reward_metric=reward_metric) + logger.preprocess_fn, reward_metric=reward_metric) r = c1.collect(n_step=10)['rew'] assert np.asanyarray(r).size == 1 and r == 4. r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] @@ -154,7 +155,7 @@ def reward_metric(x): assert np.allclose(c0.buffer[:len(c0.buffer)].rew, [[x] * 4 for x in rew]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - preprocess_fn, reward_metric=reward_metric) + logger.preprocess_fn, reward_metric=reward_metric) r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] assert np.asanyarray(r).size == 1 and r == 4. batch = c2.sample(10) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3ed3c5937..2e5a8366b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -172,7 +172,6 @@ def collect(self, n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, - log_fn: Optional[Callable[[dict], None]] = None ) -> Dict[str, float]: """Collect a specified number of step or episode. @@ -185,8 +184,6 @@ def collect(self, defaults to ``False``. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). - :param function log_fn: a function which receives env info, typically - for tensorboard logging. .. note:: @@ -250,8 +247,6 @@ def collect(self, # move data to self.data self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) - if log_fn: - log_fn(self.data.info) if render: self.render() time.sleep(render) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index edcf0fdb8..408d4e738 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -23,7 +23,6 @@ def offpolicy_trainer( test_fn: Optional[Callable[[int], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, - log_fn: Optional[Callable[[dict], None]] = None, writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, @@ -61,7 +60,6 @@ def offpolicy_trainer( :param function stop_fn: a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. - :param function log_fn: a function receives env info for logging. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter. :param int log_interval: the log interval of the writer. @@ -83,8 +81,7 @@ def offpolicy_trainer( with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: - result = train_collector.collect(n_step=collect_per_step, - log_fn=log_fn) + result = train_collector.collect(n_step=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result['rew']): test_result = test_episode( diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index b0d68ff2a..8f77dd1e4 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -23,7 +23,6 @@ def onpolicy_trainer( test_fn: Optional[Callable[[int], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, - log_fn: Optional[Callable[[dict], None]] = None, writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, @@ -62,7 +61,6 @@ def onpolicy_trainer( :param function stop_fn: a function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. - :param function log_fn: a function receives env info for logging. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter. :param int log_interval: the log interval of the writer. @@ -84,8 +82,7 @@ def onpolicy_trainer( with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: - result = train_collector.collect(n_episode=collect_per_step, - log_fn=log_fn) + result = train_collector.collect(n_episode=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result['rew']): test_result = test_episode( From 605b0acf77dfba0e6a2596e7c2b7d48835f88471 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 23 Jul 2020 15:43:13 +0800 Subject: [PATCH 16/17] minor fix --- test/base/test_collector.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 39af7836b..702e93a3c 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -46,6 +46,19 @@ def preprocess_fn(self, **kwargs): else: return Batch() + @staticmethod + def single_preprocess_fn(**kwargs): + # same as above, without tfb + if not kwargs.get('info', Batch()).is_empty(): + n = len(kwargs['obs']) + info = kwargs['info'] + for i in range(n): + info[i].update(rew=kwargs['rew'][i]) + return Batch(info=info) + # or: return {'info': info} + else: + return Batch() + def test_collector(): writer = SummaryWriter('log/collector') @@ -91,17 +104,17 @@ def test_collector(): def test_collector_with_dict_state(): - writer = SummaryWriter('log/ds_collector') - logger = Logger(writer) env = MyTestEnv(size=5, sleep=0, dict_state=True) policy = MyPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), + Logger.single_preprocess_fn) c0.collect(n_step=3) c0.collect(n_episode=2) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) - c1 = Collector(policy, envs, ReplayBuffer(size=100), logger.preprocess_fn) + c1 = Collector(policy, envs, ReplayBuffer(size=100), + Logger.single_preprocess_fn) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) batch = c1.sample(10) @@ -112,7 +125,7 @@ def test_collector_with_dict_state(): 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - logger.preprocess_fn) + Logger.single_preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) batch = c2.sample(10) print(batch['obs_next']['index']) @@ -121,12 +134,10 @@ def test_collector_with_dict_state(): def test_collector_with_ma(): def reward_metric(x): return x.sum() - writer = SummaryWriter('log/ma_collector') - logger = Logger(writer) env = MyTestEnv(size=5, sleep=0, ma_rew=4) policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), - logger.preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn, reward_metric=reward_metric) # n_step=3 will collect a full episode r = c0.collect(n_step=3)['rew'] assert np.asanyarray(r).size == 1 and r == 4. @@ -136,7 +147,7 @@ def reward_metric(x): for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) c1 = Collector(policy, envs, ReplayBuffer(size=100), - logger.preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn, reward_metric=reward_metric) r = c1.collect(n_step=10)['rew'] assert np.asanyarray(r).size == 1 and r == 4. r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] @@ -155,7 +166,7 @@ def reward_metric(x): assert np.allclose(c0.buffer[:len(c0.buffer)].rew, [[x] * 4 for x in rew]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - logger.preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn, reward_metric=reward_metric) r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] assert np.asanyarray(r).size == 1 and r == 4. batch = c2.sample(10) From c4026bf46816110c89403d5ba9bff8cfb0e773ca Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 23 Jul 2020 16:04:29 +0800 Subject: [PATCH 17/17] minor fix --- test/base/test_collector.py | 8 ++++---- tianshou/data/collector.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 702e93a3c..9fa37b66f 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -72,11 +72,11 @@ def test_collector(): c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) c0.collect(n_step=3) - assert np.allclose(c0.buffer.obs[:3], [0, 1, 0]) - assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1]) + assert np.allclose(c0.buffer.obs[:4], [0, 1, 0, 1]) + assert np.allclose(c0.buffer[:4].obs_next, [1, 2, 1, 2]) c0.collect(n_episode=3) - assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) - assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c0.buffer.obs[:10], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) + assert np.allclose(c0.buffer[:10].obs_next, [1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 2e5a8366b..fa5108be7 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -75,7 +75,8 @@ class Collector(object): Collected data always consist of full episodes. So if only ``n_step`` argument is give, the collector may return the data more than the - ``n_step`` limitation. + ``n_step`` limitation. Same as ``n_episode`` for the multiple environment + case. .. note:: @@ -176,10 +177,10 @@ def collect(self, """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. If it is - an int, it means to collect at lease ``n_episode`` episodes; if - it is list, it means to collect exactly ``n_episode[i]`` episodes - in the i-th environment + :param n_episode: how many episodes you want to collect. If it is an + int, it means to collect at lease ``n_episode`` episodes; if it is + a list, it means to collect exactly ``n_episode[i]`` episodes in + the i-th environment :param bool random: whether to use random policy for collecting data, defaults to ``False``. :param float render: the sleep time between rendering consecutive @@ -261,8 +262,7 @@ def collect(self, if n_step or np.isscalar(n_episode) or \ episode_count[i] < n_episode[i]: episode_count[i] += 1 - reward_total += np.asarray( - self._cached_buf[i].rew).sum(axis=0) + reward_total += np.sum(self._cached_buf[i].rew, axis=0) step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i])