From 71ef846c8c64dc5ca21fad811284e18dd5e30bd9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Aug 2020 14:02:53 +0800 Subject: [PATCH 01/11] add policy.update to enable post process and remove collector.sample --- test/base/test_collector.py | 8 ++++---- tianshou/data/collector.py | 5 +++++ tianshou/policy/base.py | 18 ++++++++++++++++-- tianshou/policy/modelfree/ddpg.py | 4 +--- tianshou/policy/modelfree/dqn.py | 4 +--- tianshou/policy/modelfree/sac.py | 4 +--- tianshou/policy/modelfree/td3.py | 3 +-- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 4 ++-- 9 files changed, 32 insertions(+), 20 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 68cf5c74c..9eab3744b 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -174,7 +174,7 @@ def test_collector_with_dict_state(): c1.seed(0) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) - batch = c1.sample(10) + batch = c1.buffer.sample(10)[0] print(batch) c0.buffer.update(c1.buffer) assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([ @@ -184,7 +184,7 @@ def test_collector_with_dict_state(): c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), Logger.single_preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) - batch = c2.sample(10) + batch = c2.buffer.sample(10)[0] print(batch['obs_next']['index']) @@ -209,7 +209,7 @@ def reward_metric(x): assert np.asanyarray(r).size == 1 and r == 4. r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] assert np.asanyarray(r).size == 1 and r == 4. - batch = c1.sample(10) + batch = c1.buffer.sample(10)[0] print(batch) c0.buffer.update(c1.buffer) obs = np.array(np.expand_dims([ @@ -226,7 +226,7 @@ def reward_metric(x): Logger.single_preprocess_fn, reward_metric=reward_metric) r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] assert np.asanyarray(r).size == 1 and r == 4. - batch = c2.sample(10) + batch = c2.buffer.sample(10)[0] print(batch['obs_next']) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 9105374bf..4426c0629 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -364,6 +364,11 @@ def sample(self, batch_size: int) -> Batch: the buffer, otherwise it will extract the data with the given batch_size. """ + import warnings + warnings.warn( + 'Collector.sample is deprecated and will cause error if you use ' + 'prioritized experience replay! Collector.sample will be removed' + ' after 0.3. Use policy.update instead!', DeprecationWarning) batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index cc1a59326..a73e833da 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -216,9 +216,23 @@ def compute_nstep_return( batch.returns = target_q * gammas + returns # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): - batch.update_weight = buffer.update_weight - batch.indice = indice batch.weight = to_torch_as(batch.weight, target_q) else: batch.weight = torch.ones_like(target_q) return batch + + def post_process_fn(self, batch: Batch, + buffer: ReplayBuffer, indice: np.ndarray): + """Post-process the data from the provided replay buffer. Typical + usage is to update the sampling weight in prioritized experience + replay. Check out :ref:`policy_concept` for more information. + """ + if isinstance(buffer, PrioritizedReplayBuffer): + buffer.update_weight(indice, batch.weight) + + def update(self, buffer: ReplayBuffer, sample_size: int, *args, **kwargs): + batch, indice = buffer.sample(sample_size) + batch = self.process_fn(batch, buffer, indice) + result = self.learn(batch, *args, **kwargs) + self.post_process_fn(batch, buffer, indice) + return result diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 2205102f2..b6787652f 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -144,10 +144,8 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q - if hasattr(batch, 'update_weight'): # prio-buffer - batch.update_weight(batch.indice, td) critic_loss = (td.pow(2) * batch.weight).mean() - # critic_loss = F.mse_loss(current_q, target_q) + batch.weight = td self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c37dac515..8b8fef1ce 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -160,10 +160,8 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() td = r - q - if hasattr(batch, 'update_weight'): # prio-buffer - batch.update_weight(batch.indice, td) loss = (td.pow(2) * batch.weight).mean() - # loss = F.mse_loss(q, r) + batch.weight = td loss.backward() self.optim.step() self._cnt += 1 diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ce4a5baf0..ce5ccbd1d 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -154,9 +154,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() - # prio-buffer - if hasattr(batch, 'update_weight'): - batch.update_weight(batch.indice, (td1 + td2) / 2.) + batch.weight = (td1 + td2) / 2. # actor obs_result = self(batch, explorating=False) a = obs_result.act diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 698145f1d..1e33534ad 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -132,8 +132,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() - if hasattr(batch, 'update_weight'): # prio-buffer - batch.update_weight(batch.indice, (td1 + td2) / 2.) + batch.weight = (td1 + td2) / 2. if self._cnt % self._freq == 0: actor_loss = -self.critic1( batch.obs, self(batch, eps=0).act).mean() diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 408d4e738..2321ec219 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -103,7 +103,7 @@ def offpolicy_trainer( for i in range(update_per_step * min( result['n/st'] // collect_per_step, t.total - t.n)): global_step += 1 - losses = policy.learn(train_collector.sample(batch_size)) + losses = policy.update(train_collector.buffer, batch_size) for k in result.keys(): data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index dec42b256..574a197f5 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -101,8 +101,8 @@ def onpolicy_trainer( policy.train() if train_fn: train_fn(epoch) - losses = policy.learn( - train_collector.sample(0), batch_size, repeat_per_collect) + losses = policy.update( + train_collector.buffer, 0, batch_size, repeat_per_collect) train_collector.reset_buffer() step = 1 for k in losses.keys(): From 247bb4ddeeb4cbff3f8133b3eafcc7b392263bf1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Aug 2020 14:34:07 +0800 Subject: [PATCH 02/11] update doc in policy concept --- docs/tutorials/concepts.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 3f033ead7..ea716740c 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -67,9 +67,11 @@ Tianshou aims to modularizing RL algorithms. It comes into several classes of po A policy class typically has four parts: * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on; +* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation; -* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer); * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. +* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. +* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (can update buffer). Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: @@ -125,10 +127,9 @@ Collector --------- The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. -In short, :class:`~tianshou.data.Collector` has two main methods: +In short, :class:`~tianshou.data.Collector` has one main method: * :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer; -* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. From 1f2c4caa98b44cfc81a8cf5ac2a39f86d53d6000 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Aug 2020 22:08:26 +0800 Subject: [PATCH 03/11] remove collector.sample in doc --- tianshou/data/collector.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 4426c0629..e31d4a2d7 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -64,16 +64,6 @@ class Collector(object): # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) - # sample data with a given number of batch-size: - batch_data = collector.sample(batch_size=64) - # policy.learn(batch_data) # btw, vanilla policy gradient only - # supports on-policy training, so here we pick all data in the buffer - batch_data = collector.sample(batch_size=0) - policy.learn(batch_data) - # on-policy algorithms use the collected data only once, so here we - # clear the buffer - collector.reset_buffer() - Collected data always consist of full episodes. So if only ``n_step`` argument is give, the collector may return the data more than the ``n_step`` limitation. Same as ``n_episode`` for the multiple environment From 2fb5a1caf50d13c7b6b4ee06f3a6366bf8457548 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Aug 2020 22:09:47 +0800 Subject: [PATCH 04/11] doc update of concepts --- docs/tutorials/concepts.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index ea716740c..35d162f6c 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -64,9 +64,9 @@ Policy Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. -A policy class typically has four parts: +A policy class typically has the following parts: -* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on; +* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on; * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation; * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. From 724668bff5765c6ddc603609ef09b92097f66888 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 11 Aug 2020 15:23:29 +0800 Subject: [PATCH 05/11] docs --- README.md | 6 ++++-- docs/tutorials/concepts.rst | 9 +++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 2c027d506..d8c6a3543 100644 --- a/README.md +++ b/README.md @@ -139,12 +139,14 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page ### Modularized Policy -We decouple all of the algorithms into 4 parts: +We decouple all of the algorithms into roughly the following parts: - `__init__`: initialize the policy; - `forward`: to compute actions over given observations; - `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms); -- `learn`: to learn from a given batch data. +- `learn`: to learn from a given batch data; +- `post_process_fn`: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight); +- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`. Within this API, we can interact with different policies conveniently. diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 35d162f6c..eaa0bbca9 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -67,11 +67,11 @@ Tianshou aims to modularizing RL algorithms. It comes into several classes of po A policy class typically has the following parts: * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on; -* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation; +* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. -* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (can update buffer). +* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer). Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: @@ -127,9 +127,8 @@ Collector --------- The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. -In short, :class:`~tianshou.data.Collector` has one main method: -* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer; +:class:`~tianshou.data.Collector` has one main method :meth:`~tianshou.data.Collector.collect`: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer. Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. @@ -145,8 +144,6 @@ Once you have a collector and a policy, you can start writing the training metho Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage. -There will be more types of trainers, for instance, multi-agent trainer. - .. _pseudocode: From ce8c916829e21a2b879faafa61f3f194aeb8cac1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 11 Aug 2020 15:51:27 +0800 Subject: [PATCH 06/11] polish --- test/base/test_collector.py | 8 ++++---- tianshou/data/collector.py | 8 ++++---- tianshou/policy/base.py | 5 ++--- tianshou/policy/modelfree/ddpg.py | 2 +- tianshou/policy/modelfree/dqn.py | 2 +- tianshou/policy/modelfree/sac.py | 2 +- tianshou/policy/modelfree/td3.py | 2 +- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 2 +- 9 files changed, 16 insertions(+), 17 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 9eab3744b..38e5d9378 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -174,7 +174,7 @@ def test_collector_with_dict_state(): c1.seed(0) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) - batch = c1.buffer.sample(10)[0] + batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([ @@ -184,7 +184,7 @@ def test_collector_with_dict_state(): c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), Logger.single_preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) - batch = c2.buffer.sample(10)[0] + batch, _ = c2.buffer.sample(10) print(batch['obs_next']['index']) @@ -209,7 +209,7 @@ def reward_metric(x): assert np.asanyarray(r).size == 1 and r == 4. r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] assert np.asanyarray(r).size == 1 and r == 4. - batch = c1.buffer.sample(10)[0] + batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) obs = np.array(np.expand_dims([ @@ -226,7 +226,7 @@ def reward_metric(x): Logger.single_preprocess_fn, reward_metric=reward_metric) r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] assert np.asanyarray(r).size == 1 and r == 4. - batch = c2.buffer.sample(10)[0] + batch, _ = c2.buffer.sample(10) print(batch['obs_next']) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index e31d4a2d7..0c4792727 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -347,8 +347,8 @@ def collect(self, def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call - :meth:`~tianshou.policy.BasePolicy.process_fn` before returning - the final batch data. + :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the + final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given @@ -357,8 +357,8 @@ def sample(self, batch_size: int) -> Batch: import warnings warnings.warn( 'Collector.sample is deprecated and will cause error if you use ' - 'prioritized experience replay! Collector.sample will be removed' - ' after 0.3. Use policy.update instead!', DeprecationWarning) + 'prioritized experience replay! Collector.sample will be removed ' + 'after version 0.3. Use policy.update instead!', Warning) batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index a73e833da..8ba133c7a 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -94,8 +94,7 @@ def forward(self, batch: Batch, # some code return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) # and in the sampled data batch, you can directly call - # batch.policy.log_prob to get your data, although it is stored in - # np.ndarray. + # batch.policy.log_prob to get your data. """ pass @@ -230,7 +229,7 @@ def post_process_fn(self, batch: Batch, if isinstance(buffer, PrioritizedReplayBuffer): buffer.update_weight(indice, batch.weight) - def update(self, buffer: ReplayBuffer, sample_size: int, *args, **kwargs): + def update(self, sample_size: int, buffer: ReplayBuffer, *args, **kwargs): batch, indice = buffer.sample(sample_size) batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, *args, **kwargs) diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index b6787652f..79a65d3bb 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -145,7 +145,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * batch.weight).mean() - batch.weight = td + batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 8b8fef1ce..f1a01a6e7 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -161,7 +161,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: r = to_torch_as(batch.returns, q).flatten() td = r - q loss = (td.pow(2) * batch.weight).mean() - batch.weight = td + batch.weight = td # prio-buffer loss.backward() self.optim.step() self._cnt += 1 diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ce5ccbd1d..341fe7b11 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -154,7 +154,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() - batch.weight = (td1 + td2) / 2. + batch.weight = (td1 + td2) / 2. # prio-buffer # actor obs_result = self(batch, explorating=False) a = obs_result.act diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 1e33534ad..9a340950b 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -132,7 +132,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() - batch.weight = (td1 + td2) / 2. + batch.weight = (td1 + td2) / 2. # prio-buffer if self._cnt % self._freq == 0: actor_loss = -self.critic1( batch.obs, self(batch, eps=0).act).mean() diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 2321ec219..c11d08d6a 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -103,7 +103,7 @@ def offpolicy_trainer( for i in range(update_per_step * min( result['n/st'] // collect_per_step, t.total - t.n)): global_step += 1 - losses = policy.update(train_collector.buffer, batch_size) + losses = policy.update(batch_size, train_collector.buffer) for k in result.keys(): data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 574a197f5..1c1e6a289 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -102,7 +102,7 @@ def onpolicy_trainer( if train_fn: train_fn(epoch) losses = policy.update( - train_collector.buffer, 0, batch_size, repeat_per_collect) + 0, train_collector.buffer, batch_size, repeat_per_collect) train_collector.reset_buffer() step = 1 for k in losses.keys(): From 6230fe19409701107055c7c782f5a6af46e069c6 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 11 Aug 2020 16:42:33 +0800 Subject: [PATCH 07/11] polish policy --- tianshou/policy/base.py | 5 ++++ tianshou/policy/modelfree/a2c.py | 6 ++--- tianshou/policy/modelfree/pg.py | 5 +--- tianshou/policy/modelfree/ppo.py | 43 +++++++++++++------------------- 4 files changed, 26 insertions(+), 33 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 8ba133c7a..25587f0f6 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -122,6 +122,7 @@ def compute_episodic_return( v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95, + rew_norm: bool = False, ) -> Batch: """Compute returns over given full-length episodes, including the implementation of Generalized Advantage Estimator (arXiv:1506.02438). @@ -135,6 +136,8 @@ def compute_episodic_return( to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1], defaults to 0.95. + :param bool rew_norm: normalize the reward to Normal(0, 1), defaults + to ``False``. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). @@ -149,6 +152,8 @@ def compute_episodic_return( for i in range(len(rew) - 1, -1, -1): gae = delta[i] + m[i] * gae returns[i] += gae + if rew_norm and not np.isclose(returns.std(), 0): + returns = (returns - returns.mean()) / returns.std() batch.returns = returns return batch diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 2a5c123c0..52d8dd248 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -67,7 +67,8 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, v_.append(to_numpy(self.critic(b.obs_next))) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( - batch, v_, gamma=self._gamma, gae_lambda=self._lambda) + batch, v_, gamma=self._gamma, gae_lambda=self._lambda, + rew_norm=self._rew_norm) def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, @@ -97,9 +98,6 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size - r = batch.returns - if self._rew_norm and not np.isclose(r.std(), 0): - batch.returns = (r - r.mean()) / r.std() losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index d6176e68d..8fded95ec 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -51,7 +51,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, # batch.returns = self._vectorized_returns(batch) # return batch return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.) + batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm) def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, @@ -81,9 +81,6 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses = [] - r = batch.returns - if self._rew_norm and not np.isclose(r.std(), 0): - batch.returns = (r - r.mean()) / r.std() for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 2d1decea7..b4c5b79b8 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -81,16 +81,29 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, mean, std = batch.rew.mean(), batch.rew.std() if not np.isclose(std, 0): batch.rew = (batch.rew - mean) / std - if self._lambda in [0, 1]: - return self.compute_episodic_return( - batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] + v = [] + old_log_prob = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) + v.append(self.critic(b.obs)) + old_log_prob.append(self(b).dist.log_prob( + to_torch_as(b.act, v[0]))) v_ = to_numpy(torch.cat(v_, dim=0)) - return self.compute_episodic_return( - batch, v_, gamma=self._gamma, gae_lambda=self._lambda) + batch = self.compute_episodic_return( + batch, v_, gamma=self._gamma, gae_lambda=self._lambda, + rew_norm=self._rew_norm) + batch.v = torch.cat(v, dim=0).flatten() # old value + batch.act = to_torch_as(batch.act, v[0]) + batch.logp_old = torch.cat(old_log_prob, dim=0) + batch.returns = to_torch_as(batch.returns, v[0]) + batch.adv = batch.returns - batch.v + if self._rew_norm: + mean, std = batch.adv.mean(), batch.adv.std() + if not np.isclose(std.item(), 0): + batch.adv = (batch.adv - mean) / std + return batch def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, @@ -123,26 +136,6 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] - v = [] - old_log_prob = [] - with torch.no_grad(): - for b in batch.split(batch_size, shuffle=False): - v.append(self.critic(b.obs)) - old_log_prob.append(self(b).dist.log_prob( - to_torch_as(b.act, v[0]))) - batch.v = torch.cat(v, dim=0).flatten() # old value - batch.act = to_torch_as(batch.act, v[0]) - batch.logp_old = torch.cat(old_log_prob, dim=0) - batch.returns = to_torch_as(batch.returns, v[0]) - if self._rew_norm: - mean, std = batch.returns.mean(), batch.returns.std() - if not np.isclose(std.item(), 0): - batch.returns = (batch.returns - mean) / std - batch.adv = batch.returns - batch.v - if self._rew_norm: - mean, std = batch.adv.mean(), batch.adv.std() - if not np.isclose(std.item(), 0): - batch.adv = (batch.adv - mean) / std for _ in range(repeat): for b in batch.split(batch_size): dist = self(b).dist From e8d0ed90095ca3d1ba8088f4a544703fe55e8b80 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 11 Aug 2020 17:04:07 +0800 Subject: [PATCH 08/11] remove collector.sample in docs --- README.md | 2 +- docs/tutorials/concepts.rst | 3 ++- docs/tutorials/dqn.rst | 4 ++-- tianshou/policy/modelfree/ppo.py | 4 +--- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d8c6a3543..8aba0f132 100644 --- a/README.md +++ b/README.md @@ -167,7 +167,7 @@ result = collector.collect(n_episode=[1, 0, 3]) If you want to train the given policy with a sampled batch: ```python -result = policy.learn(collector.sample(batch_size)) +result = policy.update(batch_size, collector.buffer) ``` You can check out the [documentation](https://tianshou.readthedocs.io) for further usage. diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index eaa0bbca9..8d202d792 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -163,7 +163,8 @@ We give a high-level explanation through the pseudocode used in section :ref:`po buffer.store(s, a, s_, r, d) # collector.collect(...) s = s_ # collector.collect(...) if i % 1000 == 0: # done in trainer - b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size) + # the following is done in policy.update(batch_size, buffer) + b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # buffer.sample(batch_size) # compute 2-step returns. How? b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice) # update DQN policy diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index e5edbfd41..764bb4591 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -210,8 +210,8 @@ Tianshou supports user-defined training code. Here is the code snippet: # back to training eps policy.set_eps(0.1) - # train policy with a sampled batch data - losses = policy.learn(train_collector.sample(batch_size=64)) + # train policy with a sampled batch data from buffer + losses = policy.update(64, train_collector.buffer) For further usage, you can refer to the :doc:`/tutorials/cheatsheet`. diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index b4c5b79b8..6858f3054 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -81,9 +81,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, mean, std = batch.rew.mean(), batch.rew.std() if not np.isclose(std, 0): batch.rew = (batch.rew - mean) / std - v_ = [] - v = [] - old_log_prob = [] + v, v_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) From 2a7d296bf366b37f3778fdc0876ee0d6e6579791 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 11 Aug 2020 22:02:26 +0800 Subject: [PATCH 09/11] minor fix --- tianshou/data/collector.py | 2 +- tianshou/policy/base.py | 4 ++-- tianshou/policy/modelfree/ppo.py | 4 ++-- tianshou/trainer/offpolicy.py | 8 +++++--- tianshou/trainer/onpolicy.py | 3 ++- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 0c4792727..8d2863c1a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -358,7 +358,7 @@ def sample(self, batch_size: int) -> Batch: warnings.warn( 'Collector.sample is deprecated and will cause error if you use ' 'prioritized experience replay! Collector.sample will be removed ' - 'after version 0.3. Use policy.update instead!', Warning) + 'upon version 0.3. Use policy.update instead!', Warning) batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 25587f0f6..a87eee8ab 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -152,7 +152,7 @@ def compute_episodic_return( for i in range(len(rew) - 1, -1, -1): gae = delta[i] + m[i] * gae returns[i] += gae - if rew_norm and not np.isclose(returns.std(), 0): + if rew_norm and not np.isclose(returns.std(), 0, 1e-2): returns = (returns - returns.mean()) / returns.std() batch.returns = returns return batch @@ -200,7 +200,7 @@ def compute_nstep_return( if rew_norm: bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() - if np.isclose(std, 0): + if np.isclose(std, 0, 1e-2): mean, std = 0, 1 else: mean, std = 0, 1 diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 6858f3054..3094be82e 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -79,7 +79,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() - if not np.isclose(std, 0): + if not np.isclose(std, 0, 1e-2): batch.rew = (batch.rew - mean) / std v, v_, old_log_prob = [], [], [] with torch.no_grad(): @@ -99,7 +99,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, batch.adv = batch.returns - batch.v if self._rew_norm: mean, std = batch.adv.mean(), batch.adv.std() - if not np.isclose(std.item(), 0): + if not np.isclose(std.item(), 0, 1e-2): batch.adv = (batch.adv - mean) / std return batch diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index c11d08d6a..171cbb9da 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -28,7 +28,8 @@ def offpolicy_trainer( verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: - """A wrapper for off-policy trainer procedure. + """A wrapper for off-policy trainer procedure. The ``step`` in trainer + means a policy network update. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. @@ -47,8 +48,9 @@ def offpolicy_trainer( :param int batch_size: the batch size of sample data, which is going to feed in the policy network. :param int update_per_step: the number of times the policy network would - be updated after frames be collected. In other words, collect some - frames and do some policy network update. + be updated after frames are collected, for example, set it to 256 means + it updates policy 256 times once after ``collect_per_step`` frames are + collected. :param function train_fn: a function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 1c1e6a289..e31724d66 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -28,7 +28,8 @@ def onpolicy_trainer( verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: - """A wrapper for on-policy trainer procedure. + """A wrapper for on-policy trainer procedure. The ``step`` in trainer means + a policy network update. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. From b7b0b3b80208031779fe7f2a355bbce38c444aa0 Mon Sep 17 00:00:00 2001 From: n+e <463003665@qq.com> Date: Tue, 11 Aug 2020 22:28:04 +0800 Subject: [PATCH 10/11] Apply suggestions from code review just a test --- README.md | 2 +- docs/tutorials/concepts.rst | 2 +- tianshou/policy/base.py | 11 +++++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8aba0f132..2c5db38df 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page ### Modularized Policy -We decouple all of the algorithms into roughly the following parts: +We decouple all of the algorithms roughly into the following parts: - `__init__`: initialize the policy; - `forward`: to compute actions over given observations; diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 8d202d792..ba771ad55 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -71,7 +71,7 @@ A policy class typically has the following parts: * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. -* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer). +* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index a87eee8ab..7794315a6 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -234,8 +234,15 @@ def post_process_fn(self, batch: Batch, if isinstance(buffer, PrioritizedReplayBuffer): buffer.update_weight(indice, batch.weight) - def update(self, sample_size: int, buffer: ReplayBuffer, *args, **kwargs): - batch, indice = buffer.sample(sample_size) + def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs): + """Update the policy network and replay buffer (if needed). It includes + three function steps: process_fn, learn, and post_process_fn. + + :param int batch_size: 0 means it will extract all the data from the + buffer, otherwise it will sample a batch with the given batch_size. + :param ReplayBuffer buffer: the corresponding replay buffer. + """ + batch, indice = buffer.sample(batch_size) batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, *args, **kwargs) self.post_process_fn(batch, buffer, indice) From 41d3aa0290938bc843dd79de214c88592695fc63 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 12 Aug 2020 13:20:23 +0800 Subject: [PATCH 11/11] doc fix --- tianshou/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 7794315a6..a2f545e48 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -93,7 +93,7 @@ def forward(self, batch: Batch, # some code return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) - # and in the sampled data batch, you can directly call + # and in the sampled data batch, you can directly use # batch.policy.log_prob to get your data. """ pass