-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add discrete Conservative Q-Learning for offline RL #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bcbd8f7
3e3f2c4
ebb622b
1da0134
aff80c0
5a18725
ee0aae6
a4b5f04
b76458a
74056cc
e95f3a2
8335a69
8b7d4cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,4 +67,35 @@ We test our BCQ implementation on two example tasks (different from author's ver | |
| Task | Online DQN | Behavioral | BCQ | | ||
| ---------------------- | ---------- | ---------- | --------------------------------- | | ||
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) | | ||
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) | | ||
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) | | ||
|
||
# CQL | ||
|
||
To running CQL algorithm on Atari, you need to do the following things: | ||
|
||
- Train an expert, by using the command listed in the above QRDQN section; | ||
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); | ||
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`. | ||
|
||
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): | ||
|
||
| Task | Online QRDQN | Behavioral | CQL | parameters | | ||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | | ||
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | | ||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | | ||
|
||
We reduce the size of the offline data to 10% and 1% of the above and get: | ||
|
||
Buffer size 100000: | ||
|
||
| Task | Online QRDQN | Behavioral | CQL | parameters | | ||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | | ||
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | | ||
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | | ||
|
||
Buffer size 10000: | ||
|
||
| Task | Online QRDQN | Behavioral | CQL | parameters | | ||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | | ||
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | | ||
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | | ||
Comment on lines
+100
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no need to stop at epoch 5 and 12 -- that's too small and not converge yet. In fact, the original paper indicates that they use ~25 iterations to get the optimal performance in Breakout, is that (25*1M/64/10000=39) epoch in our setting? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that epoch 5 and 12 are arbitrary. I just wanted to compare roughly with the BCQ results above. I reran the two Breakout runs to epoch 40 and the test results were worse than epoch 12. (However, I do think the evaluation protocol of 10 episodes is too small to smooth out the randomness.) The learning curves show that the loss seemed to be behaving but the test rewards were all over the places: I double-checked my implementation with the reference but couldn't find any errors. (It was essentially 3 lines of code on top of the QRDQN anyways.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this tensorboard or other visualization tools? I'm quite curious There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's tensorboard. I usually run it like this: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know. I had many runs in my directory so tensorboard used all kinds of colors and it just happened that the two runs that I cared about had nice colors. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import os | ||
import torch | ||
import pickle | ||
import pprint | ||
import datetime | ||
import argparse | ||
import numpy as np | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.utils import BasicLogger | ||
from tianshou.env import SubprocVectorEnv | ||
from tianshou.trainer import offline_trainer | ||
from tianshou.policy import DiscreteCQLPolicy | ||
from tianshou.data import Collector, VectorReplayBuffer | ||
|
||
from atari_network import QRDQN | ||
from atari_wrapper import wrap_deepmind | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") | ||
parser.add_argument("--seed", type=int, default=1626) | ||
parser.add_argument("--eps-test", type=float, default=0.001) | ||
parser.add_argument("--lr", type=float, default=0.0001) | ||
parser.add_argument("--gamma", type=float, default=0.99) | ||
parser.add_argument('--num-quantiles', type=int, default=200) | ||
parser.add_argument("--n-step", type=int, default=1) | ||
parser.add_argument("--target-update-freq", type=int, default=500) | ||
parser.add_argument("--min-q-weight", type=float, default=10.) | ||
parser.add_argument("--epoch", type=int, default=100) | ||
parser.add_argument("--update-per-epoch", type=int, default=10000) | ||
parser.add_argument("--batch-size", type=int, default=32) | ||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) | ||
parser.add_argument("--test-num", type=int, default=10) | ||
parser.add_argument('--frames-stack', type=int, default=4) | ||
parser.add_argument("--logdir", type=str, default="log") | ||
parser.add_argument("--render", type=float, default=0.) | ||
parser.add_argument("--resume-path", type=str, default=None) | ||
parser.add_argument("--watch", default=False, action="store_true", | ||
help="watch the play of pre-trained policy only") | ||
parser.add_argument("--log-interval", type=int, default=100) | ||
parser.add_argument( | ||
"--load-buffer-name", type=str, | ||
default="./expert_DQN_PongNoFrameskip-v4.hdf5") | ||
parser.add_argument( | ||
"--device", type=str, | ||
default="cuda" if torch.cuda.is_available() else "cpu") | ||
args = parser.parse_known_args()[0] | ||
return args | ||
|
||
|
||
def make_atari_env(args): | ||
return wrap_deepmind(args.task, frame_stack=args.frames_stack) | ||
|
||
|
||
def make_atari_env_watch(args): | ||
return wrap_deepmind(args.task, frame_stack=args.frames_stack, | ||
episode_life=False, clip_rewards=False) | ||
|
||
|
||
def test_discrete_cql(args=get_args()): | ||
# envs | ||
env = make_atari_env(args) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
# should be N_FRAMES x H x W | ||
print("Observations shape:", args.state_shape) | ||
print("Actions shape:", args.action_shape) | ||
# make environments | ||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) | ||
for _ in range(args.test_num)]) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
test_envs.seed(args.seed) | ||
# model | ||
net = QRDQN(*args.state_shape, args.action_shape, | ||
args.num_quantiles, args.device) | ||
optim = torch.optim.Adam(net.parameters(), lr=args.lr) | ||
# define policy | ||
policy = DiscreteCQLPolicy( | ||
net, optim, args.gamma, args.num_quantiles, args.n_step, | ||
args.target_update_freq, min_q_weight=args.min_q_weight | ||
).to(args.device) | ||
# load a previous policy | ||
if args.resume_path: | ||
policy.load_state_dict(torch.load( | ||
args.resume_path, map_location=args.device)) | ||
print("Loaded agent from: ", args.resume_path) | ||
# buffer | ||
assert os.path.exists(args.load_buffer_name), \ | ||
"Please run atari_qrdqn.py first to get expert's data buffer." | ||
if args.load_buffer_name.endswith('.pkl'): | ||
buffer = pickle.load(open(args.load_buffer_name, "rb")) | ||
elif args.load_buffer_name.endswith('.hdf5'): | ||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) | ||
else: | ||
print(f"Unknown buffer format: {args.load_buffer_name}") | ||
exit(0) | ||
|
||
# collector | ||
test_collector = Collector(policy, test_envs, exploration_noise=True) | ||
|
||
# log | ||
log_path = os.path.join( | ||
args.logdir, args.task, 'cql', | ||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') | ||
writer = SummaryWriter(log_path) | ||
writer.add_text("args", str(args)) | ||
logger = BasicLogger(writer, update_interval=args.log_interval) | ||
|
||
def save_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||
|
||
def stop_fn(mean_rewards): | ||
return False | ||
|
||
# watch agent's performance | ||
def watch(): | ||
print("Setup test envs ...") | ||
policy.eval() | ||
policy.set_eps(args.eps_test) | ||
test_envs.seed(args.seed) | ||
print("Testing agent ...") | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=args.test_num, | ||
render=args.render) | ||
pprint.pprint(result) | ||
rew = result["rews"].mean() | ||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') | ||
|
||
if args.watch: | ||
watch() | ||
exit(0) | ||
|
||
result = offline_trainer( | ||
policy, buffer, test_collector, args.epoch, | ||
args.update_per_epoch, args.test_num, args.batch_size, | ||
stop_fn=stop_fn, save_fn=save_fn, logger=logger) | ||
|
||
pprint.pprint(result) | ||
watch() | ||
|
||
|
||
if __name__ == "__main__": | ||
test_discrete_cql(get_args()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import os | ||
import gym | ||
import torch | ||
import pickle | ||
import pprint | ||
import argparse | ||
import numpy as np | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.data import Collector | ||
from tianshou.utils import BasicLogger | ||
from tianshou.env import DummyVectorEnv | ||
from tianshou.utils.net.common import Net | ||
from tianshou.trainer import offline_trainer | ||
from tianshou.policy import DiscreteCQLPolicy | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--task", type=str, default="CartPole-v0") | ||
parser.add_argument("--seed", type=int, default=1626) | ||
parser.add_argument("--eps-test", type=float, default=0.001) | ||
parser.add_argument("--lr", type=float, default=7e-4) | ||
parser.add_argument("--gamma", type=float, default=0.99) | ||
parser.add_argument('--num-quantiles', type=int, default=200) | ||
parser.add_argument("--n-step", type=int, default=3) | ||
parser.add_argument("--target-update-freq", type=int, default=320) | ||
parser.add_argument("--min-q-weight", type=float, default=10.) | ||
parser.add_argument("--epoch", type=int, default=5) | ||
parser.add_argument("--update-per-epoch", type=int, default=1000) | ||
parser.add_argument("--batch-size", type=int, default=64) | ||
parser.add_argument('--hidden-sizes', type=int, | ||
nargs='*', default=[64, 64]) | ||
parser.add_argument("--test-num", type=int, default=100) | ||
parser.add_argument("--logdir", type=str, default="log") | ||
parser.add_argument("--render", type=float, default=0.) | ||
parser.add_argument( | ||
"--load-buffer-name", type=str, | ||
default="./expert_QRDQN_CartPole-v0.pkl", | ||
) | ||
parser.add_argument( | ||
"--device", type=str, | ||
default="cuda" if torch.cuda.is_available() else "cpu", | ||
) | ||
args = parser.parse_known_args()[0] | ||
return args | ||
|
||
|
||
def test_discrete_cql(args=get_args()): | ||
# envs | ||
env = gym.make(args.task) | ||
if args.task == 'CartPole-v0': | ||
env.spec.reward_threshold = 190 # lower the goal | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
test_envs = DummyVectorEnv( | ||
[lambda: gym.make(args.task) for _ in range(args.test_num)]) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
test_envs.seed(args.seed) | ||
# model | ||
net = Net(args.state_shape, args.action_shape, | ||
hidden_sizes=args.hidden_sizes, device=args.device, | ||
softmax=False, num_atoms=args.num_quantiles) | ||
optim = torch.optim.Adam(net.parameters(), lr=args.lr) | ||
|
||
policy = DiscreteCQLPolicy( | ||
net, optim, args.gamma, args.num_quantiles, args.n_step, | ||
args.target_update_freq, min_q_weight=args.min_q_weight | ||
).to(args.device) | ||
# buffer | ||
assert os.path.exists(args.load_buffer_name), \ | ||
"Please run test_dqn.py first to get expert's data buffer." | ||
buffer = pickle.load(open(args.load_buffer_name, "rb")) | ||
|
||
# collector | ||
test_collector = Collector(policy, test_envs, exploration_noise=True) | ||
|
||
log_path = os.path.join(args.logdir, args.task, 'discrete_cql') | ||
writer = SummaryWriter(log_path) | ||
logger = BasicLogger(writer) | ||
|
||
def save_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||
|
||
def stop_fn(mean_rewards): | ||
return mean_rewards >= env.spec.reward_threshold | ||
|
||
result = offline_trainer( | ||
policy, buffer, test_collector, | ||
args.epoch, args.update_per_epoch, args.test_num, args.batch_size, | ||
stop_fn=stop_fn, save_fn=save_fn, logger=logger) | ||
|
||
assert stop_fn(result['best_reward']) | ||
|
||
if __name__ == '__main__': | ||
pprint.pprint(result) | ||
# Let's watch its performance! | ||
env = gym.make(args.task) | ||
policy.eval() | ||
policy.set_eps(args.eps_test) | ||
collector = Collector(policy, env) | ||
result = collector.collect(n_episode=1, render=args.render) | ||
rews, lens = result["rews"], result["lens"] | ||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_discrete_cql(get_args()) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to suggest using
--eps-test 0.9
or--eps-test 0.99
when generating buffer to see what happens because, in the original paper, the author used 1% and 10% of the expert data to train CQL and got a good result (Table 3).Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried
--eps-test 0.9
on Pong and Breakout. Pong can easily achieve +20 reward but it is not so stable; Breakout cannot achieve >20 reward whenmin-q-weight
is either 10 or 50. I also test with only first 10% random data instead of mixed 10% data, the results are the same. Could you please help check what's going wrong?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried
--eps-test 0.9
and--eps-test 0.99
and got terrible results, basically failed to learn. Then I checked this paper where the data were generated and found in Section 6:We train offline QR-DQN and REM with reduced data obtained via randomly subsampling entire trajectories from the logged DQN experiences, thereby maintaining the same data distribution. Figure 6 presents the performance of the offline REM and QR-DQN agents with N% of the tuples in the DQN replay dataset where N ∈ {1, 10, 20, 50, 100}.
So they were simply using less data, as opposed to use worse data. So I generated 1% and 10% of the 1M data, i.e., 10k and 100k, and tuned the parameters a little. In deed, smaller datasets need smaller--min-q-weight
to work. (I initially tried to sample from my 1M buffer data but couldn't get the format right.) These results are within expectations, IMHO.