-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add Rainbow DQN #386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Rainbow DQN #386
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
8c0a4c0
implement Rainbow DQN
9d8f565
make linter happy
45d8bb1
make mypy happy
46ed079
fix a bug about #381
Trinkle23897 403404e
address review comments
1a44548
add a test for rainbow
a22f474
Merge branch 'master' into rainbow
Trinkle23897 2762945
fix documentation
7f7b136
control the timing of sampling noises
05ac8f2
fix a bug in noisy linear
45874a4
fix doc and test
14c9c18
update exp results
nuance1979 5e6c46d
make pydocstyle happy
nuance1979 03d3f73
minor fix
f85a584
minor fix about sample_noise on model_old
c2c12ce
remove eps hack in prio buffer
f9d4347
revert eps hack and scale weights instead
7900450
remove weight scaling by magic number in favor of weight normalization
1ea40a1
fix test failure
3772c0f
use np.max() to maximize compatibility
42b4023
move weight norm to the policy side
104d476
move weight norm back to buffer side as an option
a3fc666
anneal beta parameter of prio buffer
18c1391
cosmetic change
5641d37
change beta annealing schedule
9a458d0
update current rainbow results; still bad on some tasks
f178b0e
fix a minor bug
d16dbb9
Merge branch 'master' into rainbow
Trinkle23897 0ed3f21
separate log dirs
a204fda
Merge branch 'master' into rainbow
nuance1979 4cb94f2
Merge branch 'rainbow' of https://github.com/nuance1979/tianshou into…
nuance1979 ed8552f
update results
96a5b86
update plots
nuance1979 8599d1e
Merge branch 'master' into rainbow
nuance1979 2211296
fix test failure
nuance1979 f2384eb
fix test failure again
nuance1979 5ecf6d3
fix more test failure
nuance1979 4d2debf
fix a bug about explore_noise
c946542
update plots
nuance1979 a040a6d
make linter happy
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import os | ||
import torch | ||
import pprint | ||
import datetime | ||
import argparse | ||
import numpy as np | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.policy import RainbowPolicy | ||
from tianshou.utils import BasicLogger | ||
from tianshou.env import SubprocVectorEnv | ||
from tianshou.trainer import offpolicy_trainer | ||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer | ||
|
||
from atari_network import Rainbow | ||
from atari_wrapper import wrap_deepmind | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') | ||
parser.add_argument('--seed', type=int, default=0) | ||
parser.add_argument('--eps-test', type=float, default=0.005) | ||
parser.add_argument('--eps-train', type=float, default=1.) | ||
parser.add_argument('--eps-train-final', type=float, default=0.05) | ||
parser.add_argument('--buffer-size', type=int, default=100000) | ||
parser.add_argument('--lr', type=float, default=0.0000625) | ||
parser.add_argument('--gamma', type=float, default=0.99) | ||
parser.add_argument('--num-atoms', type=int, default=51) | ||
parser.add_argument('--v-min', type=float, default=-10.) | ||
parser.add_argument('--v-max', type=float, default=10.) | ||
parser.add_argument('--noisy-std', type=float, default=0.1) | ||
parser.add_argument('--no-dueling', action='store_true', default=False) | ||
parser.add_argument('--no-noisy', action='store_true', default=False) | ||
parser.add_argument('--no-priority', action='store_true', default=False) | ||
parser.add_argument('--alpha', type=float, default=0.5) | ||
parser.add_argument('--beta', type=float, default=0.4) | ||
parser.add_argument('--beta-final', type=float, default=1.) | ||
parser.add_argument('--beta-anneal-step', type=int, default=5000000) | ||
parser.add_argument('--no-weight-norm', action='store_true', default=False) | ||
parser.add_argument('--n-step', type=int, default=3) | ||
parser.add_argument('--target-update-freq', type=int, default=500) | ||
parser.add_argument('--epoch', type=int, default=100) | ||
parser.add_argument('--step-per-epoch', type=int, default=100000) | ||
parser.add_argument('--step-per-collect', type=int, default=10) | ||
parser.add_argument('--update-per-step', type=float, default=0.1) | ||
parser.add_argument('--batch-size', type=int, default=32) | ||
parser.add_argument('--training-num', type=int, default=10) | ||
parser.add_argument('--test-num', type=int, default=10) | ||
parser.add_argument('--logdir', type=str, default='log') | ||
parser.add_argument('--render', type=float, default=0.) | ||
parser.add_argument( | ||
'--device', type=str, | ||
default='cuda' if torch.cuda.is_available() else 'cpu') | ||
parser.add_argument('--frames-stack', type=int, default=4) | ||
parser.add_argument('--resume-path', type=str, default=None) | ||
parser.add_argument('--watch', default=False, action='store_true', | ||
help='watch the play of pre-trained policy only') | ||
parser.add_argument('--save-buffer-name', type=str, default=None) | ||
return parser.parse_args() | ||
|
||
|
||
def make_atari_env(args): | ||
return wrap_deepmind(args.task, frame_stack=args.frames_stack) | ||
|
||
|
||
def make_atari_env_watch(args): | ||
return wrap_deepmind(args.task, frame_stack=args.frames_stack, | ||
episode_life=False, clip_rewards=False) | ||
|
||
|
||
def test_rainbow(args=get_args()): | ||
env = make_atari_env(args) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
# should be N_FRAMES x H x W | ||
print("Observations shape:", args.state_shape) | ||
print("Actions shape:", args.action_shape) | ||
# make environments | ||
train_envs = SubprocVectorEnv([lambda: make_atari_env(args) | ||
for _ in range(args.training_num)]) | ||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) | ||
for _ in range(args.test_num)]) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
train_envs.seed(args.seed) | ||
test_envs.seed(args.seed) | ||
# define model | ||
net = Rainbow(*args.state_shape, args.action_shape, | ||
args.num_atoms, args.noisy_std, args.device, | ||
is_dueling=not args.no_dueling, | ||
is_noisy=not args.no_noisy) | ||
optim = torch.optim.Adam(net.parameters(), lr=args.lr) | ||
# define policy | ||
policy = RainbowPolicy( | ||
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, | ||
args.n_step, target_update_freq=args.target_update_freq | ||
).to(args.device) | ||
# load a previous policy | ||
if args.resume_path: | ||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) | ||
print("Loaded agent from: ", args.resume_path) | ||
# replay buffer: `save_last_obs` and `stack_num` can be removed together | ||
# when you have enough RAM | ||
if args.no_priority: | ||
buffer = VectorReplayBuffer( | ||
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, | ||
save_only_last_obs=True, stack_num=args.frames_stack) | ||
else: | ||
buffer = PrioritizedVectorReplayBuffer( | ||
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, | ||
save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha, | ||
beta=args.beta, weight_norm=not args.no_weight_norm) | ||
# collector | ||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) | ||
test_collector = Collector(policy, test_envs, exploration_noise=True) | ||
# log | ||
log_path = os.path.join( | ||
args.logdir, args.task, 'rainbow', | ||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') | ||
writer = SummaryWriter(log_path) | ||
writer.add_text("args", str(args)) | ||
logger = BasicLogger(writer) | ||
|
||
def save_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||
|
||
def stop_fn(mean_rewards): | ||
if env.spec.reward_threshold: | ||
return mean_rewards >= env.spec.reward_threshold | ||
elif 'Pong' in args.task: | ||
return mean_rewards >= 20 | ||
else: | ||
return False | ||
|
||
def train_fn(epoch, env_step): | ||
# nature DQN setting, linear decay in the first 1M steps | ||
if env_step <= 1e6: | ||
eps = args.eps_train - env_step / 1e6 * \ | ||
(args.eps_train - args.eps_train_final) | ||
else: | ||
eps = args.eps_train_final | ||
policy.set_eps(eps) | ||
logger.write('train/eps', env_step, eps) | ||
if not args.no_priority: | ||
if env_step <= args.beta_anneal_step: | ||
beta = args.beta - env_step / args.beta_anneal_step * \ | ||
(args.beta - args.beta_final) | ||
else: | ||
beta = args.beta_final | ||
buffer.set_beta(beta) | ||
logger.write('train/beta', env_step, beta) | ||
|
||
def test_fn(epoch, env_step): | ||
policy.set_eps(args.eps_test) | ||
|
||
# watch agent's performance | ||
def watch(): | ||
print("Setup test envs ...") | ||
policy.eval() | ||
policy.set_eps(args.eps_test) | ||
test_envs.seed(args.seed) | ||
if args.save_buffer_name: | ||
print(f"Generate buffer with size {args.buffer_size}") | ||
buffer = PrioritizedVectorReplayBuffer( | ||
args.buffer_size, buffer_num=len(test_envs), | ||
ignore_obs_next=True, save_only_last_obs=True, | ||
stack_num=args.frames_stack, alpha=args.alpha, | ||
beta=args.beta) | ||
collector = Collector(policy, test_envs, buffer, | ||
exploration_noise=True) | ||
result = collector.collect(n_step=args.buffer_size) | ||
print(f"Save buffer into {args.save_buffer_name}") | ||
# Unfortunately, pickle will cause oom with 1M buffer size | ||
buffer.save_hdf5(args.save_buffer_name) | ||
else: | ||
print("Testing agent ...") | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=args.test_num, | ||
render=args.render) | ||
rew = result["rews"].mean() | ||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') | ||
|
||
if args.watch: | ||
watch() | ||
exit(0) | ||
|
||
# test train_collector and start filling replay buffer | ||
train_collector.collect(n_step=args.batch_size * args.training_num) | ||
# trainer | ||
result = offpolicy_trainer( | ||
policy, train_collector, test_collector, args.epoch, | ||
args.step_per_epoch, args.step_per_collect, args.test_num, | ||
args.batch_size, train_fn=train_fn, test_fn=test_fn, | ||
stop_fn=stop_fn, save_fn=save_fn, logger=logger, | ||
update_per_step=args.update_per_step, test_in_train=False) | ||
|
||
pprint.pprint(result) | ||
watch() | ||
|
||
|
||
if __name__ == '__main__': | ||
test_rainbow(get_args()) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.