这是indexloc提供的服务,不要输入任何密码
Skip to content

SAC implementation update #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/box2d/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Bipedal-Hardcore-SAC

- Our default choice: remove the done flag penalty, will soon converge to \~250 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
- Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
- Action noise is only necessary in the beginning. It is a negative impact at the end of the training. Removing it can reach \~255 (our best result under the original env, no done penalty removed).

![](results/sac/BipedalHardcore.png)
![](results/sac/BipedalHardcore.png)
10 changes: 9 additions & 1 deletion examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.1)
parser.add_argument('--auto_alpha', type=int, default=1)
parser.add_argument('--alpha_lr', type=float, default=3e-4)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
Expand All @@ -46,7 +48,7 @@ def get_args():
class EnvWrapper(object):
"""Env wrapper for reward scale, action repeat and action noise"""

def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3):
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0):
self._env = gym.make(task)
self.action_repeat = action_repeat
self.reward_scale = reward_scale
Expand Down Expand Up @@ -109,6 +111,12 @@ def IsStop(reward):
critic2 = Critic(net_c2, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)

policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
Expand Down
Binary file modified examples/box2d/results/sac/BipedalHardcore.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 6 additions & 2 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal

from tianshou.policy import PPOPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.policy.dist import DiagGaussian
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
Expand Down Expand Up @@ -84,7 +84,11 @@ def test_ppo(args=get_args()):
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = DiagGaussian

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
max_grad_norm=args.max_grad_norm,
Expand Down
11 changes: 0 additions & 11 deletions tianshou/policy/dist.py

This file was deleted.

67 changes: 35 additions & 32 deletions tianshou/policy/modelfree/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import numpy as np
from copy import deepcopy
from typing import Dict, Tuple, Union, Optional
from torch.distributions import Normal, Independent

from tianshou.policy import DDPGPolicy
from tianshou.policy.dist import DiagGaussian
from tianshou.data import Batch, to_torch_as, ReplayBuffer
from tianshou.exploration import BaseNoise

Expand Down Expand Up @@ -47,23 +47,26 @@ class SACPolicy(DDPGPolicy):
explanation.
"""

def __init__(self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer]
or float = 0.2,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
**kwargs) -> None:
def __init__(
self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
alpha: Union[
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
] = 0.2,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
**kwargs
) -> None:
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
action_range, reward_normalization, ignore_done,
estimation_step, **kwargs)
Expand All @@ -75,14 +78,12 @@ def __init__(self,
self.critic2_old.eval()
self.critic2_optim = critic2_optim

self._automatic_alpha_tuning = not isinstance(alpha, float)
if self._automatic_alpha_tuning:
self._target_entropy = alpha[0]
assert(alpha[1].shape == torch.Size([1])
and alpha[1].requires_grad)
self._log_alpha = alpha[1]
self._alpha_optim = alpha[2]
self._alpha = self._log_alpha.exp()
self._is_auto_alpha = False
if isinstance(alpha, tuple):
self._is_auto_alpha = True
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
self._alpha = self._log_alpha.detach().exp()
else:
self._alpha = alpha

Expand Down Expand Up @@ -111,12 +112,13 @@ def forward(self, batch: Batch,
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = DiagGaussian(*logits)
dist = Independent(Normal(*logits), 1)
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
y = self._action_scale * (1 - y.pow(2)) + self.__eps
log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
log_prob = dist.log_prob(x).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
if self._noise is not None and self.training and explorating:
act += to_torch_as(self._noise(act.shape), act)
act = act.clamp(self._range[0], self._range[1])
Expand Down Expand Up @@ -167,13 +169,13 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
actor_loss.backward()
self.actor_optim.step()

if self._automatic_alpha_tuning:
log_prob = (obs_result.log_prob + self._target_entropy).detach()
if self._is_auto_alpha:
log_prob = obs_result.log_prob.detach() + self._target_entropy
alpha_loss = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad()
alpha_loss.backward()
self._alpha_optim.step()
self._alpha = self._log_alpha.exp()
self._alpha = self._log_alpha.detach().exp()

self.sync_weight()

Expand All @@ -182,6 +184,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
'loss/critic1': critic1_loss.item(),
'loss/critic2': critic2_loss.item(),
}
if self._automatic_alpha_tuning:
if self._is_auto_alpha:
result['loss/alpha'] = alpha_loss.item()
result['v/alpha'] = self._alpha.item()
return result
11 changes: 5 additions & 6 deletions tianshou/trainer/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def offpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
# collect
if train_fn:
train_fn(epoch)
policy.eval()
result = train_collector.collect(n_step=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
Expand All @@ -100,10 +100,9 @@ def offpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
# train
policy.train()
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += collect_per_step
Expand Down
11 changes: 5 additions & 6 deletions tianshou/trainer/onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def onpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
# collect
if train_fn:
train_fn(epoch)
policy.eval()
result = train_collector.collect(n_episode=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
Expand All @@ -100,10 +100,9 @@ def onpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
# train
policy.train()
losses = policy.update(
0, train_collector.buffer, batch_size, repeat_per_collect)
train_collector.reset_buffer()
Expand Down