这是indexloc提供的服务,不要输入任何密码
Skip to content

Add discrete Conservative Q-Learning for offline RL #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 12, 2021
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ Imitation
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.DiscreteCQLPolicy
:members:
:undoc-members:
:show-inheritance:

Model-based
-----------

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
Expand Down
33 changes: 32 additions & 1 deletion examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,35 @@ We test our BCQ implementation on two example tasks (different from author's ver
| Task | Online DQN | Behavioral | BCQ |
| ---------------------- | ---------- | ---------- | --------------------------------- |
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |

# CQL

To running CQL algorithm on Atari, you need to do the following things:

- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`.

We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
Comment on lines +84 to +85
Copy link
Collaborator

@Trinkle23897 Trinkle23897 May 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to suggest using --eps-test 0.9 or --eps-test 0.99 when generating buffer to see what happens because, in the original paper, the author used 1% and 10% of the expert data to train CQL and got a good result (Table 3).

Copy link
Collaborator

@Trinkle23897 Trinkle23897 May 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried --eps-test 0.9 on Pong and Breakout. Pong can easily achieve +20 reward but it is not so stable; Breakout cannot achieve >20 reward when min-q-weight is either 10 or 50. I also test with only first 10% random data instead of mixed 10% data, the results are the same. Could you please help check what's going wrong?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried --eps-test 0.9 and --eps-test 0.99 and got terrible results, basically failed to learn. Then I checked this paper where the data were generated and found in Section 6: We train offline QR-DQN and REM with reduced data obtained via randomly subsampling entire trajectories from the logged DQN experiences, thereby maintaining the same data distribution. Figure 6 presents the performance of the offline REM and QR-DQN agents with N% of the tuples in the DQN replay dataset where N ∈ {1, 10, 20, 50, 100}. So they were simply using less data, as opposed to use worse data. So I generated 1% and 10% of the 1M data, i.e., 10k and 100k, and tuned the parameters a little. In deed, smaller datasets need smaller --min-q-weight to work. (I initially tried to sample from my 1M buffer data but couldn't get the format right.) These results are within expectations, IMHO.


We reduce the size of the offline data to 10% and 1% of the above and get:

Buffer size 100000:

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |

Buffer size 10000:

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
Comment on lines +100 to +101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to stop at epoch 5 and 12 -- that's too small and not converge yet. In fact, the original paper indicates that they use ~25 iterations to get the optimal performance in Breakout, is that (25*1M/64/10000=39) epoch in our setting?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that epoch 5 and 12 are arbitrary. I just wanted to compare roughly with the BCQ results above. I reran the two Breakout runs to epoch 40 and the test results were worse than epoch 12. (However, I do think the evaluation protocol of 10 episodes is too small to smooth out the randomness.) The learning curves show that the loss seemed to be behaving but the test rewards were all over the places:

Screen Shot 2021-05-11 at 11 15 23 AM

I double-checked my implementation with the reference but couldn't find any errors. (It was essentially 3 lines of code on top of the QRDQN anyways.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tensorboard or other visualization tools? I'm quite curious

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's tensorboard. I usually run it like this: tensorboard --logdir log/BreakoutNoFrameskip-v4/cql/ --host $(hostname -i) --port 8086. Then click the server address to open in a browser.

Copy link
Collaborator

@Trinkle23897 Trinkle23897 May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask how to change the css style? Because mine is orange and looks quite different to yours:
2021-05-12 09-09-58 的屏幕截图

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. I had many runs in my directory so tensorboard used all kinds of colors and it just happened that the two runs that I cared about had nice colors.

2 changes: 2 additions & 0 deletions examples/atari/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def watch():
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
Expand Down
147 changes: 147 additions & 0 deletions examples/atari/atari_cql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
import torch
import pickle
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy
from tianshou.data import Collector, VectorReplayBuffer

from atari_network import QRDQN
from atari_wrapper import wrap_deepmind


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--min-q-weight", type=float, default=10.)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--update-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5")
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_known_args()[0]
return args


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)


def test_discrete_cql(args=get_args()):
# envs
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
net = QRDQN(*args.state_shape, args.action_shape,
args.num_quantiles, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DiscreteCQLPolicy(
net, optim, args.gamma, args.num_quantiles, args.n_step,
args.target_update_freq, min_q_weight=args.min_q_weight
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)

# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

# log
log_path = os.path.join(
args.logdir, args.task, 'cql',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return False

# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
exit(0)

result = offline_trainer(
policy, buffer, test_collector, args.epoch,
args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)

pprint.pprint(result)
watch()


if __name__ == "__main__":
test_discrete_cql(get_args())
2 changes: 1 addition & 1 deletion test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
pass_check = 0
break
total_pass += pass_check
if sys.platform != "darwin": # macOS cannot pass this check
if sys.platform == "linux": # Windows/macOS may not pass this check
assert total_pass >= 2


Expand Down
12 changes: 12 additions & 0 deletions test/discrete/test_qrdqn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
Expand Down Expand Up @@ -41,6 +42,9 @@ def get_args():
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str,
default="./expert_QRDQN_CartPole-v0.pkl")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -130,6 +134,14 @@ def test_fn(epoch, env_step):
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")

# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())


def test_pqrdqn(args=get_args()):
args.prioritized_replay = True
Expand Down
110 changes: 110 additions & 0 deletions test/discrete/test_qrdqn_il_cql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=7e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320)
parser.add_argument("--min-q-weight", type=float, default=10.)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_QRDQN_CartPole-v0.pkl",
)
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
return args


def test_discrete_cql(args=get_args()):
# envs
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
softmax=False, num_atoms=args.num_quantiles)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)

policy = DiscreteCQLPolicy(
net, optim, args.gamma, args.num_quantiles, args.n_step,
args.target_update_freq, min_q_weight=args.min_q_weight
).to(args.device)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))

# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)

assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


if __name__ == "__main__":
test_discrete_cql(get_args())
2 changes: 1 addition & 1 deletion tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max) # type: ignore
obs = np.clip(obs, -clip_max, clip_max)
return obs


Expand Down
Loading