这是indexloc提供的服务,不要输入任何密码
Skip to content

Collector.collect run an entire episode when only set n_step #255

@eric-liuyd

Description

@eric-liuyd
  • I have marked all applicable categories:
    • exception-raising bug
    • RL algorithm bug
    • documentation request (i.e. "X is missing from the documentation.")
    • new feature request
  • I have visited the source website
  • I have searched through the issue tracker for duplicates
  • I have mentioned version numbers, operating system and environment, where applicable:
    import tianshou, torch, sys
    print(tianshou.__version__, torch.__version__, sys.version, sys.platform)

I have a question of using Collector to collect only a certain steps. Here is a simplified code.

batch_size = 32
train_collector = Collector(policy, train_envs, buf)
train_collector.collect(n_step=batch_size)

In my view, the train_collector will return a result after 'batch_size' steps of env is collected. However, the train_collector collected an entire episode without returning.

One episode in my env contains a fixed 100k steps, which takes a very long time to collect. So I only want to collect 32 env steps and train_collector collected one episode. After checking the code in Collector.collect, I found that step_count will be updated only in if done[j]: branch, which is:

                if done[j]:
                    if not (isinstance(n_episode, list)
                            and episode_count[i] >= n_episode[i]):
                        episode_count[i] += 1
                        rewards.append(self._rew_metric(
                            np.sum(self._cached_buf[i].rew, axis=0)))
                        step_count += len(self._cached_buf[i])
                        if self.buffer is not None:
                            self.buffer.update(self._cached_buf[i])
                        if isinstance(n_episode, list) and \
                                episode_count[i] >= n_episode[i]:
                            # env i has collected enough data, it has finished
                            finished_env_ids.append(i)
                    self._cached_buf[i].reset()
                    self._reset_state(j)

So it seems that the while loop will not be broken even if the real env_step exceeds the n_step, as long as the env is not done (I use only one env to do experiment).

Is my understanding right? Is there any way to break the loop when n_step is input and env_step is exactly equal to n_step?Because I don't want an entire episode to be collected when only collecting n_step (e.g., collecting a batch size). I tried to modifiy your code duplication but I'm not sure the original design will be influenced.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementFeature that is not a new algorithm or an algorithm enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions