diff --git a/examples/atari/README.md b/examples/atari/README.md index 281f72ea1..24840f25d 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -56,10 +56,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. # BCQ -TODO: after the `done` issue fixed, the result should be re-tuned and place here. - To running BCQ algorithm on Atari, you need to do the following things: - Train an expert, by using the command listed in the above DQN section; - Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); - Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`. + +We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online DQN | Behavioral | BCQ | +| ---------------------- | ---------- | ---------- | --------------------------------- | +| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) | +| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) | \ No newline at end of file diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index 05633bc4a..2c42c46c7 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -12,7 +12,7 @@ from tianshou.trainer import offline_trainer from tianshou.utils.net.discrete import Actor from tianshou.policy import DiscreteBCQPolicy -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from atari_network import DQN from atari_wrapper import wrap_deepmind @@ -25,17 +25,16 @@ def get_args(): parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=6.25e-5) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n-step", type=int, default=1) parser.add_argument("--target-update-freq", type=int, default=8000) parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[512]) - parser.add_argument("--test-num", type=int, default=100) - parser.add_argument('--frames_stack', type=int, default=4) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) @@ -44,12 +43,10 @@ def get_args(): parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( "--load-buffer-name", type=str, - default="./expert_DQN_PongNoFrameskip-v4.hdf5", - ) + default="./expert_DQN_PongNoFrameskip-v4.hdf5") parser.add_argument( "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - ) + default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_known_args()[0] return args @@ -81,25 +78,24 @@ def test_discrete_bcq(args=get_args()): # model feature_net = DQN(*args.state_shape, args.action_shape, device=args.device, features_only=True).to(args.device) - policy_net = Actor(feature_net, args.action_shape, - hidden_sizes=args.hidden_sizes).to(args.device) - imitation_net = Actor(feature_net, args.action_shape, - hidden_sizes=args.hidden_sizes).to(args.device) + policy_net = Actor( + feature_net, args.action_shape, device=args.device, + hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) + imitation_net = Actor( + feature_net, args.action_shape, device=args.device, + hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) optim = torch.optim.Adam( set(policy_net.parameters()).union(imitation_net.parameters()), - lr=args.lr, - ) + lr=args.lr) # define policy policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, args.target_update_freq, args.eps_test, - args.unlikely_action_threshold, args.imitation_logits_penalty, - ) + args.unlikely_action_threshold, args.imitation_logits_penalty) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device - )) + args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -107,7 +103,7 @@ def test_discrete_bcq(args=get_args()): if args.load_buffer_name.endswith('.pkl'): buffer = pickle.load(open(args.load_buffer_name, "rb")) elif args.load_buffer_name.endswith('.hdf5'): - buffer = ReplayBuffer.load_hdf5(args.load_buffer_name) + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") exit(0) @@ -146,11 +142,9 @@ def watch(): exit(0) result = offline_trainer( - policy, buffer, test_collector, - args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - log_interval=args.log_interval, - ) + policy, buffer, test_collector, args.epoch, + args.update_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, logger=logger) pprint.pprint(result) watch() diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 0956d691c..50b1a83b9 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -46,6 +46,7 @@ def get_args(): parser.add_argument('--resume-path', type=str, default=None) parser.add_argument('--watch', default=False, action='store_true', help='watch the play of pre-trained policy only') + parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -128,13 +129,29 @@ def test_fn(epoch, env_step): # watch agent's performance def watch(): - print("Testing agent ...") + print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(test_envs), + ignore_obs_next=True, save_only_last_obs=True, + stack_num=args.frames_stack) + collector = Collector(policy, test_envs, buffer, + exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, + render=args.render) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') if args.watch: watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index b3f36c893..20bb94f30 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -134,7 +134,8 @@ def watch(): args.buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer) + collector = Collector(policy, test_envs, buffer, + exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -144,7 +145,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') if args.watch: watch() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index ae2a26f4f..1aa7af9da 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -44,6 +44,7 @@ def get_args(): parser.add_argument('--resume-path', type=str, default=None) parser.add_argument('--watch', default=False, action='store_true', help='watch the play of pre-trained policy only') + parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -126,13 +127,29 @@ def test_fn(epoch, env_step): # watch agent's performance def watch(): - print("Testing agent ...") + print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(test_envs), + ignore_obs_next=True, save_only_last_obs=True, + stack_num=args.frames_stack) + collector = Collector(policy, test_envs, buffer, + exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, + render=args.render) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') if args.watch: watch() diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index c59910e84..bd88b2b5b 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -136,9 +136,11 @@ def test_fn(epoch, env_step): # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) - collector = Collector(policy, test_envs, buf) - collector.collect(n_step=args.buffer_size) + policy.set_eps(0.2) + collector = Collector(policy, test_envs, buf, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) pickle.dump(buf, open(args.save_buffer_name, "wb")) + print(result["rews"].mean()) def test_pdqn(args=get_args()): diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 996f7d599..c3bebb53a 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -21,16 +21,16 @@ def get_args(): parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=3e-4) - parser.add_argument("--gamma", type=float, default=0.9) + parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) - parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) + parser.add_argument("--unlikely-action-threshold", type=float, default=0.6) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--update-per-epoch", type=int, default=1000) + parser.add_argument("--update-per-epoch", type=int, default=2000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128]) + nargs='*', default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) @@ -49,6 +49,8 @@ def get_args(): def test_discrete_bcq(args=get_args()): # envs env = gym.make(args.task) + if args.task == 'CartPole-v0': + env.spec.reward_threshold = 190 # lower the goal args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 38ae0c4f9..5ea68473a 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -118,7 +118,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: return { "loss": loss.item(), - "q_loss": q_loss.item(), - "i_loss": i_loss.item(), - "reg_loss": reg_loss.item(), + "loss/q": q_loss.item(), + "loss/i": i_loss.item(), + "loss/reg": reg_loss.item(), }