这是indexloc提供的服务,不要输入任何密码
Skip to content

Convert RL Unplugged Atari datasets to tianshou ReplayBuffer #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion examples/offline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Tianshou provides an `offline_trainer` for offline reinforcement learning. You c

## Discrete control

For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged).
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent.

### Gather Data

Expand Down Expand Up @@ -100,3 +100,24 @@ We test our CRR implementation on two example tasks (different from author's ver
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |

Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.

### RL Unplugged Data

We provide a script to convert the Atari datasets of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged) to Tianshou ReplayBuffer.

For example, the following command will download the first shard of the first run of Breakout game to `~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100` then convert it to a `tianshou.data.ReplayBuffer` and save it to `~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5` (use `--dataset-dir` and `--buffer-dir` to change the default directories):

```bash
python3 convert_rl_unplugged_atari.py --task Breakout --run-id 1 --shard-id 1
```

Then you can use it to train an agent by:

```bash
python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12
```

Note:
- Each shard contains about 500k transitions.
- This conversion script depends on Tensorflow.
- It takes about 1 hour to process one shard on my machine. YMMV.
10 changes: 8 additions & 2 deletions examples/offline/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import make_atari_env
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
Expand Down Expand Up @@ -59,6 +59,9 @@ def get_args():
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
)
parser.add_argument(
"--buffer-from-rl-unplugged", action="store_true", default=False
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
Expand Down Expand Up @@ -120,7 +123,10 @@ def test_discrete_bcq(args=get_args()):
if args.load_buffer_name.endswith(".pkl"):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
if args.buffer_from_rl_unplugged:
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
else:
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
Expand Down
278 changes: 278 additions & 0 deletions examples/offline/convert_rl_unplugged_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
#!/usr/bin/env python3
#
# Adapted from
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
#
"""Convert Atari RL Unplugged datasets to Tianshou replay buffers.

Examples in the dataset represent SARSA transitions stored during a
DQN training run as described in https://arxiv.org/pdf/1907.04543.

For every training run we have recorded all 50 million transitions corresponding
to 200 million environment steps (4x factor because of frame skipping). There
are 5 separate datasets for each of the 45 games.

Every transition in the dataset is a tuple containing the following features:

* o_t: Observation at time t. Observations have been processed using the
canonical Atari frame processing, including 4x frame stacking. The shape
of a single observation is [84, 84, 4].
* a_t: Action taken at time t.
* r_t: Reward after a_t.
* d_t: Discount after a_t.
* o_tp1: Observation at time t+1.
* a_tp1: Action at time t+1.
* extras:
* episode_id: Episode identifier.
* episode_return: Total episode return computed using per-step [-1, 1]
clipping.
"""
import os
from argparse import ArgumentParser

import requests
import tensorflow as tf
from tqdm import tqdm

from tianshou.data import Batch, ReplayBuffer

tf.config.set_visible_devices([], 'GPU')

# 9 tuning games.
TUNING_SUITE = [
"BeamRider",
"DemonAttack",
"DoubleDunk",
"IceHockey",
"MsPacman",
"Pooyan",
"RoadRunner",
"Robotank",
"Zaxxon",
]

# 36 testing games.
TESTING_SUITE = [
"Alien",
"Amidar",
"Assault",
"Asterix",
"Atlantis",
"BankHeist",
"BattleZone",
"Boxing",
"Breakout",
"Carnival",
"Centipede",
"ChopperCommand",
"CrazyClimber",
"Enduro",
"FishingDerby",
"Freeway",
"Frostbite",
"Gopher",
"Gravitar",
"Hero",
"Jamesbond",
"Kangaroo",
"Krull",
"KungFuMaster",
"NameThisGame",
"Phoenix",
"Pong",
"Qbert",
"Riverraid",
"Seaquest",
"SpaceInvaders",
"StarGunner",
"TimePilot",
"UpNDown",
"VideoPinball",
"WizardOfWor",
"YarsRevenge",
]

# Total of 45 games.
ALL_GAMES = TUNING_SUITE + TESTING_SUITE
URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari"


def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str:
return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}"


def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
"""Decode PNGs.

Args:
pngs: String Tensor of size (4,) containing PNG encoded images.

Returns:
4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor.
"""
# Statically unroll png decoding
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
# NOTE: to match tianshou's convention for framestacking
frames = tf.squeeze(tf.stack(frames, axis=0))
frames.set_shape((4, 84, 84))
return frames


def _make_tianshou_batch(
o_t: tf.Tensor,
a_t: tf.Tensor,
r_t: tf.Tensor,
d_t: tf.Tensor,
o_tp1: tf.Tensor,
a_tp1: tf.Tensor,
) -> Batch:
"""Create Tianshou batch with offline data.

Args:
o_t: Observation at time t.
a_t: Action at time t.
r_t: Reward at time t.
d_t: Discount at time t.
o_tp1: Observation at time t+1.
a_tp1: Action at time t+1.

Returns:
A tianshou.data.Batch object.
"""
return Batch(
obs=o_t.numpy(),
act=a_t.numpy(),
rew=r_t.numpy(),
done=1 - d_t.numpy(),
obs_next=o_tp1.numpy()
)


def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch:
"""Create a tianshou Batch replay sample from a TF example."""

# Parse tf.Example.
feature_description = {
"o_t": tf.io.FixedLenFeature([4], tf.string),
"o_tp1": tf.io.FixedLenFeature([4], tf.string),
"a_t": tf.io.FixedLenFeature([], tf.int64),
"a_tp1": tf.io.FixedLenFeature([], tf.int64),
"r_t": tf.io.FixedLenFeature([], tf.float32),
"d_t": tf.io.FixedLenFeature([], tf.float32),
"episode_id": tf.io.FixedLenFeature([], tf.int64),
"episode_return": tf.io.FixedLenFeature([], tf.float32),
}
data = tf.io.parse_single_example(tf_example, feature_description)

# Process data.
o_t = _decode_frames(data["o_t"])
o_tp1 = _decode_frames(data["o_tp1"])
a_t = tf.cast(data["a_t"], tf.int32)
a_tp1 = tf.cast(data["a_tp1"], tf.int32)

# Build tianshou Batch replay sample.
return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1)


# Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
def download(url: str, fname: str, chunk_size=1024):
resp = requests.get(url, stream=True)
total = int(resp.headers.get('content-length', 0))
if os.path.exists(fname):
print(f"Found cached file at {fname}.")
return
with open(fname, 'wb') as ofile, tqdm(
desc=fname,
total=total,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = ofile.write(data)
bar.update(size)


def process_shard(url: str, fname: str, ofname: str) -> None:
download(url, fname)
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
buffer = ReplayBuffer(500000)
cnt = 0
for example in file_ds:
batch = _tf_example_to_tianshou_batch(example)
buffer.add(batch)
cnt += 1
Comment on lines +202 to +204
Copy link
Collaborator

@Trinkle23897 Trinkle23897 Apr 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use a batch-style process instead of a for-loop, e.g., directly set buffer._meta.obs = ... (or maybe we should create another API to support this thing?)

For example,

buffer = ReplayBuffer(500000)
# add first batch to initialize memory
batch = ...
buffer.add(batch)
# then directly set, I know it's ugly but this is the general idea
meta = buffer._meta
meta.obs = ...
meta.act = ...
...
buffer._meta = meta

if cnt % 1000 == 0:
print(f"...{cnt}", end="", flush=True)
print("\nReplayBuffer size:", len(buffer))
buffer.save_hdf5(ofname, compression="gzip")


def process_dataset(
task: str,
download_path: str,
dst_path: str,
run_id: int = 1,
shard_id: int = 0,
total_num_shards: int = 100,
) -> None:
fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}"
url = f"{URL_PREFIX}/{fn}"
filepath = f"{download_path}/{fn}"
ofname = f"{dst_path}/{fn}.hdf5"
process_shard(url, filepath, ofname)


def main(args):
if args.task not in ALL_GAMES:
raise KeyError(f"`{args.task}` is not in the list of games.")
fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
buffer_path = os.path.join(args.buffer_dir, args.task, f"{fn}.hdf5")
if os.path.exists(buffer_path):
raise IOError(f"Found existing buffer at {buffer_path}. Will not overwrite.")
args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
args.buffer_dir = os.environ.get("RLU_BUFFER_DIR", args.buffer_dir)
dataset_path = os.path.join(args.dataset_dir, args.task)
os.makedirs(dataset_path, exist_ok=True)
dst_path = os.path.join(args.buffer_dir, args.task)
os.makedirs(dst_path, exist_ok=True)
process_dataset(
args.task,
args.dataset_dir,
args.buffer_dir,
run_id=args.run_id,
shard_id=args.shard_id,
total_num_shards=args.total_num_shards
)


if __name__ == "__main__":
parser = ArgumentParser(usage=__doc__)
parser.add_argument("--task", required=True, help="Name of the Atari game.")
parser.add_argument(
"--run-id",
type=int,
default=1,
help="Run id to download and convert. Value in [1..5]."
)
parser.add_argument(
"--shard-id",
type=int,
default=0,
help="Shard id to download and convert. Value in [0..99]."
)
parser.add_argument(
"--total-num-shards", type=int, default=100, help="Total number of shards."
)
parser.add_argument(
"--dataset-dir",
default=os.path.expanduser("~/.rl_unplugged/datasets"),
help="Directory for downloaded original datasets.",
)
parser.add_argument(
"--buffer-dir",
default=os.path.expanduser("~/.rl_unplugged/buffers"),
help="Directory for converted replay buffers.",
)
args = parser.parse_args()
main(args)
4 changes: 2 additions & 2 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def __setattr__(self, key: str, value: Any) -> None:
), "key '{}' is reserved and cannot be assigned".format(key)
super().__setattr__(key, value)

def save_hdf5(self, path: str) -> None:
def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
"""Save replay buffer to HDF5 file."""
with h5py.File(path, "w") as f:
to_hdf5(self.__dict__, f)
to_hdf5(self.__dict__, f, compression=compression)

@classmethod
def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer":
Expand Down
Loading