-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
- I have marked all applicable categories:
- exception-raising bug
- RL algorithm bug
- documentation request (i.e. "X is missing from the documentation.")
- new feature request
- I have visited the source website
- I have searched through the issue tracker for duplicates
- I have mentioned version numbers, operating system and environment, where applicable:
import tianshou, torch, sys print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
I have a question of using Collector to collect only a certain steps. Here is a simplified code.
batch_size = 32
train_collector = Collector(policy, train_envs, buf)
train_collector.collect(n_step=batch_size)
In my view, the train_collector will return a result after 'batch_size' steps of env is collected. However, the train_collector collected an entire episode without returning.
One episode in my env contains a fixed 100k steps, which takes a very long time to collect. So I only want to collect 32 env steps and train_collector collected one episode. After checking the code in Collector.collect, I found that step_count will be updated only in if done[j]:
branch, which is:
if done[j]:
if not (isinstance(n_episode, list)
and episode_count[i] >= n_episode[i]):
episode_count[i] += 1
rewards.append(self._rew_metric(
np.sum(self._cached_buf[i].rew, axis=0)))
step_count += len(self._cached_buf[i])
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
if isinstance(n_episode, list) and \
episode_count[i] >= n_episode[i]:
# env i has collected enough data, it has finished
finished_env_ids.append(i)
self._cached_buf[i].reset()
self._reset_state(j)
So it seems that the while loop will not be broken even if the real env_step exceeds the n_step, as long as the env is not done (I use only one env to do experiment).
Is my understanding right? Is there any way to break the loop when n_step is input and env_step is exactly equal to n_step?Because I don't want an entire episode to be collected when only collecting n_step (e.g., collecting a batch size). I tried to modifiy your code duplication but I'm not sure the original design will be influenced.