diff --git a/test/base/test_batch.py b/test/base/test_batch.py index f11a8d60e..801ab4484 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -556,7 +556,7 @@ def test_batch_standard_compatibility() -> None: batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])) batch_mean = np.mean(batch) assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst` - assert sorted(batch_mean.keys()) == ["a", "b", "c"] # type: ignore + assert sorted(batch_mean.get_keys()) == ["a", "b", "c"] # type: ignore with pytest.raises(TypeError): len(batch_mean) assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 31265f664..5488ff365 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1379,11 +1379,14 @@ def test_custom_key() -> None: buffer.add(batch) sampled_batch, _ = buffer.sample(1) # Check if they have the same keys - assert set(batch.keys()) == set( - sampled_batch.keys(), - ), "Batches have different keys: {} and {}".format(set(batch.keys()), set(sampled_batch.keys())) + assert set(batch.get_keys()) == set( + sampled_batch.get_keys(), + ), "Batches have different keys: {} and {}".format( + set(batch.get_keys()), + set(sampled_batch.get_keys()), + ) # Compare the values for each key - for key in batch.keys(): + for key in batch.get_keys(): if isinstance(batch.__dict__[key], np.ndarray) and isinstance( sampled_batch.__dict__[key], np.ndarray, diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 5c7fa036e..1136beb23 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,6 +1,6 @@ import pprint import warnings -from collections.abc import Collection, Iterable, Iterator, Sequence +from collections.abc import Collection, Iterable, Iterator, KeysView, Sequence from copy import deepcopy from numbers import Number from types import EllipsisType @@ -185,8 +185,8 @@ def alloc_by_keys_diff( This mainly is an internal method, use it only if you know what you are doing. """ - for key in batch.keys(): - if key in meta.keys(): + for key in batch.get_keys(): + if key in meta.get_keys(): if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): alloc_by_keys_diff(meta[key], batch[key], size, stack) elif isinstance(meta[key], Batch) and meta[key].is_empty(): @@ -441,6 +441,9 @@ def to_dict(self) -> dict[str, Any]: result[k] = v return result + def get_keys(self) -> KeysView: + return self.__dict__.keys() + def to_list_of_dicts(self) -> list[dict[str, Any]]: return [entry.to_dict() for entry in self] diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 53f9bd8eb..a34964f5a 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -253,12 +253,12 @@ def add( """ # preprocess batch new_batch = Batch() - for key in batch.keys(): + for key in batch.get_keys(): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset( - batch.keys(), + batch.get_keys(), ) # important to do after preprocess batch stacked_batch = buffer_ids is not None if stacked_batch: diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index a495b0ada..90480257a 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -127,11 +127,11 @@ def add( """ # preprocess batch new_batch = Batch() - for key in set(self._reserved_keys).intersection(batch.keys()): + for key in set(self._reserved_keys).intersection(batch.get_keys()): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) - assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys()) + assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.get_keys()) if self._save_only_last_obs: batch.obs = batch.obs[:, -1] if not self._save_obs_next: diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 7a7be5888..81cfe0a6d 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -225,7 +225,7 @@ def forward( # type: ignore results.append((False, np.array([-1]), Batch(), Batch(), Batch())) continue tmp_batch = batch[agent_index] - if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray): + if "rew" in tmp_batch.get_keys() and isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] if not hasattr(tmp_batch.obs, "mask"):