这是indexloc提供的服务,不要输入任何密码
Skip to content

SAC mujoco result #246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed docs/_static/images/Ant-v2.png
Binary file not shown.
Binary file modified docs/_static/images/concepts_arch2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_dqn(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape: ", args.state_shape)
print("Actions shape: ", args.action_shape)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
Expand All @@ -79,7 +79,9 @@ def test_dqn(args=get_args()):
target_update_freq=args.target_update_freq)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path))
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
Expand Down
2 changes: 1 addition & 1 deletion examples/box2d/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Bipedal-Hardcore-SAC

- Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
- Our default choice: remove the done flag penalty, will soon converge to \~280 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)

![](results/sac/BipedalHardcore.png)
50 changes: 22 additions & 28 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import SACPolicy
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import SACPolicy
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic


Expand All @@ -24,8 +24,8 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.1)
parser.add_argument('--auto_alpha', type=int, default=1)
parser.add_argument('--alpha_lr', type=float, default=3e-4)
parser.add_argument('--auto-alpha', type=int, default=1)
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
Expand All @@ -35,54 +35,50 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=int, default=0)
parser.add_argument('--ignore-done', type=int, default=0)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--resume-path', type=str, default=None)
return parser.parse_args()


class EnvWrapper(object):
"""Env wrapper for reward scale, action repeat and action noise"""
class Wrapper(gym.Wrapper):
"""Env wrapper for reward scale, action repeat and removing done penalty"""

def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0):
self._env = gym.make(task)
def __init__(self, env, action_repeat=3, reward_scale=5, rm_done=True):
super().__init__(env)
self.action_repeat = action_repeat
self.reward_scale = reward_scale
self.act_noise = act_noise

def __getattr__(self, name):
return getattr(self._env, name)
self.rm_done = rm_done

def step(self, action):
# add action noise
action += self.act_noise * (-2 * np.random.random(4) + 1)
r = 0.0
for _ in range(self.action_repeat):
obs_, reward_, done_, info_ = self._env.step(action)
obs, reward, done, info = self.env.step(action)
# remove done reward penalty
if done_:
if not done or not self.rm_done:
r = r + reward
if done:
break
r = r + reward_
# scale reward
return obs_, self.reward_scale * r, done_, info_
return obs, self.reward_scale * r, done, info


def test_sac_bipedal(args=get_args()):
env = EnvWrapper(args.task)
env = Wrapper(gym.make(args.task))

args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]

train_envs = SubprocVectorEnv(
[lambda: EnvWrapper(args.task) for _ in range(args.training_num)])
train_envs = SubprocVectorEnv([
lambda: Wrapper(gym.make(args.task))
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv([lambda: EnvWrapper(args.task, reward_scale=1)
for _ in range(args.test_num)])
test_envs = SubprocVectorEnv([
lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False)
for _ in range(args.test_num)])

# seed
np.random.seed(args.seed)
Expand Down Expand Up @@ -117,8 +113,6 @@ def test_sac_bipedal(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# load a previous policy
if args.resume_path:
Expand Down
10 changes: 6 additions & 4 deletions examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ def test_sac(args=get_args()):
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, args.device).to(args.device)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

if args.auto_alpha:
Expand Down
Binary file modified examples/box2d/results/sac/BipedalHardcore.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 26 additions & 2 deletions examples/mujoco/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
Result of Ant-v2:
# Mujoco Result



## SAC (single run)

The best reward computes from 100 episodes returns in the test phase.

SAC on Swimmer-v3 always stops at 47\~48.

| task | 3M best reward | parameters | time cost (3M) |
| -------------- | ----------------- | ------------------------------------------------------- | -------------- |
| HalfCheetah-v3 | 10157.70 ± 171.70 | `python3 mujoco_sac.py --task HalfCheetah-v3` | 2~3h |
| Walker2d-v3 | 5143.04 ± 15.57 | `python3 mujoco_sac.py --task Walker2d-v3` | 2~3h |
| Hopper-v3 | 3604.19 ± 169.55 | `python3 mujoco_sac.py --task Hopper-v3` | 2~3h |
| Humanoid-v3 | 6579.20 ± 1470.57 | `python3 mujoco_sac.py --task Humanoid-v3 --alpha 0.05` | 2~3h |
| Ant-v3 | 6281.65 ± 686.28 | `python3 mujoco_sac.py --task Ant-v3` | 2~3h |

![](results/sac/all.png)

### Which parts are important?

0. DO NOT share the same network with two critic networks.
1. The sigma (of the Gaussian policy) MUST be conditioned on input.
2. The network size should not be less than 256.
3. The deterministic evaluation helps a lot :)

![](/docs/_static/images/Ant-v2.png)
102 changes: 73 additions & 29 deletions examples/mujoco/ant_v2_sac.py → examples/mujoco/mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,36 @@

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Ant-v2')
parser.add_argument('--task', type=str, default='Ant-v3')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--buffer-size', type=int, default=1000000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', default=False, action='store_true')
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--n-step', type=int, default=2)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=4)
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--pre-collect-step', type=int, default=10000)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-layer-size', type=int, default=256)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument('--log-interval', type=int, default=1000)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


Expand All @@ -45,6 +54,10 @@ def test_sac(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
Expand All @@ -57,53 +70,84 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.layer_num, args.state_shape, device=args.device,
hidden_layer_size=args.hidden_layer_size)
actor = ActorProb(
net, args.action_shape,
args.max_action, args.device, unbounded=True
net, args.action_shape, args.max_action, args.device, unbounded=True,
hidden_layer_size=args.hidden_layer_size, conditioned_sigma=True,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net_c1 = Net(args.layer_num, args.state_shape, args.action_shape,
concat=True, device=args.device,
hidden_layer_size=args.hidden_layer_size)
critic1 = Critic(
net_c1, args.device, hidden_layer_size=args.hidden_layer_size
).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, args.device).to(args.device)
net_c2 = Net(args.layer_num, args.state_shape, args.action_shape,
concat=True, device=args.device,
hidden_layer_size=args.hidden_layer_size)
critic2 = Critic(
net_c2, args.device, hidden_layer_size=args.hidden_layer_size
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)

policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, ignore_done=True)
estimation_step=args.n_step)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)

# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)

def watch():
# watch agent's performance
print("Testing agent ...")
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
return False

if args.watch:
watch()
exit(0)

# trainer
train_collector.collect(n_step=args.pre_collect_step, random=True)
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
args.batch_size, args.update_per_step,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
log_interval=args.log_interval)
pprint.pprint(result)
watch()


if __name__ == '__main__':
Expand Down
Binary file added examples/mujoco/results/sac/all.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def test_sac(args=get_args()):
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def test_td3(args=get_args()):
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
Expand Down
10 changes: 6 additions & 4 deletions test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ def test_sac_with_il(args=get_args()):
net, args.action_shape, args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, args.device).to(args.device)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
Expand Down
Loading