diff --git a/docs/_static/images/Ant-v2.png b/docs/_static/images/Ant-v2.png deleted file mode 100644 index b5497592a..000000000 Binary files a/docs/_static/images/Ant-v2.png and /dev/null differ diff --git a/docs/_static/images/concepts_arch2.png b/docs/_static/images/concepts_arch2.png index 134bad337..58d8c62e7 100644 Binary files a/docs/_static/images/concepts_arch2.png and b/docs/_static/images/concepts_arch2.png differ diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index e44890308..b7d8ffef4 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -58,8 +58,8 @@ def test_dqn(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # should be N_FRAMES x H x W - print("Observations shape: ", args.state_shape) - print("Actions shape: ", args.action_shape) + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) # make environments train_envs = SubprocVectorEnv([lambda: make_atari_env(args) for _ in range(args.training_num)]) @@ -79,7 +79,9 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path)) + policy.load_state_dict(torch.load( + args.resume_path, map_location=args.device + )) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM diff --git a/examples/box2d/README.md b/examples/box2d/README.md index 53300dd1c..f438b2a4b 100644 --- a/examples/box2d/README.md +++ b/examples/box2d/README.md @@ -1,6 +1,6 @@ # Bipedal-Hardcore-SAC -- Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below) +- Our default choice: remove the done flag penalty, will soon converge to \~280 reward within 100 epochs (10M env steps, 3~4 hours, see the image below) - If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward) ![](results/sac/BipedalHardcore.png) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 0ddc682f3..bd38524e6 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -6,11 +6,11 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.policy import SACPolicy +from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.policy import SACPolicy -from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -24,8 +24,8 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.1) - parser.add_argument('--auto_alpha', type=int, default=1) - parser.add_argument('--alpha_lr', type=float, default=3e-4) + parser.add_argument('--auto-alpha', type=int, default=1) + parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--collect-per-step', type=int, default=10) @@ -35,54 +35,50 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=int, default=0) - parser.add_argument('--ignore-done', type=int, default=0) parser.add_argument('--n-step', type=int, default=4) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') - parser.add_argument('--resume_path', type=str, default=None) + parser.add_argument('--resume-path', type=str, default=None) return parser.parse_args() -class EnvWrapper(object): - """Env wrapper for reward scale, action repeat and action noise""" +class Wrapper(gym.Wrapper): + """Env wrapper for reward scale, action repeat and removing done penalty""" - def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0): - self._env = gym.make(task) + def __init__(self, env, action_repeat=3, reward_scale=5, rm_done=True): + super().__init__(env) self.action_repeat = action_repeat self.reward_scale = reward_scale - self.act_noise = act_noise - - def __getattr__(self, name): - return getattr(self._env, name) + self.rm_done = rm_done def step(self, action): - # add action noise - action += self.act_noise * (-2 * np.random.random(4) + 1) r = 0.0 for _ in range(self.action_repeat): - obs_, reward_, done_, info_ = self._env.step(action) + obs, reward, done, info = self.env.step(action) # remove done reward penalty - if done_: + if not done or not self.rm_done: + r = r + reward + if done: break - r = r + reward_ # scale reward - return obs_, self.reward_scale * r, done_, info_ + return obs, self.reward_scale * r, done, info def test_sac_bipedal(args=get_args()): - env = EnvWrapper(args.task) + env = Wrapper(gym.make(args.task)) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] - train_envs = SubprocVectorEnv( - [lambda: EnvWrapper(args.task) for _ in range(args.training_num)]) + train_envs = SubprocVectorEnv([ + lambda: Wrapper(gym.make(args.task)) + for _ in range(args.training_num)]) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: EnvWrapper(args.task, reward_scale=1) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([ + lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False) + for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -117,8 +113,6 @@ def test_sac_bipedal(args=get_args()): actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # load a previous policy if args.resume_path: diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index b9481a5cf..3614fb836 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -67,11 +67,13 @@ def test_sac(args=get_args()): args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, - args.action_shape, concat=True, device=args.device) - critic1 = Critic(net, args.device).to(args.device) + net_c1 = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net_c1, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net, args.device).to(args.device) + net_c2 = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic2 = Critic(net_c2, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) if args.auto_alpha: diff --git a/examples/box2d/results/sac/BipedalHardcore.png b/examples/box2d/results/sac/BipedalHardcore.png index 86367e591..d31867fcc 100644 Binary files a/examples/box2d/results/sac/BipedalHardcore.png and b/examples/box2d/results/sac/BipedalHardcore.png differ diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md index 8a63d2729..3f116fd6b 100644 --- a/examples/mujoco/README.md +++ b/examples/mujoco/README.md @@ -1,3 +1,27 @@ -Result of Ant-v2: +# Mujoco Result + + + +## SAC (single run) + +The best reward computes from 100 episodes returns in the test phase. + +SAC on Swimmer-v3 always stops at 47\~48. + +| task | 3M best reward | parameters | time cost (3M) | +| -------------- | ----------------- | ------------------------------------------------------- | -------------- | +| HalfCheetah-v3 | 10157.70 ± 171.70 | `python3 mujoco_sac.py --task HalfCheetah-v3` | 2~3h | +| Walker2d-v3 | 5143.04 ± 15.57 | `python3 mujoco_sac.py --task Walker2d-v3` | 2~3h | +| Hopper-v3 | 3604.19 ± 169.55 | `python3 mujoco_sac.py --task Hopper-v3` | 2~3h | +| Humanoid-v3 | 6579.20 ± 1470.57 | `python3 mujoco_sac.py --task Humanoid-v3 --alpha 0.05` | 2~3h | +| Ant-v3 | 6281.65 ± 686.28 | `python3 mujoco_sac.py --task Ant-v3` | 2~3h | + +![](results/sac/all.png) + +### Which parts are important? + +0. DO NOT share the same network with two critic networks. +1. The sigma (of the Gaussian policy) MUST be conditioned on input. +2. The network size should not be less than 256. +3. The deterministic evaluation helps a lot :) -![](/docs/_static/images/Ant-v2.png) \ No newline at end of file diff --git a/examples/mujoco/ant_v2_sac.py b/examples/mujoco/mujoco_sac.py similarity index 53% rename from examples/mujoco/ant_v2_sac.py rename to examples/mujoco/mujoco_sac.py index 819c7454e..2a80d4fc8 100644 --- a/examples/mujoco/ant_v2_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -16,27 +16,36 @@ def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Ant-v2') + parser.add_argument('--task', type=str, default='Ant-v3') parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--buffer-size', type=int, default=1000000) parser.add_argument('--actor-lr', type=float, default=3e-4) - parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', default=False, action='store_true') + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument('--n-step', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--update-per-step', type=int, default=1) + parser.add_argument('--pre-collect-step', type=int, default=10000) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--hidden-layer-size', type=int, default=256) parser.add_argument('--layer-num', type=int, default=1) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=bool, default=True) + parser.add_argument('--log-interval', type=int, default=1000) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--watch', default=False, action='store_true', + help='watch the play of pre-trained policy only') return parser.parse_args() @@ -45,6 +54,10 @@ def test_sac(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), + np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -57,53 +70,84 @@ def test_sac(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.layer_num, args.state_shape, device=args.device) + net = Net(args.layer_num, args.state_shape, device=args.device, + hidden_layer_size=args.hidden_layer_size) actor = ActorProb( - net, args.action_shape, - args.max_action, args.device, unbounded=True + net, args.action_shape, args.max_action, args.device, unbounded=True, + hidden_layer_size=args.hidden_layer_size, conditioned_sigma=True, ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, - args.action_shape, concat=True, device=args.device) - critic1 = Critic(net, args.device).to(args.device) + net_c1 = Net(args.layer_num, args.state_shape, args.action_shape, + concat=True, device=args.device, + hidden_layer_size=args.hidden_layer_size) + critic1 = Critic( + net_c1, args.device, hidden_layer_size=args.hidden_layer_size + ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net, args.device).to(args.device) + net_c2 = Net(args.layer_num, args.state_shape, args.action_shape, + concat=True, device=args.device, + hidden_layer_size=args.hidden_layer_size) + critic2 = Critic( + net_c2, args.device, hidden_layer_size=args.hidden_layer_size + ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - reward_normalization=args.rew_norm, ignore_done=True) + estimation_step=args.n_step) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load( + args.resume_path, map_location=args.device + )) + print("Loaded agent from: ", args.resume_path) + # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) - # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) + def watch(): + # watch agent's performance + print("Testing agent ...") + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) + pprint.pprint(result) + def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold + return False + + if args.watch: + watch() + exit(0) # trainer + train_collector.collect(n_step=args.pre_collect_step, random=True) result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) - assert stop_fn(result['best_reward']) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - policy.eval() - test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + args.batch_size, args.update_per_step, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + log_interval=args.log_interval) + pprint.pprint(result) + watch() if __name__ == '__main__': diff --git a/examples/mujoco/results/sac/all.png b/examples/mujoco/results/sac/all.png new file mode 100644 index 000000000..7f314f46f Binary files /dev/null and b/examples/mujoco/results/sac/all.png differ diff --git a/examples/mujoco/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py similarity index 100% rename from examples/mujoco/ant_v2_ddpg.py rename to examples/mujoco/runnable/ant_v2_ddpg.py diff --git a/examples/mujoco/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py similarity index 100% rename from examples/mujoco/ant_v2_td3.py rename to examples/mujoco/runnable/ant_v2_td3.py diff --git a/examples/mujoco/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py similarity index 97% rename from examples/mujoco/halfcheetahBullet_v0_sac.py rename to examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index 05591676f..9ff9c48ae 100644 --- a/examples/mujoco/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -71,6 +71,8 @@ def test_sac(args=get_args()): args.action_shape, concat=True, device=args.device) critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( diff --git a/examples/mujoco/mujoco/__init__.py b/examples/mujoco/runnable/mujoco/__init__.py similarity index 100% rename from examples/mujoco/mujoco/__init__.py rename to examples/mujoco/runnable/mujoco/__init__.py diff --git a/examples/mujoco/mujoco/assets/point.xml b/examples/mujoco/runnable/mujoco/assets/point.xml similarity index 100% rename from examples/mujoco/mujoco/assets/point.xml rename to examples/mujoco/runnable/mujoco/assets/point.xml diff --git a/examples/mujoco/mujoco/maze_env_utils.py b/examples/mujoco/runnable/mujoco/maze_env_utils.py similarity index 100% rename from examples/mujoco/mujoco/maze_env_utils.py rename to examples/mujoco/runnable/mujoco/maze_env_utils.py diff --git a/examples/mujoco/mujoco/point.py b/examples/mujoco/runnable/mujoco/point.py similarity index 100% rename from examples/mujoco/mujoco/point.py rename to examples/mujoco/runnable/mujoco/point.py diff --git a/examples/mujoco/mujoco/point_maze_env.py b/examples/mujoco/runnable/mujoco/point_maze_env.py similarity index 100% rename from examples/mujoco/mujoco/point_maze_env.py rename to examples/mujoco/runnable/mujoco/point_maze_env.py diff --git a/examples/mujoco/mujoco/register.py b/examples/mujoco/runnable/mujoco/register.py similarity index 100% rename from examples/mujoco/mujoco/register.py rename to examples/mujoco/runnable/mujoco/register.py diff --git a/examples/mujoco/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py similarity index 97% rename from examples/mujoco/point_maze_td3.py rename to examples/mujoco/runnable/point_maze_td3.py index ff42716c5..a02a4c309 100644 --- a/examples/mujoco/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -72,6 +72,8 @@ def test_td3(args=get_args()): args.action_shape, concat=True, device=args.device) critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 865a1c8e4..1c9b794e3 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -70,11 +70,13 @@ def test_sac_with_il(args=get_args()): net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, - args.action_shape, concat=True, device=args.device) - critic1 = Critic(net, args.device).to(args.device) + net_c1 = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net_c1, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net, args.device).to(args.device) + net_c2 = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic2 = Critic(net_c2, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 847971474..6c4378f6d 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -74,11 +74,13 @@ def test_td3(args=get_args()): args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, - args.action_shape, concat=True, device=args.device) - critic1 = Critic(net, args.device).to(args.device) + net_c1 = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net_c1, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net, args.device).to(args.device) + net_c2 = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic2 = Critic(net_c2, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 4b8607386..fcfa3d1a9 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -62,11 +62,11 @@ def test_discrete_sac(args=get_args()): net = Net(args.layer_num, args.state_shape, device=args.device) actor = Actor(net, args.action_shape, softmax_output=False).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, device=args.device) - critic1 = Critic(net, last_size=args.action_shape).to(args.device) + net_c1 = Net(args.layer_num, args.state_shape, device=args.device) + critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net = Net(args.layer_num, args.state_shape, device=args.device) - critic2 = Critic(net, last_size=args.action_shape).to(args.device) + net_c2 = Net(args.layer_num, args.state_shape, device=args.device) + critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) # better not to use auto alpha in CartPole diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 8d1d72369..014fbe6d5 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -38,6 +38,9 @@ class SACPolicy(DDPGPolicy): defaults to False. :param BaseNoise exploration_noise: add a noise to action for exploration, defaults to None. This is useful when solving hard-exploration problem. + :param bool deterministic_eval: whether to use deterministic action (mean + of Gaussian policy) instead of stochastic action sampled by the policy, + defaults to True. .. seealso:: @@ -63,6 +66,7 @@ def __init__( ignore_done: bool = False, estimation_step: int = 1, exploration_noise: Optional[BaseNoise] = None, + deterministic_eval: bool = True, **kwargs: Any, ) -> None: super().__init__(None, None, None, None, action_range, tau, gamma, @@ -86,6 +90,7 @@ def __init__( else: self._alpha = alpha + self._deterministic_eval = deterministic_eval self.__eps = np.finfo(np.float32).eps.item() def train(self, mode: bool = True) -> "SACPolicy": @@ -116,13 +121,16 @@ def forward( # type: ignore logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) - x = dist.rsample() + if self._deterministic_eval and not self.training: + x = logits[0] + else: + x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) - if self._noise is not None and not self.updating: + if self._noise is not None and self.training and not self.updating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) return Batch( diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 7afd9eb9c..3ab0977b7 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -6,6 +6,10 @@ from tianshou.data import to_torch, to_torch_as +SIGMA_MIN = -20 +SIGMA_MAX = 2 + + class Actor(nn.Module): """Simple actor network with MLP. @@ -89,12 +93,17 @@ def __init__( device: Union[str, int, torch.device] = "cpu", unbounded: bool = False, hidden_layer_size: int = 128, + conditioned_sigma: bool = False, ) -> None: super().__init__() self.preprocess = preprocess_net self.device = device self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape)) - self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) + self._c_sigma = conditioned_sigma + if conditioned_sigma: + self.sigma = nn.Linear(hidden_layer_size, np.prod(action_shape)) + else: + self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) self._max = max_action self._unbounded = unbounded @@ -109,9 +118,14 @@ def forward( mu = self.mu(logits) if not self._unbounded: mu = self._max * torch.tanh(mu) - shape = [1] * len(mu.shape) - shape[1] = -1 - sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() + if self._c_sigma: + sigma = torch.clamp( + self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX + ).exp() + else: + shape = [1] * len(mu.shape) + shape[1] = -1 + sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() return (mu, sigma), state @@ -131,6 +145,7 @@ def __init__( device: Union[str, int, torch.device] = "cpu", unbounded: bool = False, hidden_layer_size: int = 128, + conditioned_sigma: bool = False, ) -> None: super().__init__() self.device = device @@ -141,7 +156,11 @@ def __init__( batch_first=True, ) self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape)) - self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) + self._c_sigma = conditioned_sigma + if conditioned_sigma: + self.sigma = nn.Linear(hidden_layer_size, np.prod(action_shape)) + else: + self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) self._max = max_action self._unbounded = unbounded @@ -170,9 +189,14 @@ def forward( mu = self.mu(logits) if not self._unbounded: mu = self._max * torch.tanh(mu) - shape = [1] * len(mu.shape) - shape[1] = -1 - sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() + if self._c_sigma: + sigma = torch.clamp( + self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX + ).exp() + else: + shape = [1] * len(mu.shape) + shape[1] = -1 + sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() # please ensure the first dim is batch size: [bsz, len, ...] return (mu, sigma), {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}