这是indexloc提供的服务,不要输入任何密码
Skip to content

Collector sampling with multiple environment does not seem to be unbiased with n_episodes #1042

@utkarshp

Description

@utkarshp
  • I have marked all applicable categories:
    • exception-raising bug
    • RL algorithm bug
    • documentation request (i.e. "X is missing from the documentation.")
    • new feature request
    • design request (i.e. "X should be changed to Y.")
  • I have visited the source website
  • I have searched through the issue tracker for duplicates
  • I have mentioned version numbers, operating system and environment, where applicable:
python
>>>  import tianshou, gymnasium as gym, torch, numpy, sys
>>>  print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
0.5.1 0.29.1 2.1.2 1.26.3 3.11.0 (main, Mar  1 2023, 18:26:19) [GCC 11.2.0] linux

I have observed that when Collector is inited with n_episodes, it tends to sample more from environments which take less steps to finish compared to others. There is some logic implemented in the collect method, but it does not seem to be doing enough. Here is the code that tries to deal with this by removing some envs:
data/collector.py lines 367--373 in master branch:

                # remove surplus env id from ready_env_ids
                # to avoid bias in selecting environments
                if n_episode:
                    surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
                    if surplus_env_num > 0:
                        mask = np.ones_like(ready_env_ids, dtype=bool)
                        mask[env_ind_local[:surplus_env_num]] = False
                        ready_env_ids = ready_env_ids[mask]
                        self.data = self.data[mask]

Here is the problem: Suppose we have 50 envs, running 100 episodes. Ideally, I would want each env to be stepped through two episodes regardless of their lengths. Now suppose 30 of these 50 envs are slow and take 31 steps each, while the remaining 20 are fast take 10 steps each to finish. After 10 steps, we get episode_count = 20 => surplus_env_num < 0, and nothing changes. All envs are stepped through another episode each, resulting in episode_count = 40 => surplus_env_num < 0. The same thing happens again, and we finally get episode_count = 60 => surplus_env_num > 0. At this point, we have already gathered a total of 60 episodes from the fast envs (3 episodes each), but no episode from the slow envs has finished. In the end, we would get 3 episodes each from the fast envs, 1-2 episodes each from the slower envs.

I made an attempt at fixing this as follows. Essentially, this attempt makes the envs that finished wait on those that have not, and then calls reset on all at once.

+    def _reset_env_to_next(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
+        gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
+        obs, info = self.env.reset(**gym_reset_kwargs)
+        if self.preprocess_fn:
+            processed_data = self.preprocess_fn(
+                obs=obs, info=info, env_id=np.arange(self.env_num)
+            )
+            obs = processed_data.get("obs", obs)
+            info = processed_data.get("info", info)
+        self.data = Batch(
+            obs={},
+            act={},
+            rew={},
+            terminated={},
+            truncated={},
+            done={},
+            obs_next={},
+            info={},
+            policy={}
+        )
+        self.data.info = info
+        self.data.obs_next = obs
+
     def collect(
             self,
             n_step: Optional[int] = None,
@@ -69,8 +92,10 @@
                     "which may cause extra transitions collected into the buffer."
                 )
             ready_env_ids = np.arange(self.env_num)
+            init_ready_env_ids = np.arange(self.env_num)
         elif n_episode is not None:
             assert n_episode > 0
+            init_ready_env_ids = np.arange(min(self.env_num, n_episode))
             ready_env_ids = np.arange(min(self.env_num, n_episode))
             self.data = self.data[:min(self.env_num, n_episode)]
         else:
@@ -170,24 +195,18 @@
                 episode_lens.append(ep_len[env_ind_local])
                 episode_rews.append(ep_rew[env_ind_local])
                 episode_start_indices.append(ep_idx[env_ind_local])
-                # now we copy obs_next to obs, but since there might be
-                # finished episodes, we have to reset finished envs first.
-                self._reset_env_with_ids(
-                    env_ind_local, env_ind_global, gym_reset_kwargs
-                )
-                for i in env_ind_local:
-                    self._reset_state(i)
+                unfinished_ind_local = np.where(~done)[0]
 
                 # remove surplus env id from ready_env_ids
                 # to avoid bias in selecting environments
                 if n_episode:
-                    surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
-                    if surplus_env_num > 0:
-                        mask = np.ones_like(ready_env_ids, dtype=bool)
-                        mask[env_ind_local[:surplus_env_num]] = False
-                        ready_env_ids = ready_env_ids[mask]
-                        self.data = self.data[mask]
-
+                    ready_env_ids = ready_env_ids[unfinished_ind_local]
+                    self.data = self.data[unfinished_ind_local]
+                    if len(unfinished_ind_local) == 0:
+                        self._reset_env_to_next(gym_reset_kwargs)
+                        ready_env_ids = init_ready_env_ids
+                        for i in ready_env_ids:
+                            self._reset_state(i)
             self.data.obs = self.data.obs_next
 
             if (n_step and step_count >= n_step) or \

Metadata

Metadata

Assignees

Labels

algorithm enhancementNot quite a new algorithm, but an enhancement to algo. functionalityquestionFurther information is requested

Type

No type

Projects

Status

To do

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions