这是indexloc提供的服务,不要输入任何密码
Skip to content

Pickle compatible for replay buffer and improve buffer.get #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
import pickle
import pytest
import numpy as np
from timeit import timeit

from tianshou.data import Batch, PrioritizedReplayBuffer, \
ReplayBuffer, SegmentTree
from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer

if __name__ == '__main__':
from env import MyTestEnv
Expand Down Expand Up @@ -39,10 +41,11 @@ def test_replaybuffer(size=10, bufsize=20):

def test_ignore_obs_next(size=10):
# Issue 82
buf = ReplayBuffer(size, ignore_obs_net=True)
buf = ReplayBuffer(size, ignore_obs_next=True)
for i in range(size):
buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]),
'mask2': np.array([i + 4, 0, 1, 0, 0])},
'mask2': np.array([i + 4, 0, 1, 0, 0]),
'mask': i},
act={'act_id': i,
'position_id': i + 3},
rew=i,
Expand All @@ -55,14 +58,30 @@ def test_ignore_obs_next(size=10):
assert isinstance(data, Batch)
assert isinstance(data2, Batch)
assert np.allclose(indice, orig)
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9])
buf.stack_num = 4
data = buf[indice]
data2 = buf[indice]
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
assert np.allclose(data.obs_next.mask, np.array([
[0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
[4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
[7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
assert np.allclose(data.info['if'], data2.info['if'])
assert np.allclose(data.info['if'], np.array([
[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
[7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
assert data.obs_next


def test_stack(size=5, bufsize=9, stack_num=4):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
obs = env.reset(1)
for i in range(15):
for i in range(16):
obs_next, rew, done, info = env.step(1)
buf.add(obs, 1, rew, done, None, info)
buf2.add(obs, 1, rew, done, None, info)
Expand All @@ -73,12 +92,11 @@ def test_stack(size=5, bufsize=9, stack_num=4):
assert np.allclose(buf.get(indice, 'obs'), np.expand_dims(
[[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]], axis=-1))
print(buf)
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1))
_, indice = buf2.sample(0)
assert indice == [2]
assert indice.tolist() == [2, 6]
_, indice = buf2.sample(1)
assert indice.sum() == 2
assert indice in [2, 6]


def test_priortized_replaybuffer(size=32, bufsize=15):
Expand Down Expand Up @@ -107,7 +125,7 @@ def test_update():
buf2 = ReplayBuffer(4, stack_num=2)
for i in range(5):
buf1.add(obs=np.array([i]), act=float(i), rew=i * i,
done=False, info={'incident': 'found'})
done=i % 2 == 0, info={'incident': 'found'})
assert len(buf1) > len(buf2)
buf2.update(buf1)
assert len(buf1) == len(buf2)
Expand Down Expand Up @@ -214,10 +232,38 @@ def sample_tree():
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))


def test_pickle():
size = 100
vbuf = ReplayBuffer(size, stack_num=2)
lbuf = ListReplayBuffer()
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
for i in range(4):
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
for i in range(3):
lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0)
for i in range(5):
pbuf.add(obs=Batch(index=np.array([i])),
act=2, rew=rew, done=0, weight=np.random.rand())
# save & load
_vbuf = pickle.loads(pickle.dumps(vbuf))
_lbuf = pickle.loads(pickle.dumps(lbuf))
_pbuf = pickle.loads(pickle.dumps(pbuf))
assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act)
assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act)
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
# make sure the meta var is identical
assert _vbuf.stack_num == vbuf.stack_num
assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
pbuf.weight[np.arange(len(pbuf))])


if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_pickle()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
Expand Down
94 changes: 52 additions & 42 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ReplayBuffer:
The following code snippet illustrates its usage:
::

>>> import numpy as np
>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
Expand All @@ -35,6 +35,7 @@ class ReplayBuffer:
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
Expand All @@ -54,6 +55,11 @@ class ReplayBuffer:
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3

:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
Expand Down Expand Up @@ -119,6 +125,7 @@ def __init__(self, size: int, stack_num: int = 1,
sample_avail: bool = False, **kwargs) -> None:
super().__init__()
self._maxsize = size
self._indices = np.arange(size)
self._stack = None
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
Expand All @@ -137,9 +144,18 @@ def __repr__(self) -> str:
"""Return str(self)."""
return self.__class__.__name__ + self._meta.__repr__()[5:]

def __getattr__(self, key: str) -> Union['Batch', Any]:
def __getattr__(self, key: str) -> Any:
"""Return self.key"""
return self._meta.__dict__[key]
try:
return self._meta[key]
except KeyError as e:
raise AttributeError from e

def __setstate__(self, state):
"""Unpickling interface. We need it because pickling buffer does not
work out-of-the-box (``buffer.__getattr__`` is customized).
"""
self.__dict__.update(state)

def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
Expand All @@ -149,9 +165,8 @@ def _add_to_buffer(self, name: str, inst: Any) -> None:
value = self._meta.__dict__[name]
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
raise ValueError(
"Cannot add data to a buffer with different shape, key: "
f"{name}, expect shape: {value.shape[1:]}, "
f"given shape: {inst.shape}.")
"Cannot add data to a buffer with different shape, with key "
f"{name}, expect {value.shape[1:]}, given {inst.shape}.")
try:
value[self._index] = inst
except KeyError:
Expand Down Expand Up @@ -261,47 +276,42 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
"""
if stack_num is None:
stack_num = self.stack_num
if isinstance(indice, slice):
indice = np.arange(
0 if indice.start is None
else self._size - indice.start if indice.start < 0
else indice.start,
self._size if indice.stop is None
else self._size - indice.stop if indice.stop < 0
else indice.stop,
1 if indice.step is None else indice.step)
else:
indice = np.array(indice, copy=True)
# set last frame done to True
last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True
if key == 'obs_next' and (not self._save_s_ or self.obs_next is None):
indice += 1 - self.done[indice].astype(np.int)
if stack_num == 1: # the most often case
if key != 'obs_next' or self._save_s_:
val = self._meta.__dict__[key]
try:
return val[indice]
except IndexError as e:
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()
indice = self._indices[:self._size][indice]
done = self._meta.__dict__['done']
if key == 'obs_next' and not self._save_s_:
indice += 1 - done[indice].astype(np.int)
indice[indice == self._size] = 0
key = 'obs'
val = self._meta.__dict__[key]
try:
if stack_num > 1:
stack = []
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)
pre_indice[pre_indice == -1] = self._size - 1
indice = np.asarray(
pre_indice + self.done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
if isinstance(val, Batch):
stack = Batch.stack(stack, axis=indice.ndim)
else:
stack = np.stack(stack, axis=indice.ndim)
if stack_num == 1:
return val[indice]
stack = []
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)
pre_indice[pre_indice == -1] = self._size - 1
indice = np.asarray(
pre_indice + done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
if isinstance(val, Batch):
stack = Batch.stack(stack, axis=indice.ndim)
else:
stack = val[indice]
stack = np.stack(stack, axis=indice.ndim)
return stack
except IndexError as e:
stack = Batch()
if not isinstance(val, Batch) or len(val.__dict__) > 0:
raise e
self.done[last_index] = last_done
return stack
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()

def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
Expand Down Expand Up @@ -380,7 +390,7 @@ def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key"""
if key == 'weight':
return self._weight
return self._meta.__dict__[key]
return super().__getattr__(key)

def add(self,
obs: Union[dict, np.ndarray],
Expand Down
3 changes: 2 additions & 1 deletion tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def post_process_fn(self, batch: Batch,
usage is to update the sampling weight in prioritized experience
replay. Check out :ref:`policy_concept` for more information.
"""
if isinstance(buffer, PrioritizedReplayBuffer):
if isinstance(buffer, PrioritizedReplayBuffer) \
and hasattr(batch, 'weight'):
buffer.update_weight(indice, batch.weight)

def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
Expand Down