From b6bdedcbda1ec8e8b766119f692fd50c882ab68a Mon Sep 17 00:00:00 2001 From: rocknamx8 Date: Mon, 21 Sep 2020 17:01:23 +0800 Subject: [PATCH 01/11] add description of self.learning --- docs/tutorials/concepts.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index d7f5971d2..a0ca9d27f 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -72,6 +72,7 @@ A policy class typically has the following parts: * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. +* :meth:`~tianshou.policy.BasePolicy.learning`: indicate the learning state. * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. From d6886fb80358498fee0e3f308aab74003d34e93e Mon Sep 17 00:00:00 2001 From: rocknamx8 Date: Mon, 21 Sep 2020 17:03:18 +0800 Subject: [PATCH 02/11] add self.learning for indicating learning state. --- tianshou/policy/base.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 8bd0fcbe6..852b2f408 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -61,6 +61,7 @@ def __init__( self.observation_space = observation_space self.action_space = action_space self.agent_id = 0 + self.learning = False self._compile() def set_agent_id(self, agent_id: int) -> None: @@ -127,6 +128,14 @@ def learn( "[batch_size]" shape while Normal distribution gives "[batch_size, 1]" shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error. + .. trick:: + In order to distinguish the training state, learning state and + testing state, you can check the policy state by ``self.training`` + and ``self.learning``. The state setting is as follow: + training: ``self.training=True``. + perform ``self.learn()`` during training: ``self.training=True``, + ``self.learning=True``. + testing: ``self.training=False``, ``self.learning=False`` """ pass @@ -149,6 +158,8 @@ def update( """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. + In addition, ``self.learning`` will be True before ``self.learn()`` + and ``self.learning`` will be False after ``self.learn()``. :param int sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. @@ -158,7 +169,9 @@ def update( return {} batch, indice = buffer.sample(sample_size) batch = self.process_fn(batch, buffer, indice) + self.learning = True result = self.learn(batch, **kwargs) + self.learning = False self.post_process_fn(batch, buffer, indice) return result From 127436904316d8e16024d4c756919a7c9235df03 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 21 Sep 2020 17:38:16 +0800 Subject: [PATCH 03/11] polish --- docs/tutorials/concepts.rst | 19 ++++++++++++++++++- tianshou/policy/base.py | 21 +++++++++++---------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index a0ca9d27f..4618c252d 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -72,7 +72,6 @@ A policy class typically has the following parts: * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. -* :meth:`~tianshou.policy.BasePolicy.learning`: indicate the learning state. * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. @@ -98,6 +97,24 @@ For example, if you try to use your policy to evaluate one episode (and don't wa Here, ``Batch(obs=[obs])`` will automatically create the 0-dimension to be the batch-size. Otherwise, the network cannot determine the batch-size. +.. _policy_state: + +Three states for policy +^^^^^^^^^^^^^^^^^^^^^^^ + +TODO: explain what is training state, learning state and testing state. + +We define the learning state as performing ``policy.learn()`` during training. + +In order to distinguish the training state, learning state and testing state, you can check the policy state by ``policy.training`` and ``policy.learning``. The state setting is as follows: + +TODO: use a table instead of itemize + +* Training: ``policy.training=True`` and ``policy.learning=False``; +* Learning state: ``policy.training=True`` and ``policy.learning=True``; +* Testing: ``policy.training=False`` and ``policy.learning=False``. + + .. _process_fn: policy.process_fn diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index fc4170413..e13eb47e6 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -119,6 +119,13 @@ def learn( :return: A dict which includes loss and its corresponding label. + .. note:: + + In order to distinguish the training state, learning state and + testing state, you can check the policy state by ``self.training`` + and ``self.learning``. Please refer to :ref:`policy_state` for more + detailed explanation. + .. warning:: If you use ``torch.distributions.Normal`` and @@ -127,14 +134,6 @@ def learn( "[batch_size]" shape while Normal distribution gives "[batch_size, 1]" shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error. - .. trick:: - In order to distinguish the training state, learning state and - testing state, you can check the policy state by ``self.training`` - and ``self.learning``. The state setting is as follow: - training: ``self.training=True``. - perform ``self.learn()`` during training: ``self.training=True``, - ``self.learning=True``. - testing: ``self.training=False``, ``self.learning=False`` """ pass @@ -155,8 +154,10 @@ def update( """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. - In addition, ``self.learning`` will be True before ``self.learn()`` - and ``self.learning`` will be False after ``self.learn()``. + In addition, this function will change the value of ``self.learning``: + it will be True before ``self.learn()`` and will be False after + ``self.learn()``. Please refer to :ref:`policy_state` for more detailed + explanation. :param int sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. From 93d3bbbbbdc011bb4c2011d7d330c3b84207418f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 21 Sep 2020 17:43:21 +0800 Subject: [PATCH 04/11] fix a bug in timing --- README.md | 1 + docs/index.rst | 1 + tianshou/data/collector.py | 6 +++++- tianshou/trainer/offpolicy.py | 2 ++ tianshou/trainer/onpolicy.py | 2 ++ 5 files changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f65837e1c..c68fbf4ad 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) +- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) diff --git a/docs/index.rst b/docs/index.rst index 587dc7e5c..454997ef7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -18,6 +18,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ +* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 531553846..0d755fbc2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -129,10 +129,14 @@ def reset(self) -> None: obs_next={}, policy={}) self.reset_env() self.reset_buffer() - self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 + self.reset_stat() if self._action_noise is not None: self._action_noise.reset() + def reset_stat(self) -> None: + """Reset the statistic variables.""" + self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 + def reset_buffer(self) -> None: """Reset the main data buffer.""" if self.buffer is not None: diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 01e6530bb..170fd6835 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -75,6 +75,8 @@ def offpolicy_trainer( best_epoch, best_reward = -1, -1.0 stat: Dict[str, MovAvg] = {} start_time = time.time() + train_collector.reset_stat() + test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): # train diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 37f427826..877c6348c 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -75,6 +75,8 @@ def onpolicy_trainer( best_epoch, best_reward = -1, -1.0 stat: Dict[str, MovAvg] = {} start_time = time.time() + train_collector.reset_stat() + test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): # train From 958ec270fb23dd0d2dddefe14344405e7cef83d9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 21 Sep 2020 17:49:08 +0800 Subject: [PATCH 05/11] add table --- docs/tutorials/concepts.rst | 40 ++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 4618c252d..db8be5960 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -75,6 +75,28 @@ A policy class typically has the following parts: * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. +.. _policy_state: + +Three states for policy +^^^^^^^^^^^^^^^^^^^^^^^ + +TODO: explain what is training state, learning state and testing state. + +We define the learning state as performing ``policy.learn()`` during training. + +In order to distinguish the training state, learning state and testing state, you can check the policy state by ``policy.training`` and ``policy.learning``. The state setting is as follows: + ++------------------+-----------------+-----------------+ +| State for policy | policy.learning | policy.training | ++==================+=================+=================+ +| Training state | False | True | ++------------------+-----------------+-----------------+ +| Learning state | True | True | ++------------------+-----------------+-----------------+ +| Testing state | False | False | ++------------------+-----------------+-----------------+ + + policy.forward ^^^^^^^^^^^^^^ @@ -97,24 +119,6 @@ For example, if you try to use your policy to evaluate one episode (and don't wa Here, ``Batch(obs=[obs])`` will automatically create the 0-dimension to be the batch-size. Otherwise, the network cannot determine the batch-size. -.. _policy_state: - -Three states for policy -^^^^^^^^^^^^^^^^^^^^^^^ - -TODO: explain what is training state, learning state and testing state. - -We define the learning state as performing ``policy.learn()`` during training. - -In order to distinguish the training state, learning state and testing state, you can check the policy state by ``policy.training`` and ``policy.learning``. The state setting is as follows: - -TODO: use a table instead of itemize - -* Training: ``policy.training=True`` and ``policy.learning=False``; -* Learning state: ``policy.training=True`` and ``policy.learning=True``; -* Testing: ``policy.training=False`` and ``policy.learning=False``. - - .. _process_fn: policy.process_fn From dd9a3d630147e2a93aad1fc623280e6b51f85c20 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 21 Sep 2020 18:27:47 +0800 Subject: [PATCH 06/11] remove exploration argument in forward --- tianshou/policy/base.py | 4 ++-- tianshou/policy/modelfree/ddpg.py | 11 +++++------ tianshou/policy/modelfree/sac.py | 7 +++---- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index e13eb47e6..eea6329db 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -166,11 +166,11 @@ def update( if buffer is None: return {} batch, indice = buffer.sample(sample_size) - batch = self.process_fn(batch, buffer, indice) self.learning = True + batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, **kwargs) - self.learning = False self.post_process_fn(batch, buffer, indice) + self.learning = False return result @staticmethod diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 81bf7f6d7..71112924b 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -103,9 +103,9 @@ def _target_q( ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} with torch.no_grad(): - target_q = self.critic_old(batch.obs_next, self( - batch, model='actor_old', input='obs_next', - explorating=False).act) + target_q = self.critic_old( + batch.obs_next, + self(batch, model='actor_old', input='obs_next').act) return target_q def process_fn( @@ -124,7 +124,6 @@ def forward( state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "actor", input: str = "obs", - explorating: bool = True, **kwargs: Any, ) -> Batch: """Compute action over the given batch data. @@ -143,7 +142,7 @@ def forward( obs = batch[input] actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias - if self._noise and self.training and explorating: + if self._noise and not self.learning: actions += to_torch_as(self._noise(actions.shape), actions) actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) @@ -158,7 +157,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() - action = self(batch, explorating=False).act + action = self(batch).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index e44f8a124..ec7f41015 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -110,7 +110,6 @@ def forward( # type: ignore batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", - explorating: bool = True, **kwargs: Any, ) -> Batch: obs = batch[input] @@ -123,7 +122,7 @@ def forward( # type: ignore y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) - if self._noise is not None and self.training and explorating: + if self._noise is not None and not self.learning: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( @@ -134,7 +133,7 @@ def _target_q( ) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): - obs_next_result = self(batch, input='obs_next', explorating=False) + obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act batch.act = to_torch_as(batch.act, a_) target_q = torch.min( @@ -167,7 +166,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor - obs_result = self(batch, explorating=False) + obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() From cfe5bd6d56cbba955910166ac2adbad9f807411a Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 21 Sep 2020 18:41:00 +0800 Subject: [PATCH 07/11] first try dqn --- tianshou/policy/modelfree/dqn.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 71d16f6b6..901480868 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -80,7 +80,7 @@ def _target_q( batch = buffer[indice] # batch.obs_next: s_{t+n} if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(batch, input="obs_next", eps=0).act + a = self(batch, input="obs_next").act with torch.no_grad(): target_q = self( batch, model="model_old", input="obs_next" @@ -110,7 +110,6 @@ def forward( state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "model", input: str = "obs", - eps: Optional[float] = None, **kwargs: Any, ) -> Batch: """Compute action over the given batch data. @@ -152,12 +151,10 @@ def forward( q_: np.ndarray = to_numpy(q) q_[~obs.mask] = -np.inf act = q_.argmax(axis=1) - # add eps to act - if eps is None: - eps = self.eps - if not np.isclose(eps, 0.0): + # add eps to act in training or testing phase + if not self.learning and not np.isclose(self.eps, 0.0): for i in range(len(q)): - if np.random.rand() < eps: + if np.random.rand() < self.eps: q_ = np.random.rand(*q[i].shape) if hasattr(obs, "mask"): q_[~obs.mask[i]] = -np.inf @@ -169,7 +166,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) - q = self(batch, eps=0.0).logits + q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) td = r - q From aa546c363112ec418c9d65e8f5a609fc5e61e8c5 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 21 Sep 2020 20:21:31 +0800 Subject: [PATCH 08/11] fix doc --- docs/tutorials/concepts.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index db8be5960..7136d1292 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -80,9 +80,10 @@ A policy class typically has the following parts: Three states for policy ^^^^^^^^^^^^^^^^^^^^^^^ -TODO: explain what is training state, learning state and testing state. - -We define the learning state as performing ``policy.learn()`` during training. +During the training process, the policy has three state, namely training state, learning state, and testing state: +Training state is defined as interacting with environments and collecting training data into the buffer; +we define the learning state as performing a model update (such as ``policy.learn()``) during training process; +and the testing state is obvious: evaluate the performance of the current policy during training process. In order to distinguish the training state, learning state and testing state, you can check the policy state by ``policy.training`` and ``policy.learning``. The state setting is as follows: @@ -96,6 +97,8 @@ In order to distinguish the training state, learning state and testing state, yo | Testing state | False | False | +------------------+-----------------+-----------------+ +``policy.learning`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.learning`` is helpful for setting epsilon in this case. + policy.forward ^^^^^^^^^^^^^^ From f2d1bccb0666f78a941cee62e5ea1840b45f627b Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 21 Sep 2020 21:43:55 +0800 Subject: [PATCH 09/11] collecting state --- docs/tutorials/concepts.rst | 10 +++++----- tianshou/policy/base.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 7136d1292..632877c9e 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -80,17 +80,17 @@ A policy class typically has the following parts: Three states for policy ^^^^^^^^^^^^^^^^^^^^^^^ -During the training process, the policy has three state, namely training state, learning state, and testing state: -Training state is defined as interacting with environments and collecting training data into the buffer; +During the training process, the policy has three state, namely collecting state, learning state, and testing state: +Collecting state is defined as interacting with environments and collecting training data into the buffer; we define the learning state as performing a model update (such as ``policy.learn()``) during training process; and the testing state is obvious: evaluate the performance of the current policy during training process. -In order to distinguish the training state, learning state and testing state, you can check the policy state by ``policy.training`` and ``policy.learning``. The state setting is as follows: +In order to distinguish the collecting state, learning state and testing state, you can check the policy state by ``policy.training`` and ``policy.learning``. The state setting is as follows: +------------------+-----------------+-----------------+ -| State for policy | policy.learning | policy.training | +| State for policy | policy.training | policy.learning | +==================+=================+=================+ -| Training state | False | True | +| Collecting state | True | False | +------------------+-----------------+-----------------+ | Learning state | True | True | +------------------+-----------------+-----------------+ diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index eea6329db..882a66b4c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -121,7 +121,7 @@ def learn( .. note:: - In order to distinguish the training state, learning state and + In order to distinguish the collecting state, learning state and testing state, you can check the policy state by ``self.training`` and ``self.learning``. Please refer to :ref:`policy_state` for more detailed explanation. From df8ccf36f7874be303aa13879aa69d3fe56a55cb Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 22 Sep 2020 09:35:23 +0800 Subject: [PATCH 10/11] updating state --- docs/tutorials/concepts.rst | 45 ++++++++++++++++--------------- tianshou/policy/base.py | 16 +++++------ tianshou/policy/modelfree/ddpg.py | 2 +- tianshou/policy/modelfree/dqn.py | 2 +- tianshou/policy/modelfree/sac.py | 2 +- 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 632877c9e..7d644587d 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -77,27 +77,30 @@ A policy class typically has the following parts: .. _policy_state: -Three states for policy -^^^^^^^^^^^^^^^^^^^^^^^ - -During the training process, the policy has three state, namely collecting state, learning state, and testing state: -Collecting state is defined as interacting with environments and collecting training data into the buffer; -we define the learning state as performing a model update (such as ``policy.learn()``) during training process; -and the testing state is obvious: evaluate the performance of the current policy during training process. - -In order to distinguish the collecting state, learning state and testing state, you can check the policy state by ``policy.training`` and ``policy.learning``. The state setting is as follows: - -+------------------+-----------------+-----------------+ -| State for policy | policy.training | policy.learning | -+==================+=================+=================+ -| Collecting state | True | False | -+------------------+-----------------+-----------------+ -| Learning state | True | True | -+------------------+-----------------+-----------------+ -| Testing state | False | False | -+------------------+-----------------+-----------------+ - -``policy.learning`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.learning`` is helpful for setting epsilon in this case. +States for policy +^^^^^^^^^^^^^^^^^ + +During the training process, the policy has two main states: training state and testing state. The training state can be further divided into the collecting state and updating state. + +The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process. + +As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer; +we define the updating state as performing a model update by :meth:`~tianshou.policy.BasePolicy.update` during training process. + + +In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows: + ++-----------------------------------+-----------------+-----------------+ +| State for policy | policy.training | policy.updating | ++================+==================+=================+=================+ +| | Collecting state | True | False | +| Training state +------------------+-----------------+-----------------+ +| | Updating state | True | True | ++----------------+------------------+-----------------+-----------------+ +| Testing state | False | False | ++-----------------------------------+-----------------+-----------------+ + +``policy.updating`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.updating`` is helpful for setting epsilon in this case. policy.forward diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 882a66b4c..809751fe7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -60,7 +60,7 @@ def __init__( self.observation_space = observation_space self.action_space = action_space self.agent_id = 0 - self.learning = False + self.updating = False self._compile() def set_agent_id(self, agent_id: int) -> None: @@ -121,9 +121,9 @@ def learn( .. note:: - In order to distinguish the collecting state, learning state and + In order to distinguish the collecting state, updating state and testing state, you can check the policy state by ``self.training`` - and ``self.learning``. Please refer to :ref:`policy_state` for more + and ``self.updating``. Please refer to :ref:`policy_state` for more detailed explanation. .. warning:: @@ -154,9 +154,9 @@ def update( """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. - In addition, this function will change the value of ``self.learning``: - it will be True before ``self.learn()`` and will be False after - ``self.learn()``. Please refer to :ref:`policy_state` for more detailed + In addition, this function will change the value of ``self.updating``: + it will be False before this function and will be True when executing + :meth:`update`. Please refer to :ref:`policy_state` for more detailed explanation. :param int sample_size: 0 means it will extract all the data from the @@ -166,11 +166,11 @@ def update( if buffer is None: return {} batch, indice = buffer.sample(sample_size) - self.learning = True + self.updating = True batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indice) - self.learning = False + self.updating = False return result @staticmethod diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 71112924b..ab28b6b7d 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -142,7 +142,7 @@ def forward( obs = batch[input] actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias - if self._noise and not self.learning: + if self._noise and not self.updating: actions += to_torch_as(self._noise(actions.shape), actions) actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 901480868..91cca6139 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -152,7 +152,7 @@ def forward( q_[~obs.mask] = -np.inf act = q_.argmax(axis=1) # add eps to act in training or testing phase - if not self.learning and not np.isclose(self.eps, 0.0): + if not self.updating and not np.isclose(self.eps, 0.0): for i in range(len(q)): if np.random.rand() < self.eps: q_ = np.random.rand(*q[i].shape) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ec7f41015..8d1d72369 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -122,7 +122,7 @@ def forward( # type: ignore y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) - if self._noise is not None and not self.learning: + if self._noise is not None and not self.updating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( From a048736f8ec6d7b2fe1e675d8822aa3d191afa0b Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 22 Sep 2020 14:51:18 +0800 Subject: [PATCH 11/11] fix a bug in dqn --- tianshou/utils/net/discrete.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 547593f66..0a8452ced 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -116,6 +116,7 @@ def conv2d_layers_size_out( nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(linear_input_size, 512), + nn.ReLU(inplace=True), nn.Linear(512, np.prod(action_shape)), )