diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 7eaf69004..1a831d75c 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -166,3 +166,13 @@ isort yapf pydocstyle Args +tuples +tuple +Multi +multi +parameterized +Proximal +metadata +GPU +Dopamine +builtin diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 203d8e4fb..3ffdeccb9 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1336,6 +1336,65 @@ def test_from_data(): os.remove(path) +def test_custom_key(): + batch = Batch( + **{ + 'obs_next': + np.array( + [ + [ + 1.174, -0.1151, -0.609, -0.5205, -0.9316, 3.236, -2.418, 0.386, + 0.2227, -0.5117, 2.293 + ] + ] + ), + 'rew': + np.array([4.28125]), + 'act': + np.array([[-0.3088, -0.4636, 0.4956]]), + 'truncated': + np.array([False]), + 'obs': + np.array( + [ + [ + 1.193, -0.1203, -0.6123, -0.519, -0.9434, 3.32, -2.266, 0.9116, + 0.623, 0.1259, 0.363 + ] + ] + ), + 'terminated': + np.array([False]), + 'done': + np.array([False]), + 'returns': + np.array([74.70343082]), + 'info': + Batch(), + 'policy': + Batch(), + } + ) + buffer_size = len(batch.rew) + buffer = ReplayBuffer(buffer_size) + buffer.add(batch) + sampled_batch, _ = buffer.sample(1) + # Check if they have the same keys + assert set(batch.keys()) == set(sampled_batch.keys()), \ + "Batches have different keys: {} and {}".format( + set(batch.keys()), set(sampled_batch.keys())) + # Compare the values for each key + for key in batch.keys(): + if isinstance(batch.__dict__[key], np.ndarray + ) and isinstance(sampled_batch.__dict__[key], np.ndarray): + assert np.allclose(batch.__dict__[key], sampled_batch.__dict__[key]), \ + "Value mismatch for key: {}".format(key) + if isinstance(batch.__dict__[key], + Batch) and isinstance(sampled_batch.__dict__[key], Batch): + assert batch.__dict__[key].is_empty() + assert sampled_batch.__dict__[key].is_empty() + + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() @@ -1351,3 +1410,4 @@ def test_from_data(): test_multibuf_hdf5() test_from_data() test_herreplaybuffer() + test_custom_key() diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index aa159bd34..6294ec36d 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -220,9 +220,8 @@ def add( ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. - :param Batch batch: the input data batch. Its keys must belong to the 7 - input keys, and "obs", "act", "rew", "terminated", "truncated" is - required. + :param Batch batch: the input data batch. "obs", "act", "rew", + "terminated", "truncated" are required keys. :param buffer_ids: to make consistent with other buffer's add function; if it is not None, we assume the input batch's first dimension is always 1. @@ -232,12 +231,12 @@ def add( """ # preprocess batch new_batch = Batch() - for key in set(self._input_keys).intersection(batch.keys()): + for key in batch.keys(): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) - assert set(["obs", "act", "rew", "terminated", "truncated", - "done"]).issubset(batch.keys()) + assert set(["obs", "act", "rew", "terminated", "truncated", "done"] + ).issubset(batch.keys()) # important to do after preprocess batch stacked_batch = buffer_ids is not None if stacked_batch: assert len(batch) == 1 @@ -376,14 +375,18 @@ def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch: obs_next = self.get(indices, "obs_next", Batch()) else: obs_next = self.get(self.next(indices), "obs", Batch()) - return Batch( - obs=obs, - act=self.act[indices], - rew=self.rew[indices], - terminated=self.terminated[indices], - truncated=self.truncated[indices], - done=self.done[indices], - obs_next=obs_next, - info=self.get(indices, "info", Batch()), - policy=self.get(indices, "policy", Batch()), - ) + batch_dict = { + "obs": obs, + "act": self.act[indices], + "rew": self.rew[indices], + "terminated": self.terminated[indices], + "truncated": self.truncated[indices], + "done": self.done[indices], + "obs_next": obs_next, + "info": self.get(indices, "info", Batch()), + "policy": self.get(indices, "policy", Batch()), + } + for key in self._meta.__dict__.keys(): + if key not in self._input_keys: + batch_dict[key] = self._meta[key][indices] + return Batch(batch_dict)