diff --git a/examples/offline/README.md b/examples/offline/README.md index 04c42686e..61052c7c6 100644 --- a/examples/offline/README.md +++ b/examples/offline/README.md @@ -37,7 +37,7 @@ Tianshou provides an `offline_trainer` for offline reinforcement learning. You c ## Discrete control -For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged). +For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. ### Gather Data @@ -100,3 +100,24 @@ We test our CRR implementation on two example tasks (different from author's ver | BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. + +### RL Unplugged Data + +We provide a script to convert the Atari datasets of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged) to Tianshou ReplayBuffer. + +For example, the following command will download the first shard of the first run of Breakout game to `~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100` then convert it to a `tianshou.data.ReplayBuffer` and save it to `~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5` (use `--dataset-dir` and `--buffer-dir` to change the default directories): + +```bash +python3 convert_rl_unplugged_atari.py --task Breakout --run-id 1 --shard-id 1 +``` + +Then you can use it to train an agent by: + +```bash +python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12 +``` + +Note: + - Each shard contains about 500k transitions. + - This conversion script depends on Tensorflow. + - It takes about 1 hour to process one shard on my machine. YMMV. diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index f398f64e9..c37313d37 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -12,7 +12,7 @@ from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger, WandbLogger @@ -59,6 +59,9 @@ def get_args(): parser.add_argument( "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" ) + parser.add_argument( + "--buffer-from-rl-unplugged", action="store_true", default=False + ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) @@ -120,7 +123,10 @@ def test_discrete_bcq(args=get_args()): if args.load_buffer_name.endswith(".pkl"): buffer = pickle.load(open(args.load_buffer_name, "rb")) elif args.load_buffer_name.endswith(".hdf5"): - buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + if args.buffer_from_rl_unplugged: + buffer = ReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) else: print(f"Unknown buffer format: {args.load_buffer_name}") exit(0) diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py new file mode 100755 index 000000000..d60ddb484 --- /dev/null +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# +# Adapted from +# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py +# +"""Convert Atari RL Unplugged datasets to Tianshou replay buffers. + +Examples in the dataset represent SARSA transitions stored during a +DQN training run as described in https://arxiv.org/pdf/1907.04543. + +For every training run we have recorded all 50 million transitions corresponding +to 200 million environment steps (4x factor because of frame skipping). There +are 5 separate datasets for each of the 45 games. + +Every transition in the dataset is a tuple containing the following features: + +* o_t: Observation at time t. Observations have been processed using the + canonical Atari frame processing, including 4x frame stacking. The shape + of a single observation is [84, 84, 4]. +* a_t: Action taken at time t. +* r_t: Reward after a_t. +* d_t: Discount after a_t. +* o_tp1: Observation at time t+1. +* a_tp1: Action at time t+1. +* extras: + * episode_id: Episode identifier. + * episode_return: Total episode return computed using per-step [-1, 1] + clipping. +""" +import os +from argparse import ArgumentParser + +import requests +import tensorflow as tf +from tqdm import tqdm + +from tianshou.data import Batch, ReplayBuffer + +tf.config.set_visible_devices([], 'GPU') + +# 9 tuning games. +TUNING_SUITE = [ + "BeamRider", + "DemonAttack", + "DoubleDunk", + "IceHockey", + "MsPacman", + "Pooyan", + "RoadRunner", + "Robotank", + "Zaxxon", +] + +# 36 testing games. +TESTING_SUITE = [ + "Alien", + "Amidar", + "Assault", + "Asterix", + "Atlantis", + "BankHeist", + "BattleZone", + "Boxing", + "Breakout", + "Carnival", + "Centipede", + "ChopperCommand", + "CrazyClimber", + "Enduro", + "FishingDerby", + "Freeway", + "Frostbite", + "Gopher", + "Gravitar", + "Hero", + "Jamesbond", + "Kangaroo", + "Krull", + "KungFuMaster", + "NameThisGame", + "Phoenix", + "Pong", + "Qbert", + "Riverraid", + "Seaquest", + "SpaceInvaders", + "StarGunner", + "TimePilot", + "UpNDown", + "VideoPinball", + "WizardOfWor", + "YarsRevenge", +] + +# Total of 45 games. +ALL_GAMES = TUNING_SUITE + TESTING_SUITE +URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari" + + +def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str: + return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}" + + +def _decode_frames(pngs: tf.Tensor) -> tf.Tensor: + """Decode PNGs. + + Args: + pngs: String Tensor of size (4,) containing PNG encoded images. + + Returns: + 4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor. + """ + # Statically unroll png decoding + frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)] + # NOTE: to match tianshou's convention for framestacking + frames = tf.squeeze(tf.stack(frames, axis=0)) + frames.set_shape((4, 84, 84)) + return frames + + +def _make_tianshou_batch( + o_t: tf.Tensor, + a_t: tf.Tensor, + r_t: tf.Tensor, + d_t: tf.Tensor, + o_tp1: tf.Tensor, + a_tp1: tf.Tensor, +) -> Batch: + """Create Tianshou batch with offline data. + + Args: + o_t: Observation at time t. + a_t: Action at time t. + r_t: Reward at time t. + d_t: Discount at time t. + o_tp1: Observation at time t+1. + a_tp1: Action at time t+1. + + Returns: + A tianshou.data.Batch object. + """ + return Batch( + obs=o_t.numpy(), + act=a_t.numpy(), + rew=r_t.numpy(), + done=1 - d_t.numpy(), + obs_next=o_tp1.numpy() + ) + + +def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch: + """Create a tianshou Batch replay sample from a TF example.""" + + # Parse tf.Example. + feature_description = { + "o_t": tf.io.FixedLenFeature([4], tf.string), + "o_tp1": tf.io.FixedLenFeature([4], tf.string), + "a_t": tf.io.FixedLenFeature([], tf.int64), + "a_tp1": tf.io.FixedLenFeature([], tf.int64), + "r_t": tf.io.FixedLenFeature([], tf.float32), + "d_t": tf.io.FixedLenFeature([], tf.float32), + "episode_id": tf.io.FixedLenFeature([], tf.int64), + "episode_return": tf.io.FixedLenFeature([], tf.float32), + } + data = tf.io.parse_single_example(tf_example, feature_description) + + # Process data. + o_t = _decode_frames(data["o_t"]) + o_tp1 = _decode_frames(data["o_tp1"]) + a_t = tf.cast(data["a_t"], tf.int32) + a_tp1 = tf.cast(data["a_tp1"], tf.int32) + + # Build tianshou Batch replay sample. + return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1) + + +# Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 +def download(url: str, fname: str, chunk_size=1024): + resp = requests.get(url, stream=True) + total = int(resp.headers.get('content-length', 0)) + if os.path.exists(fname): + print(f"Found cached file at {fname}.") + return + with open(fname, 'wb') as ofile, tqdm( + desc=fname, + total=total, + unit='iB', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_content(chunk_size=chunk_size): + size = ofile.write(data) + bar.update(size) + + +def process_shard(url: str, fname: str, ofname: str) -> None: + download(url, fname) + file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP") + buffer = ReplayBuffer(500000) + cnt = 0 + for example in file_ds: + batch = _tf_example_to_tianshou_batch(example) + buffer.add(batch) + cnt += 1 + if cnt % 1000 == 0: + print(f"...{cnt}", end="", flush=True) + print("\nReplayBuffer size:", len(buffer)) + buffer.save_hdf5(ofname, compression="gzip") + + +def process_dataset( + task: str, + download_path: str, + dst_path: str, + run_id: int = 1, + shard_id: int = 0, + total_num_shards: int = 100, +) -> None: + fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}" + url = f"{URL_PREFIX}/{fn}" + filepath = f"{download_path}/{fn}" + ofname = f"{dst_path}/{fn}.hdf5" + process_shard(url, filepath, ofname) + + +def main(args): + if args.task not in ALL_GAMES: + raise KeyError(f"`{args.task}` is not in the list of games.") + fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards) + buffer_path = os.path.join(args.buffer_dir, args.task, f"{fn}.hdf5") + if os.path.exists(buffer_path): + raise IOError(f"Found existing buffer at {buffer_path}. Will not overwrite.") + args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir) + args.buffer_dir = os.environ.get("RLU_BUFFER_DIR", args.buffer_dir) + dataset_path = os.path.join(args.dataset_dir, args.task) + os.makedirs(dataset_path, exist_ok=True) + dst_path = os.path.join(args.buffer_dir, args.task) + os.makedirs(dst_path, exist_ok=True) + process_dataset( + args.task, + args.dataset_dir, + args.buffer_dir, + run_id=args.run_id, + shard_id=args.shard_id, + total_num_shards=args.total_num_shards + ) + + +if __name__ == "__main__": + parser = ArgumentParser(usage=__doc__) + parser.add_argument("--task", required=True, help="Name of the Atari game.") + parser.add_argument( + "--run-id", + type=int, + default=1, + help="Run id to download and convert. Value in [1..5]." + ) + parser.add_argument( + "--shard-id", + type=int, + default=0, + help="Shard id to download and convert. Value in [0..99]." + ) + parser.add_argument( + "--total-num-shards", type=int, default=100, help="Total number of shards." + ) + parser.add_argument( + "--dataset-dir", + default=os.path.expanduser("~/.rl_unplugged/datasets"), + help="Directory for downloaded original datasets.", + ) + parser.add_argument( + "--buffer-dir", + default=os.path.expanduser("~/.rl_unplugged/buffers"), + help="Directory for converted replay buffers.", + ) + args = parser.parse_args() + main(args) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index fc783467d..c9aafc7d0 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -86,10 +86,10 @@ def __setattr__(self, key: str, value: Any) -> None: ), "key '{}' is reserved and cannot be assigned".format(key) super().__setattr__(key, value) - def save_hdf5(self, path: str) -> None: + def save_hdf5(self, path: str, compression: Optional[str] = None) -> None: """Save replay buffer to HDF5 file.""" with h5py.File(path, "w") as f: - to_hdf5(self.__dict__, f) + to_hdf5(self.__dict__, f, compression=compression) @classmethod def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 12fa72439..e23143b0d 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -78,13 +78,17 @@ def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]: Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore -def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None: +def to_hdf5( + x: Hdf5ConvertibleType, y: h5py.Group, compression: Optional[str] = None +) -> None: """Copy object into HDF5 group.""" - def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None: + def to_hdf5_via_pickle( + x: object, y: h5py.Group, key: str, compression: Optional[str] = None + ) -> None: """Pickle, convert to numpy array and write to HDF5 dataset.""" data = np.frombuffer(pickle.dumps(x), dtype=np.byte) - y.create_dataset(key, data=data) + y.create_dataset(key, data=data, compression=compression) for k, v in x.items(): if isinstance(v, (Batch, dict)): @@ -95,22 +99,22 @@ def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None: subgrp.attrs["__data_type__"] = "Batch" else: subgrp_data = v - to_hdf5(subgrp_data, subgrp) + to_hdf5(subgrp_data, subgrp, compression=compression) elif isinstance(v, torch.Tensor): # PyTorch tensors are written to datasets - y.create_dataset(k, data=to_numpy(v)) + y.create_dataset(k, data=to_numpy(v), compression=compression) y[k].attrs["__data_type__"] = "Tensor" elif isinstance(v, np.ndarray): try: # NumPy arrays are written to datasets - y.create_dataset(k, data=v) + y.create_dataset(k, data=v, compression=compression) y[k].attrs["__data_type__"] = "ndarray" except TypeError: # If data type is not supported by HDF5 fall back to pickle. # This happens if dtype=object (e.g. due to entries being None) # and possibly in other cases like structured arrays. try: - to_hdf5_via_pickle(v, y, k) + to_hdf5_via_pickle(v, y, k, compression=compression) except Exception as exception: raise RuntimeError( f"Attempted to pickle {v.__class__.__name__} due to " @@ -122,7 +126,7 @@ def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None: y.attrs[k] = v else: # resort to pickle for any other type of object try: - to_hdf5_via_pickle(v, y, k) + to_hdf5_via_pickle(v, y, k, compression=compression) except Exception as exception: raise NotImplementedError( f"No conversion to HDF5 for object of type '{type(v)}' "