diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index ddfcd049d..083994467 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -10,6 +10,7 @@ from tianshou.policy import DDPGPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer +from tianshou.exploration import GaussianNoise if __name__ == '__main__': from net import Actor, Critic @@ -78,7 +79,7 @@ def test_ddpg(args=get_args()): critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, - args.tau, args.gamma, args.exploration_noise, + args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise), [env.action_space.low[0], env.action_space.high[0]], reward_normalization=args.rew_norm, ignore_done=args.ignore_done, diff --git a/test/continuous/test_sac_with_mcc.py b/test/continuous/test_sac_with_mcc.py new file mode 100644 index 000000000..74e439a84 --- /dev/null +++ b/test/continuous/test_sac_with_mcc.py @@ -0,0 +1,130 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer +from tianshou.env import VectorEnv +from tianshou.exploration import OUNoise + +if __name__ == '__main__': + from net import ActorProb, Critic +else: # pytest + from test.continuous.net import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='MountainCarContinuous-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=50000) + parser.add_argument('--actor-lr', type=float, default=3e-4) + parser.add_argument('--critic-lr', type=float, default=3e-4) + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument('--noise_std', type=float, default=0.5) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--auto_alpha', type=bool, default=True) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=2400) + parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--layer-num', type=int, default=1) + parser.add_argument('--training-num', type=int, default=80) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=1.0/35.0) + parser.add_argument('--rew-norm', type=bool, default=False) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_sac(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # train_envs = gym.make(args.task) + train_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + actor = ActorProb( + args.layer_num, args.state_shape, args.action_shape, + args.max_action, args.device + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + critic1 = Critic( + args.layer_num, args.state_shape, args.action_shape, args.device + ).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic( + args.layer_num, args.state_shape, args.action_shape, args.device + ).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + alpha = (target_entropy, log_alpha, alpha_optim) + else: + alpha = args.alpha + + policy = SACPolicy( + actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, + args.tau, args.gamma, alpha, + [env.action_space.low[0], env.action_space.high[0]], + reward_normalization=args.rew_norm, ignore_done=True, + exploration_noise=OUNoise(0.0, args.noise_std)) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # train_collector.collect(n_step=args.buffer_size) + # log + log_path = os.path.join(args.logdir, args.task, 'sac') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(x): + return x >= env.spec.reward_threshold + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + assert stop_fn(result['best_reward']) + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + test_sac() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 5ee3d8022..6d3133bb7 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -10,6 +10,7 @@ from tianshou.policy import TD3Policy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer +from tianshou.exploration import GaussianNoise if __name__ == '__main__': from net import Actor, Critic @@ -85,8 +86,8 @@ def test_td3(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, args.exploration_noise, args.policy_noise, - args.update_actor_freq, args.noise_clip, + args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise), + args.policy_noise, args.update_actor_freq, args.noise_clip, [env.action_space.low[0], env.action_space.high[0]], reward_normalization=args.rew_norm, ignore_done=args.ignore_done, diff --git a/tianshou/exploration/__init__.py b/tianshou/exploration/__init__.py index 220913ee9..abe6f3806 100644 --- a/tianshou/exploration/__init__.py +++ b/tianshou/exploration/__init__.py @@ -1,5 +1,7 @@ -from tianshou.exploration.random import OUNoise +from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise __all__ = [ + 'BaseNoise', + 'GaussianNoise', 'OUNoise', ] diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index 1e83dfc71..2df5489f9 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -1,8 +1,42 @@ import numpy as np from typing import Union, Optional +from abc import ABC, abstractmethod -class OUNoise(object): +class BaseNoise(ABC, object): + """The action noise base class.""" + + def __init__(self, **kwargs) -> None: + super(BaseNoise, self).__init__() + + @abstractmethod + def __call__(self, **kwargs) -> np.ndarray: + """Generate new noise.""" + raise NotImplementedError + + def reset(self, **kwargs) -> None: + """Reset to the initial state.""" + pass + + +class GaussianNoise(BaseNoise): + """Class for vanilla gaussian process, + used for exploration in DDPG by default. + """ + + def __init__(self, + mu: float = 0.0, + sigma: float = 1.0): + super().__init__() + self._mu = mu + assert 0 <= sigma, 'noise std should not be negative' + self._sigma = sigma + + def __call__(self, size: tuple) -> np.ndarray: + return np.random.normal(self._mu, self._sigma, size) + + +class OUNoise(BaseNoise): """Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG. Usage: :: @@ -19,26 +53,31 @@ class OUNoise(object): """ def __init__(self, + mu: float = 0.0, sigma: float = 0.3, theta: float = 0.15, dt: float = 1e-2, x0: Optional[Union[float, np.ndarray]] = None ) -> None: - self.alpha = theta * dt - self.beta = sigma * np.sqrt(dt) - self.x0 = x0 + super(BaseNoise, self).__init__() + self._mu = mu + self._alpha = theta * dt + self._beta = sigma * np.sqrt(dt) + self._x0 = x0 self.reset() - def __call__(self, size: tuple, mu: float = .1) -> np.ndarray: + def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray: """Generate new noise. Return a ``numpy.ndarray`` which size is equal to ``size``. """ - if self.x is None or self.x.shape != size: - self.x = 0 - r = self.beta * np.random.normal(size=size) - self.x = self.x + self.alpha * (mu - self.x) + r - return self.x + if self._x is None or self._x.shape != size: + self._x = 0 + if mu is None: + mu = self._mu + r = self._beta * np.random.normal(size=size) + self._x = self._x + self._alpha * (mu - self._x) + r + return self._x def reset(self) -> None: """Reset to the initial state.""" - self.x = None + self._x = self._x0 diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 5d3dde3b8..ec56123d8 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -5,7 +5,7 @@ from typing import Dict, Tuple, Union, Optional from tianshou.policy import BasePolicy -# from tianshou.exploration import OUNoise +from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.data import Batch, ReplayBuffer, to_torch_as @@ -21,8 +21,8 @@ class DDPGPolicy(BasePolicy): :param float tau: param for soft update of the target network, defaults to 0.005. :param float gamma: discount factor, in [0, 1], defaults to 0.99. - :param float exploration_noise: the noise intensity, add to the action, - defaults to 0.1. + :param BaseNoise exploration_noise: the exploration noise, + add to the action, defaults to ``GaussianNoise(sigma=0.1)``. :param action_range: the action range (minimum, maximum). :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), @@ -45,7 +45,8 @@ def __init__(self, critic_optim: torch.optim.Optimizer, tau: float = 0.005, gamma: float = 0.99, - exploration_noise: float = 0.1, + exploration_noise: Optional[BaseNoise] + = GaussianNoise(sigma=0.1), action_range: Optional[Tuple[float, float]] = None, reward_normalization: bool = False, ignore_done: bool = False, @@ -64,8 +65,7 @@ def __init__(self, self._tau = tau assert 0 <= gamma <= 1, 'gamma should in [0, 1]' self._gamma = gamma - assert 0 <= exploration_noise, 'noise should not be negative' - self._eps = exploration_noise + self._noise = exploration_noise assert action_range is not None self._range = action_range self._action_bias = (action_range[0] + action_range[1]) / 2 @@ -77,9 +77,9 @@ def __init__(self, assert estimation_step > 0, 'estimation_step should greater than 0' self._n_step = estimation_step - def set_eps(self, eps: float) -> None: - """Set the eps for exploration.""" - self._eps = eps + def set_exp_noise(self, noise: Optional[BaseNoise]) -> None: + """Set the exploration noise.""" + self._noise = noise def train(self) -> None: """Set the module in training mode, except for the target network.""" @@ -106,7 +106,8 @@ def _target_q(self, buffer: ReplayBuffer, batch = buffer[indice] # batch.obs_next: s_{t+n} with torch.no_grad(): target_q = self.critic_old(batch.obs_next, self( - batch, model='actor_old', input='obs_next', eps=0).act) + batch, model='actor_old', input='obs_next', + explorating=False).act) return target_q def process_fn(self, batch: Batch, buffer: ReplayBuffer, @@ -122,7 +123,7 @@ def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = 'actor', input: str = 'obs', - eps: Optional[float] = None, + explorating: bool = True, **kwargs) -> Batch: """Compute action over the given batch data. @@ -142,14 +143,8 @@ def forward(self, batch: Batch, obs = getattr(batch, input) logits, h = model(obs, state=state, info=batch.info) logits += self._action_bias - if eps is None: - eps = self._eps - if eps > 0: - # noise = np.random.normal(0, eps, size=logits.shape) - # logits += to_torch(noise, device=logits.device) - # noise = self.noise(logits.shape, eps) - logits += torch.randn( - size=logits.shape, device=logits.device) * eps + if self.training and explorating: + logits += to_torch_as(self._noise(logits.shape), logits) logits = logits.clamp(self._range[0], self._range[1]) return Batch(act=logits, state=h) @@ -161,7 +156,8 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() - actor_loss = -self.critic(batch.obs, self(batch, eps=0).act).mean() + action = self(batch, explorating=False).act + actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 2885fe1eb..8ddb2bfe4 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -7,6 +7,7 @@ from tianshou.policy import DDPGPolicy from tianshou.policy.dist import DiagGaussian from tianshou.data import Batch, to_torch_as, ReplayBuffer +from tianshou.exploration import BaseNoise class SACPolicy(DDPGPolicy): @@ -28,13 +29,18 @@ class SACPolicy(DDPGPolicy): :param float gamma: discount factor, in [0, 1], defaults to 0.99. :param float exploration_noise: the noise intensity, add to the action, defaults to 0.1. - :param float alpha: entropy regularization coefficient, default to 0.2. + :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy + regularization coefficient, default to 0.2. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then + alpha is automatatically tuned. :param action_range: the action range (minimum, maximum). :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. :param bool ignore_done: ignore the done flag while training the policy, defaults to ``False``. + :param BaseNoise exploration_noise: add a noise to action for exploration. + This is useful when solving hard-exploration problem. .. seealso:: @@ -51,13 +57,15 @@ def __init__(self, critic2_optim: torch.optim.Optimizer, tau: float = 0.005, gamma: float = 0.99, - alpha: float = 0.2, + alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer] + or float = 0.2, action_range: Optional[Tuple[float, float]] = None, reward_normalization: bool = False, ignore_done: bool = False, estimation_step: int = 1, + exploration_noise: Optional[BaseNoise] = None, **kwargs) -> None: - super().__init__(None, None, None, None, tau, gamma, 0, + super().__init__(None, None, None, None, tau, gamma, exploration_noise, action_range, reward_normalization, ignore_done, estimation_step, **kwargs) self.actor, self.actor_optim = actor, actor_optim @@ -67,7 +75,18 @@ def __init__(self, self.critic2, self.critic2_old = critic2, deepcopy(critic2) self.critic2_old.eval() self.critic2_optim = critic2_optim - self._alpha = alpha + + self._automatic_alpha_tuning = not isinstance(alpha, float) + if self._automatic_alpha_tuning: + self._target_entropy = alpha[0] + assert(alpha[1].shape == torch.Size([1]) + and alpha[1].requires_grad) + self._log_alpha = alpha[1] + self._alpha_optim = alpha[2] + self._alpha = self._log_alpha.exp() + else: + self._alpha = alpha + self.__eps = np.finfo(np.float32).eps.item() def train(self) -> None: @@ -92,7 +111,9 @@ def sync_weight(self) -> None: def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, - input: str = 'obs', **kwargs) -> Batch: + input: str = 'obs', + explorating: bool = True, + **kwargs) -> Batch: obs = getattr(batch, input) logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) @@ -100,9 +121,10 @@ def forward(self, batch: Batch, x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias - log_prob = dist.log_prob(x) - torch.log( - self._action_scale * (1 - y.pow(2)) + self.__eps - ).sum(-1, keepdim=True) + y = self._action_scale * (1 - y.pow(2)) + self.__eps + log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True) + if self._noise is not None and self.training and explorating: + act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) @@ -111,7 +133,7 @@ def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): - obs_next_result = self(batch, input='obs_next') + obs_next_result = self(batch, input='obs_next', explorating=False) a_ = obs_next_result.act batch.act = to_torch_as(batch.act, a_) target_q = torch.min( @@ -135,7 +157,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: critic2_loss.backward() self.critic2_optim.step() # actor - obs_result = self(batch) + obs_result = self(batch, explorating=False) a = obs_result.act current_q1a = self.critic1(batch.obs, a) current_q2a = self.critic2(batch.obs, a) @@ -144,9 +166,22 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() + + if self._automatic_alpha_tuning: + log_prob = (obs_result.log_prob + self._target_entropy).detach() + alpha_loss = -(self._log_alpha * log_prob).mean() + self._alpha_optim.zero_grad() + alpha_loss.backward() + self._alpha_optim.step() + self._alpha = self._log_alpha.exp() + self.sync_weight() - return { + + result = { 'loss/actor': actor_loss.item(), 'loss/critic1': critic1_loss.item(), 'loss/critic2': critic2_loss.item(), } + if self._automatic_alpha_tuning: + result['loss/alpha'] = alpha_loss.item() + return result diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index dd3004cce..2223e37ba 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -6,6 +6,7 @@ from tianshou.policy import DDPGPolicy from tianshou.data import Batch, ReplayBuffer +from tianshou.exploration import BaseNoise, GaussianNoise class TD3Policy(DDPGPolicy): @@ -26,8 +27,8 @@ class TD3Policy(DDPGPolicy): :param float tau: param for soft update of the target network, defaults to 0.005. :param float gamma: discount factor, in [0, 1], defaults to 0.99. - :param float exploration_noise: the noise intensity, add to the action, - defaults to 0.1. + :param float exploration_noise: the exploration noise, add to the action, + defaults to ``GaussianNoise(sigma=0.1)`` :param float policy_noise: the noise used in updating policy network, default to 0.2. :param int update_actor_freq: the update frequency of actor network, @@ -56,7 +57,8 @@ def __init__(self, critic2_optim: torch.optim.Optimizer, tau: float = 0.005, gamma: float = 0.99, - exploration_noise: float = 0.1, + exploration_noise: Optional[BaseNoise] + = GaussianNoise(sigma=0.1), policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5,