这是indexloc提供的服务,不要输入任何密码
Skip to content

add Atari SAC examples #657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,17 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| MsPacmanNoFrameskip-v4 | 1930 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 904 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 2.5e-5` |
| SpaceInvadersNoFrameskip-v4 | 843 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |

# SAC (single run)

One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.1 | ![](results/discrete_sac/Pong_rew.png) | `python3 atari_sac.py --task "PongNoFrameskip-v4"` |
| BreakoutNoFrameskip-v4 | 211.2 | ![](results/discrete_sac/Breakout_rew.png) | `python3 atari_sac.py --task "BreakoutNoFrameskip-v4" --n-step 1 --actor-lr 1e-4 --critic-lr 1e-4` |
| EnduroNoFrameskip-v4 | 1290.7 | ![](results/discrete_sac/Enduro_rew.png) | `python3 atari_sac.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 13157.5 | ![](results/discrete_sac/Qbert_rew.png) | `python3 atari_sac.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 3836 | ![](results/discrete_sac/MsPacman_rew.png) | `python3 atari_sac.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 1772 | ![](results/discrete_sac/Seaquest_rew.png) | `python3 atari_sac.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 649 | ![](results/discrete_sac/SpaceInvaders_rew.png) | `python3 atari_sac.py --task "SpaceInvadersNoFrameskip-v4"` |
2 changes: 1 addition & 1 deletion examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def dist(p):
feature_net.net,
feature_dim,
action_dim,
hidden_sizes=args.hidden_sizes,
hidden_sizes=[args.hidden_size],
device=args.device,
)
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
Expand Down
269 changes: 269 additions & 0 deletions examples/atari/atari_sac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import argparse
import datetime
import os
import pprint

import numpy as np
import torch
from atari_network import DQN
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteSACPolicy, ICMPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=4213)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--actor-lr", type=float, default=1e-5)
parser.add_argument("--critic-lr", type=float, default=1e-5)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--alpha", type=float, default=0.05)
parser.add_argument("--auto-alpha", action="store_true", default=False)
parser.add_argument("--alpha-lr", type=float, default=3e-4)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--rew-norm", type=int, default=False)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument("--save-buffer-name", type=str, default=None)
parser.add_argument(
"--icm-lr-scale",
type=float,
default=0.,
help="use intrinsic curiosity module with this lr scale"
)
parser.add_argument(
"--icm-reward-scale",
type=float,
default=0.01,
help="scaling factor for intrinsic curiosity reward"
)
parser.add_argument(
"--icm-forward-loss-weight",
type=float,
default=0.2,
help="weight for the forward model loss in ICM"
)
return parser.parse_args()


def test_discrete_sac(args=get_args()):
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
args.training_num,
args.test_num,
scale=args.scale_obs,
frame_stack=args.frames_stack,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# define model
net = DQN(
*args.state_shape,
args.action_shape,
device=args.device,
features_only=True,
output_dim=args.hidden_size
)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(net, last_size=args.action_shape, device=args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, last_size=args.action_shape, device=args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

# define policy
if args.auto_alpha:
target_entropy = 0.98 * np.log(np.prod(args.action_shape))
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)

policy = DiscreteSACPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
args.tau,
args.gamma,
args.alpha,
estimation_step=args.n_step,
reward_normalization=args.rew_norm,
).to(args.device)
if args.icm_lr_scale > 0:
feature_net = DQN(
*args.state_shape, args.action_shape, args.device, features_only=True
)
action_dim = np.prod(args.action_shape)
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.net,
feature_dim,
action_dim,
hidden_sizes=[args.hidden_size],
device=args.device,
)
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.actor_lr)
policy = ICMPolicy(
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
args.icm_forward_loss_weight
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "discrete_sac_icm" if args.icm_lr_scale > 0 else "discrete_sac"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(test_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
watch()
exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
stop_fn=stop_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
resume_from_log=args.resume_id is not None,
save_checkpoint_fn=save_checkpoint_fn,
)

pprint.pprint(result)
watch()


if __name__ == "__main__":
test_discrete_sac(get_args())
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/discrete_sac/Pong_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion test/discrete/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_discrete_sac(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 180} # lower the goal
default_reward_threshold = {"CartPole-v0": 170} # lower the goal
args.reward_threshold = default_reward_threshold.get(
args.task, env.spec.reward_threshold
)
Expand Down
5 changes: 4 additions & 1 deletion tianshou/policy/modelfree/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def forward( # type: ignore
obs = batch[input]
logits, hidden = self.actor(obs, state=state, info=batch.info)
dist = Categorical(logits=logits)
act = dist.sample()
if self._deterministic_eval and not self.training:
act = logits.argmax(axis=-1)
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist)

def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
Expand Down