这是indexloc提供的服务,不要输入任何密码
Skip to content

Add lr_scheduler option for Onpolicy algorithm #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class A2CPolicy(PGPolicy):
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down Expand Up @@ -142,6 +144,10 @@ def learn( # type: ignore
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return {
"loss": losses,
"loss/actor": actor_losses,
Expand Down
8 changes: 8 additions & 0 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class PGPolicy(BasePolicy):
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand All @@ -38,13 +40,15 @@ def __init__(
reward_normalization: bool = False,
action_scaling: bool = True,
action_bound_method: str = "clip",
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
**kwargs: Any,
) -> None:
super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if model is not None:
self.model: torch.nn.Module = model
self.optim = optim
self.lr_scheduler = lr_scheduler
self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor
Expand Down Expand Up @@ -110,6 +114,10 @@ def learn( # type: ignore
loss.backward()
self.optim.step()
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return {"loss": losses}

# def _vanilla_returns(self, batch):
Expand Down
6 changes: 6 additions & 0 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class PPOPolicy(PGPolicy):
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down Expand Up @@ -179,6 +181,10 @@ def learn( # type: ignore
list(self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm)
self.optim.step()
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return {
"loss": losses,
"loss/clip": clip_losses,
Expand Down