这是indexloc提供的服务,不要输入任何密码
Skip to content

fix atari_bcq #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

# BCQ

TODO: after the `done` issue fixed, the result should be re-tuned and place here.

To running BCQ algorithm on Atari, you need to do the following things:

- Train an expert, by using the command listed in the above DQN section;
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.

We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online DQN | Behavioral | BCQ |
| ---------------------- | ---------- | ---------- | --------------------------------- |
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
46 changes: 20 additions & 26 deletions examples/atari/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
from tianshou.policy import DiscreteBCQPolicy
from tianshou.data import Collector, ReplayBuffer
from tianshou.data import Collector, VectorReplayBuffer

from atari_network import DQN
from atari_wrapper import wrap_deepmind
Expand All @@ -25,17 +25,16 @@ def get_args():
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=6.25e-5)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--target-update-freq", type=int, default=8000)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--update-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
Expand All @@ -44,12 +43,10 @@ def get_args():
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
)
default="./expert_DQN_PongNoFrameskip-v4.hdf5")
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -81,33 +78,32 @@ def test_discrete_bcq(args=get_args()):
# model
feature_net = DQN(*args.state_shape, args.action_shape,
device=args.device, features_only=True).to(args.device)
policy_net = Actor(feature_net, args.action_shape,
hidden_sizes=args.hidden_sizes).to(args.device)
imitation_net = Actor(feature_net, args.action_shape,
hidden_sizes=args.hidden_sizes).to(args.device)
policy_net = Actor(
feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
imitation_net = Actor(
feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr,
)
lr=args.lr)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
args.target_update_freq, args.eps_test,
args.unlikely_action_threshold, args.imitation_logits_penalty,
)
args.unlikely_action_threshold, args.imitation_logits_penalty)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_dqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
Expand Down Expand Up @@ -146,11 +142,9 @@ def watch():
exit(0)

result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
log_interval=args.log_interval,
)
policy, buffer, test_collector, args.epoch,
args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)

pprint.pprint(result)
watch()
Expand Down
25 changes: 21 additions & 4 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_args():
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()


Expand Down Expand Up @@ -128,13 +129,29 @@ def test_fn(epoch, env_step):

# watch agent's performance
def watch():
print("Testing agent ...")
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
pprint.pprint(result)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
Expand Down
6 changes: 4 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def watch():
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand All @@ -144,7 +145,8 @@ def watch():
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
Expand Down
25 changes: 21 additions & 4 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_args():
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()


Expand Down Expand Up @@ -126,13 +127,29 @@ def test_fn(epoch, env_step):

# watch agent's performance
def watch():
print("Testing agent ...")
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
pprint.pprint(result)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
Expand Down
6 changes: 4 additions & 2 deletions test/discrete/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,11 @@ def test_fn(epoch, env_step):

# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
collector = Collector(policy, test_envs, buf)
collector.collect(n_step=args.buffer_size)
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())


def test_pdqn(args=get_args()):
Expand Down
10 changes: 6 additions & 4 deletions test/discrete/test_il_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ def get_args():
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--gamma", type=float, default=0.9)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.6)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--update-per-epoch", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128])
nargs='*', default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
Expand All @@ -49,6 +49,8 @@ def get_args():
def test_discrete_bcq(args=get_args()):
# envs
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
Expand Down
6 changes: 3 additions & 3 deletions tianshou/policy/imitation/discrete_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:

return {
"loss": loss.item(),
"q_loss": q_loss.item(),
"i_loss": i_loss.item(),
"reg_loss": reg_loss.item(),
"loss/q": q_loss.item(),
"loss/i": i_loss.item(),
"loss/reg": reg_loss.item(),
}