diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 7e287e269..4ba5d8d5a 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -10,7 +10,12 @@ def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() - assert Batch(b={'c': {}}).is_empty() + assert not Batch(b={'c': {}}).is_empty() + assert Batch(b={'c': {}}).is_empty(recurse=True) + assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() + assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) + assert not Batch(d=1).is_empty() + assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch() @@ -109,6 +114,11 @@ def test_batch(): assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) + # None is a valid object and can be stored in Batch + a = Batch.stack([Batch(a=None), Batch(b=None)]) + assert a.a[0] is None and a.a[1] is None + assert a.b[0] is None and a.b[1] is None + def test_batch_over_batch(): batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) @@ -162,6 +172,20 @@ def test_batch_cat_and_stack(): assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 + a = Batch(a=Batch(a=np.random.randn(3, 4))) + assert np.allclose( + np.concatenate([a.a.a, a.a.a]), + Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a) + + # test cat with lens infer + a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4)) + b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) + ans = Batch.cat([a, b, a]) + assert np.allclose(ans.a.a, + np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) + assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) + assert ans.a.t.is_empty() + b12_stack = Batch.stack((b1, b2)) assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 @@ -177,6 +201,32 @@ def test_batch_cat_and_stack(): assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) + # test cat with reserved keys (values are Batch()) + b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) + b2 = Batch(a=Batch(), + b=torch.rand(4, 3), + common=Batch(c=np.random.rand(4, 5))) + test = Batch.cat([b1, b2]) + ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + assert np.allclose(test.a, ans.a) + assert torch.allclose(test.b, ans.b) + assert np.allclose(test.common.c, ans.common.c) + + # test cat with all reserved keys (values are Batch()) + b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5))) + b2 = Batch(a=Batch(), + b=torch.rand(4, 3), + common=Batch(c=np.random.rand(4, 5))) + test = Batch.cat([b1, b2]) + ans = Batch(a=Batch(), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + assert ans.a.is_empty() + assert torch.allclose(test.b, ans.b) + assert np.allclose(test.common.c, ans.common.c) + # test stack with compatible keys b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), @@ -205,6 +255,25 @@ def test_batch_cat_and_stack(): assert np.allclose(d.c, [3, 0, 7]) assert np.allclose(d.d, [0, 6, 9]) + # test stack with empty Batch() + assert Batch.stack([Batch(), Batch(), Batch()]).is_empty() + a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch()) + b = Batch(a=4, b=5, d=6, e=Batch()) + c = Batch(c=7, b=6, d=9, e=Batch()) + d = Batch.stack([a, b, c]) + assert np.allclose(d.a, [1, 4, 0]) + assert np.allclose(d.b, [2, 5, 6]) + assert np.allclose(d.c, [3, 0, 7]) + assert np.allclose(d.d, [0, 6, 9]) + assert d.e.is_empty() + b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5))) + b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5))) + test = Batch.stack([b1, b2], axis=-1) + assert test.a.is_empty() + assert test.b.is_empty() + assert np.allclose(test.common.c, + np.stack([b1.common.c, b2.common.c], axis=-1)) + b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2]) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 23d859999..7eaca1144 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -14,11 +14,17 @@ def _is_batch_set(data: Any) -> bool: + # Batch set is a list/tuple of dict/Batch objects, + # or 1-D np.ndarray with np.object type, + # where each element is a dict/Batch object if isinstance(data, (list, tuple)): if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): return True elif isinstance(data, np.ndarray) and data.dtype == np.object: - if all(isinstance(e, (dict, Batch)) for e in data.tolist()): + # ``for e in data`` will just unpack the first dimension, + # but data.tolist() will flatten ndarray of objects + # so do not use data.tolist() + if all(isinstance(e, (dict, Batch)) for e in data): return True return False @@ -39,7 +45,7 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ # here we do not consider scalar types, following the # behavior of numpy which does not support concatenation # of zero-dimensional arrays (scalars) - raise TypeError(f"cannot cat {inst} with which is scalar") + raise TypeError(f"cannot concatenate with {inst} which is scalar") if has_shape: shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) if isinstance(inst, np.ndarray): @@ -95,9 +101,9 @@ class Batch: In short, you can define a :class:`Batch` with any key-value pair. - For Numpy arrays, only data types with ``np.object``, bool, and number - are supported. For strings or other data types, however, they can be - held in ``np.object`` arrays. + For Numpy arrays, only data types with ``np.object``, bool, and number are + supported. For strings or other data types, however, they can be held in + ``np.object`` arrays. The current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: @@ -108,9 +114,39 @@ class Batch: * ``done`` the done flag of step :math:`t` ; * ``obs_next`` the observation of step :math:`t+1` ; * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()``\ - function returns 4 arguments, and the last one is ``info``); + function returns 4 arguments, and the last one is ``info``); * ``policy`` the data computed by policy in step :math:`t`; + For convenience, :class:`~tianshou.data.Batch` supports the mechanism of + key reservation: one can specify a key without any value, which serves as + a placeholder for the Batch object. For example, you know there will be a + key named ``obs``, but do not know the value until the simulator runs. Then + you can reserve the key ``obs``. This is done by setting the value to + ``Batch()``. + + For a Batch object, we call it "incomplete" if: (i) it is ``Batch()``; (ii) + it has reserved keys; (iii) any of its sub-Batch is incomplete. Otherwise, + the Batch object is finalized. + + Key reservation mechanism is convenient, but also causes some problem in + aggregation operators like ``stack`` or ``cat`` of Batch objects. We say + that Batch objects are compatible for aggregation with three cases: + + 1. finalized Batch objects are compatible if and only if their exists a \ + way to extend keys so that their structures are exactly the same. + + 2. incomplete Batch objects and other finalized objects are compatible if \ + their exists a way to extend keys so that incomplete Batch objects can \ + have the same structure as finalized objects. + + 3. incomplete Batch objects themselevs are compatible if their exists a \ + way to extend keys so that their structure can be the same. + + In a word, incomplete Batch objects have a set of possible structures + in the future, but finalized Batch object only have a finalized structure. + Batch objects are compatible if and only if they share at least one + commonly possible structure by extending keys. + :class:`~tianshou.data.Batch` object can be initialized by a wide variety of arguments, ranging from the key/value pairs or dictionary, to list and Numpy arrays of :class:`dict` or Batch instances where each element is @@ -126,8 +162,8 @@ class Batch: ) :class:`~tianshou.data.Batch` has the same API as a native Python - :class:`dict`. In this regard, one can access stored data using string - key, or iterate over stored data: + :class:`dict`. In this regard, one can access stored data using string key, + or iterate over stored data: :: >>> data = Batch(a=4, b=[5, 5]) @@ -153,7 +189,7 @@ class Batch: ) >>> for sample in data: >>> print(sample.a) - [0., 2.] + [0. 2.] >>> print(data.shape) [1, 2] @@ -341,7 +377,7 @@ def __getitem__(self, index: Union[ if len(batch_items) > 0: b = Batch() for k, v in batch_items: - if isinstance(v, Batch) and len(v.__dict__) == 0: + if isinstance(v, Batch) and v.is_empty(): b.__dict__[k] = Batch() else: b.__dict__[k] = v[index] @@ -376,8 +412,9 @@ def __setitem__(self, index: Union[ except KeyError: if isinstance(val, Batch): self.__dict__[key][index] = Batch() - elif isinstance(val, np.ndarray) and \ - issubclass(val.dtype.type, (np.bool_, np.number)): + elif isinstance(val, torch.Tensor) or \ + (isinstance(val, np.ndarray) and + issubclass(val.dtype.type, (np.bool_, np.number))): self.__dict__[key][index] = 0 else: self.__dict__[key][index] = None @@ -389,14 +426,14 @@ def __iadd__(self, other: Union['Batch', Number, np.number]): for (k, r), v in zip(self.__dict__.items(), other.__dict__.values()): # TODO are keys consistent? - if r is None: + if isinstance(r, Batch) and r.is_empty(): continue else: self.__dict__[k] += v return self elif isinstance(other, (Number, np.number)): for k, r in self.items(): - if r is None: + if isinstance(r, Batch) and r.is_empty(): continue else: self.__dict__[k] += other @@ -413,7 +450,9 @@ def __imul__(self, val: Union[Number, np.number]): """Algebraic multiplication with a scalar value in-place.""" assert isinstance(val, (Number, np.number)), \ "Only multiplication by a number is supported." - for k in self.__dict__.keys(): + for k, r in self.__dict__.items(): + if isinstance(r, Batch) and r.is_empty(): + continue self.__dict__[k] *= val return self @@ -425,7 +464,9 @@ def __itruediv__(self, val: Union[Number, np.number]): """Algebraic division with a scalar value in-place.""" assert isinstance(val, (Number, np.number)), \ "Only division by a number is supported." - for k in self.__dict__.keys(): + for k, r in self.__dict__.items(): + if isinstance(r, Batch) and r.is_empty(): + continue self.__dict__[k] /= val return self @@ -501,50 +542,106 @@ def to_torch(self, dtype: Optional[torch.dtype] = None, elif isinstance(v, Batch): v.to_torch(dtype, device) - def cat_(self, - batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: - """Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects - into current batch. + def __cat(self, + batches: Union['Batch', List[Union[dict, 'Batch']]], + lens: List[int]) -> None: + """:: + + >>> a = Batch(a=np.random.randn(3, 4)) + >>> x = Batch(a=a, b=np.random.randn(4, 4)) + >>> y = Batch(a=Batch(a=Batch()), b=np.random.randn(4, 4)) + + If we want to concatenate x and y, we want to pad y.a.a with zeros. + Without ``lens`` as a hint, when we concatenate x.a and y.a, we would + not be able to know how to pad y.a. So ``Batch.cat_`` should compute + the ``lens`` to give ``Batch.__cat`` a hint. + :: + + >>> ans = Batch.cat([x, y]) + >>> # this is equivalent to the following line + >>> ans = Batch(); ans.__cat([x, y], lens=[3, 4]) + >>> # this lens is equal to [len(a), len(b)] """ - if isinstance(batches, Batch): - batches = [batches] - if len(batches) == 0: - return - batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] - if len(self.__dict__) > 0: - batches = [self] + list(batches) # partial keys will be padded by zeros # with the shape of [len, rest_shape] - lens = [len(x) for x in batches] sum_lens = [0] for x in lens: sum_lens.append(sum_lens[-1] + x) - keys_map = list(map(lambda e: set(e.keys()), batches)) + # collect non-empty keys + keys_map = [ + set(k for k, v in batch.items() + if not (isinstance(v, Batch) and v.is_empty())) + for batch in batches] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): - self.__dict__[k] = Batch.cat(v) + batch_holder = Batch() + batch_holder.__cat(v, lens=lens) + self.__dict__[k] = batch_holder elif all(isinstance(e, torch.Tensor) for e in v): self.__dict__[k] = torch.cat(v) else: + # cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch())) + # will fail here v = np.concatenate(v) if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) self.__dict__[k] = v - keys_partial = set.union(*keys_map) - keys_shared - _assert_type_keys(keys_partial) + keys_total = set.union(*[set(b.keys()) for b in batches]) + keys_reserve_or_partial = set.difference(keys_total, keys_shared) + _assert_type_keys(keys_reserve_or_partial) + # keys that are reserved in all batches + keys_reserve = set.difference(keys_total, set.union(*keys_map)) + # keys that occur only in some batches, but not all + keys_partial = keys_reserve_or_partial.difference(keys_reserve) + for k in keys_reserve: + # reserved keys + self.__dict__[k] = Batch() for k in keys_partial: for i, e in enumerate(batches): - val = e.get(k, None) - if val is not None: - try: - self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val - except KeyError: - self.__dict__[k] = \ - _create_value(val, sum_lens[-1], stack=False) - self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val + if k not in e.__dict__: + continue + val = e.get(k) + if isinstance(val, Batch) and val.is_empty(): + continue + try: + self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val + except KeyError: + self.__dict__[k] = \ + _create_value(val, sum_lens[-1], stack=False) + self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val + + def cat_(self, + batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: + """Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects + into current batch. + """ + if isinstance(batches, Batch): + batches = [batches] + if len(batches) == 0: + return + batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] + + # x.is_empty() means that x is Batch() and should be ignored + batches = [x for x in batches if not x.is_empty()] + try: + # x.is_empty(recurse=True) here means x is a nested empty batch + # like Batch(a=Batch), and we have to treat it as length zero and + # keep it. + lens = [0 if x.is_empty(recurse=True) else len(x) + for x in batches] + except TypeError as e: + e2 = ValueError( + f'Batch.cat_ meets an exception. Maybe because there is ' + f'any scalar in {batches} but Batch.cat_ does not support' + f'the concatenation of scalar.') + raise Exception([e, e2]) + if not self.is_empty(): + batches = [self] + list(batches) + lens = [0 if self.is_empty(recurse=True) else len(self)] + lens + return self.__cat(batches, lens) @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': @@ -577,9 +674,13 @@ def stack_(self, if len(batches) == 0: return batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] - if len(self.__dict__) > 0: + if not self.is_empty(): batches = [self] + list(batches) - keys_map = list(map(lambda e: set(e.keys()), batches)) + # collect non-empty keys + keys_map = [ + set(k for k, v in batch.items() + if not (isinstance(v, Batch) and v.is_empty())) + for batch in batches] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] _assert_type_keys(keys_shared) @@ -593,22 +694,35 @@ def stack_(self, if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) self.__dict__[k] = v - keys_partial = set.difference(set.union(*keys_map), keys_shared) + # all the keys + keys_total = set.union(*[set(b.keys()) for b in batches]) + # keys that are reserved in all batches + keys_reserve = set.difference(keys_total, set.union(*keys_map)) + # keys that are either partial or reserved + keys_reserve_or_partial = set.difference(keys_total, keys_shared) + # keys that occur only in some batches, but not all + keys_partial = keys_reserve_or_partial.difference(keys_reserve) if keys_partial and axis != 0: raise ValueError( f"Stack of Batch with non-shared keys {keys_partial} " f"is only supported with axis=0, but got axis={axis}!") - _assert_type_keys(keys_partial) + _assert_type_keys(keys_reserve_or_partial) + for k in keys_reserve: + # reserved keys + self.__dict__[k] = Batch() for k in keys_partial: for i, e in enumerate(batches): - val = e.get(k, None) - if val is not None: - try: - self.__dict__[k][i] = val - except KeyError: - self.__dict__[k] = \ - _create_value(val, len(batches)) - self.__dict__[k][i] = val + if k not in e.__dict__: + continue + val = e.get(k) + if isinstance(val, Batch) and val.is_empty(): + continue + try: + self.__dict__[k][i] = val + except KeyError: + self.__dict__[k] = \ + _create_value(val, len(batches)) + self.__dict__[k][i] = val @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': @@ -691,26 +805,53 @@ def __len__(self) -> int: """Return len(self).""" r = [] for v in self.__dict__.values(): - if isinstance(v, Batch) and v.is_empty(): + if isinstance(v, Batch) and v.is_empty(recurse=True): continue elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): r.append(len(v)) else: - raise TypeError("Object of type 'Batch' has no len()") + raise TypeError(f"Object {v} in {self} has no len()") if len(r) == 0: - raise TypeError("Object of type 'Batch' has no len()") + raise TypeError(f"Object {self} has no len()") return min(r) - def is_empty(self): - return not any( - not x.is_empty() if isinstance(x, Batch) - else hasattr(x, '__len__') and len(x) > 0 for x in self.values()) + def is_empty(self, recurse: bool = False): + """ + Test if a Batch is empty. If ``recurse=True``, it further tests the + values of the object; else it only tests the existence of any key. + + ``b.is_empty(recurse=True)`` is mainly used to distinguish + ``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise + exceptions when applied to ``len()``, but the former can be used in + ``cat``, while the latter is a scalar and cannot be used in ``cat``. + + Another usage is in ``__len__``, where we have to skip checking the + length of recursely empty Batch. + :: + + >>> Batch().is_empty() + True + >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty() + False + >>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) + True + >>> Batch(d=1).is_empty() + False + >>> Batch(a=np.float64(1.0)).is_empty() + False + """ + if len(self.__dict__) == 0: + return True + if not recurse: + return False + return all(False if not isinstance(x, Batch) + else x.is_empty(recurse=True) for x in self.values()) @property def shape(self) -> List[int]: """Return self.shape.""" - if len(self.__dict__.keys()) == 0: + if self.is_empty(): return [] else: data_shape = [] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 2bacf14f0..98c62daa2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -98,7 +98,7 @@ def __init__(self, stat_size: Optional[int] = 100, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, - **kwargs) -> None: + ) -> None: super().__init__() self.env = env self.env_num = 1 diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 335d7e86a..e75374228 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -108,7 +108,8 @@ def compute_episodic_return( batch: Batch, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, - gae_lambda: float = 0.95) -> Batch: + gae_lambda: float = 0.95, + ) -> Batch: """Compute returns over given full-length episodes, including the implementation of Generalized Advantage Estimator (arXiv:1506.02438). @@ -124,18 +125,19 @@ def compute_episodic_return( :return: a Batch. The result will be stored in batch.returns. """ + rew = batch.rew if v_s_ is None: - v_s_ = batch.rew * 0. + v_s_ = rew * 0. else: if not isinstance(v_s_, np.ndarray): v_s_ = np.array(v_s_, np.float) - v_s_ = v_s_.reshape(batch.rew.shape) + v_s_ = v_s_.reshape(rew.shape) returns = np.roll(v_s_, 1, axis=0) m = (1. - batch.done) * gamma - delta = batch.rew + v_s_ * m - returns + delta = rew + v_s_ * m - returns m *= gae_lambda gae = 0. - for i in range(len(batch.rew) - 1, -1, -1): + for i in range(len(rew) - 1, -1, -1): gae = delta[i] + m[i] * gae returns[i] += gae batch.returns = returns @@ -149,7 +151,7 @@ def compute_nstep_return( target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, - rew_norm: bool = False + rew_norm: bool = False, ) -> np.ndarray: r"""Compute n-step return for Q-learning targets: @@ -180,8 +182,9 @@ def compute_nstep_return( :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with shape (bsz, ). """ + rew = buffer.rew if rew_norm: - bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer + bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0): mean, std = 0, 1 @@ -189,7 +192,7 @@ def compute_nstep_return( mean, std = 0, 1 returns = np.zeros_like(indice) gammas = np.zeros_like(indice) + n_step - done, rew, buf_len = buffer.done, buffer.rew, len(buffer) + done, buf_len = buffer.done, len(buffer) for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len gammas[done[now] > 0] = n diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 57bdba933..c01e45fd0 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -23,7 +23,7 @@ class ImitationPolicy(BasePolicy): """ def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, - mode: str = 'continuous', **kwargs) -> None: + mode: str = 'continuous') -> None: super().__init__() self.model = model self.optim = optim diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c34ba4e04..eb6f29878 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -21,6 +21,8 @@ class DQNPolicy(BasePolicy): ahead. :param int target_update_freq: the target network update frequency (``0`` if you do not use the target network). + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to ``False``. .. seealso:: @@ -34,6 +36,7 @@ def __init__(self, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: Optional[int] = 0, + reward_normalization: bool = False, **kwargs) -> None: super().__init__(**kwargs) self.model = model @@ -49,6 +52,7 @@ def __init__(self, if self._target: self.model_old = deepcopy(self.model) self.model_old.eval() + self._rew_norm = reward_normalization def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" @@ -94,7 +98,8 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, to :math:`Q_{new}`. """ batch = self.compute_nstep_return( - batch, buffer, indice, self._target_q, self._gamma, self._n_step) + batch, buffer, indice, self._target_q, + self._gamma, self._n_step, self._rew_norm) if isinstance(buffer, PrioritizedReplayBuffer): batch.update_weight = buffer.update_weight batch.indice = indice diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 1e1272157..5cecd0570 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -27,7 +27,6 @@ def offpolicy_trainer( writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, - **kwargs ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index e3849fe88..5f7ae7694 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -27,7 +27,6 @@ def onpolicy_trainer( writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, - **kwargs ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure.