diff --git a/README.md b/README.md index 414e288c3..2c027d506 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Here is Tianshou's other features: - Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training) - Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) - Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process) -- Support n-step returns estimation for all Q-learning based algorithms +- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms - Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. diff --git a/docs/index.rst b/docs/index.rst index 25f041085..9ef598a81 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,7 +28,7 @@ Here is Tianshou's other features: * Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` * Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` * Support customized training process: :ref:`customize_training` -* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms +* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay for all Q-learning based algorithms * Support multi-agent RL: :doc:`/tutorials/tictactoe` 中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_ diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index d9f59d2b8..4c6bc7176 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,6 +1,9 @@ +import pytest import numpy as np +from timeit import timeit -from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer +from tianshou.data import Batch, PrioritizedReplayBuffer, \ + ReplayBuffer, SegmentTree if __name__ == '__main__': from env import MyTestEnv @@ -112,9 +115,110 @@ def test_update(): assert (buf2[-1].obs == buf1[0].obs).all() +def test_segtree(): + for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]): + realop = getattr(np, op) + # small test + actual_len = 8 + tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes + assert np.all([tree[i] == init for i in range(actual_len)]) + with pytest.raises(IndexError): + tree[actual_len] + naive = np.full([actual_len], init) + for _ in range(1000): + # random choose a place to perform single update + index = np.random.randint(actual_len) + value = np.random.rand() + naive[index] = value + tree[index] = value + for i in range(actual_len): + for j in range(i + 1, actual_len): + ref = realop(naive[i:j]) + out = tree.reduce(i, j) + assert np.allclose(ref, out) + # batch setitem + for _ in range(1000): + index = np.random.choice(actual_len, size=4) + value = np.random.rand(4) + naive[index] = value + tree[index] = value + assert np.allclose(realop(naive), tree.reduce()) + for i in range(10): + left = np.random.randint(actual_len) + right = np.random.randint(left + 1, actual_len + 1) + assert np.allclose(realop(naive[left:right]), + tree.reduce(left, right)) + # large test + actual_len = 16384 + tree = SegmentTree(actual_len, op) + naive = np.full([actual_len], init) + for _ in range(1000): + index = np.random.choice(actual_len, size=64) + value = np.random.rand(64) + naive[index] = value + tree[index] = value + assert np.allclose(realop(naive), tree.reduce()) + for i in range(10): + left = np.random.randint(actual_len) + right = np.random.randint(left + 1, actual_len + 1) + assert np.allclose(realop(naive[left:right]), + tree.reduce(left, right)) + + # test prefix-sum-idx + actual_len = 8 + tree = SegmentTree(actual_len) + naive = np.random.rand(actual_len) + tree[np.arange(actual_len)] = naive + for _ in range(1000): + scalar = np.random.rand() * naive.sum() + index = tree.get_prefix_sum_idx(scalar) + assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() + # corner case here + naive = np.ones(actual_len, np.int) + tree[np.arange(actual_len)] = naive + for scalar in range(actual_len): + index = tree.get_prefix_sum_idx(scalar * 1.) + assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() + tree = SegmentTree(10) + tree[np.arange(3)] = np.array([0.1, 0, 0.1]) + assert np.allclose(tree.get_prefix_sum_idx( + np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2]) + with pytest.raises(AssertionError): + tree.get_prefix_sum_idx(.2) + # test large prefix-sum-idx + actual_len = 16384 + tree = SegmentTree(actual_len) + naive = np.random.rand(actual_len) + tree[np.arange(actual_len)] = naive + for _ in range(1000): + scalar = np.random.rand() * naive.sum() + index = tree.get_prefix_sum_idx(scalar) + assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() + + # profile + if __name__ == '__main__': + size = 100000 + bsz = 64 + naive = np.random.rand(size) + tree = SegmentTree(size) + tree[np.arange(size)] = naive + + def sample_npbuf(): + return np.random.choice(size, bsz, p=naive / naive.sum()) + + def sample_tree(): + scalar = np.random.rand(bsz) * tree.reduce() + return tree.get_prefix_sum_idx(scalar) + + print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000)) + print('tree', timeit(sample_tree, setup=sample_tree, number=1000)) + + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() test_stack() + test_segtree() + test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) test_update() diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 0455f7059..ae4c4ce0c 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -8,9 +8,9 @@ from tianshou.env import VectorEnv from tianshou.policy import DQNPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer def get_args(): @@ -33,6 +33,9 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', type=int, default=0) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -58,15 +61,20 @@ def test_dqn(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, - args.action_shape, args.device, - dueling=(2, 2)).to(args.device) + args.action_shape, args.device, # dueling=(1, 1) + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) + # buffer + if args.prioritized_replay > 0: + buf = PrioritizedReplayBuffer( + args.buffer_size, alpha=args.alpha, beta=args.beta) + else: + buf = ReplayBuffer(args.buffer_size) # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + train_collector = Collector(policy, train_envs, buf) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) @@ -114,5 +122,11 @@ def test_fn(x): collector.close() +def test_pdqn(args=get_args()): + args.prioritized_replay = 1 + args.gamma = .95 + test_dqn(args) + + if __name__ == '__main__': test_dqn(get_args()) diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py deleted file mode 100644 index b614f248a..000000000 --- a/test/discrete/test_pdqn.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import gym -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.utils.net.common import Net -from tianshou.env import VectorEnv -from tianshou.policy import DQNPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0.05) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--layer-num', type=int, default=3) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', type=int, default=1) - parser.add_argument('--alpha', type=float, default=0.5) - parser.add_argument('--beta', type=float, default=0.5) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - args = parser.parse_known_args()[0] - return args - - -def test_pdqn(args=get_args()): - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = VectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = VectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.layer_num, args.state_shape, - args.action_shape, args.device).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) - # collector - if args.prioritized_replay > 0: - buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, - beta=args.alpha, repeat_sample=True) - else: - buf = ReplayBuffer(args.buffer_size) - train_collector = Collector( - policy, train_envs, buf) - test_collector = Collector(policy, test_envs) - # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) - # log - log_path = os.path.join(args.logdir, args.task, 'dqn') - writer = SummaryWriter(log_path) - - def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - - def stop_fn(x): - return x >= env.spec.reward_threshold - - def train_fn(x): - policy.set_eps(args.eps_train) - - def test_fn(x): - policy.set_eps(args.eps_test) - - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) - - assert stop_fn(result['best_reward']) - train_collector.close() - test_collector.close() - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - collector.close() - - -if __name__ == '__main__': - test_pdqn(get_args()) diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index aec32682a..88abdcb64 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -1,8 +1,8 @@ -import numpy as np import pytest +import numpy as np from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer, - ReplayBuffer) + ReplayBuffer, SegmentTree) @pytest.fixture(scope="module") @@ -21,7 +21,7 @@ def data(): 'buffer': buffer, 'buffer2': buffer2, 'slice': slice(-3000, -1000, 2), - 'indexes': indexes + 'indexes': indexes, } @@ -77,5 +77,15 @@ def test_sample(data): buffer.sample(int(1e2)) +def test_segtree(data): + size = 100000 + tree = SegmentTree(size) + tree[np.arange(size)] = np.random.rand(size) + + for i in np.arange(1e5): + scalar = np.random.rand(64) * tree.reduce() + tree.get_prefix_sum_idx(scalar) + + if __name__ == '__main__': pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"]) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 5d097a03b..f5f68e9e0 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,6 +1,7 @@ from tianshou.data.batch import Batch -from tianshou.data.utils import to_numpy, to_torch, \ +from tianshou.data.utils.converter import to_numpy, to_torch, \ to_torch_as +from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer import ReplayBuffer, \ ListReplayBuffer, PrioritizedReplayBuffer from tianshou.data.collector import Collector @@ -10,8 +11,9 @@ 'to_numpy', 'to_torch', 'to_torch_as', + 'SegmentTree', 'ReplayBuffer', 'ListReplayBuffer', 'PrioritizedReplayBuffer', - 'Collector' + 'Collector', ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 4491bee01..1d7a80f3c 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,7 +1,9 @@ +import torch import numpy as np from typing import Any, Tuple, Union, Optional -from tianshou.data.batch import Batch, _create_value +from tianshou.data import Batch, SegmentTree, to_numpy +from tianshou.data.batch import _create_value class ReplayBuffer: @@ -313,7 +315,7 @@ def __getitem__(self, index: Union[ done=self.done[index], obs_next=self.get(index, 'obs_next'), info=self.get(index, 'info'), - policy=self.get(index, 'policy') + policy=self.get(index, 'policy'), ) @@ -326,8 +328,8 @@ class ListReplayBuffer(ReplayBuffer): .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` for more - detailed explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed + explanation. """ def __init__(self, **kwargs) -> None: @@ -353,31 +355,32 @@ def reset(self) -> None: class PrioritizedReplayBuffer(ReplayBuffer): - """Prioritized replay buffer implementation. + """Implementation of Prioritized Experience Replay. arXiv:1511.05952 :param float alpha: the prioritization exponent. :param float beta: the importance sample soft coefficient. - :param str mode: defaults to ``weight``. - :param bool replace: whether to sample with replacement .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` for more - detailed explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed + explanation. """ - def __init__(self, size: int, alpha: float, beta: float, - mode: str = 'weight', - replace: bool = False, **kwargs) -> None: - if mode != 'weight': - raise NotImplementedError + def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None: super().__init__(size, **kwargs) - self._alpha = alpha - self._beta = beta - self._weight_sum = 0.0 - self._amortization_freq = 50 - self._replace = replace - self._meta.weight = np.zeros(size, dtype=np.float64) + assert alpha > 0. and beta >= 0. + self._alpha, self._beta = alpha, beta + self._max_prio = 1. + self._min_prio = 1. + # bypass the check + self._weight = SegmentTree(size) + self.__eps = np.finfo(np.float32).eps.item() + + def __getattr__(self, key: str) -> Union['Batch', Any]: + """Return self.key""" + if key == 'weight': + return self._weight + return self._meta.__dict__[key] def add(self, obs: Union[dict, np.ndarray], @@ -387,68 +390,55 @@ def add(self, obs_next: Optional[Union[dict, np.ndarray]] = None, info: dict = {}, policy: Optional[Union[dict, Batch]] = {}, - weight: float = 1.0, + weight: float = None, **kwargs) -> None: """Add a batch of data into replay buffer.""" - # we have to sacrifice some convenience for speed - self._weight_sum += np.abs(weight) ** self._alpha - \ - self._meta.weight[self._index] - self._add_to_buffer('weight', np.abs(weight) ** self._alpha) + if weight is None: + weight = self._max_prio + else: + weight = np.abs(weight) + self._max_prio = max(self._max_prio, weight) + self._min_prio = min(self._min_prio, weight) + self.weight[self._index] = weight ** self._alpha super().add(obs, act, rew, done, obs_next, info, policy) - @property - def replace(self): - return self._replace - - @replace.setter - def replace(self, v: bool): - self._replace = v - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with priority probability. \ - Return all the data in the buffer if batch_size is ``0``. + """Get a random sample from buffer with priority probability. Return + all the data in the buffer if batch_size is ``0``. :return: Sample data and its corresponding index inside the buffer. + + The ``weight`` in the returned Batch is the weight on loss function + to de-bias the sampling process (some transition tuples are sampled + more often so their losses are weighted less). """ - assert self._size > 0, 'cannot sample a buffer with size == 0 !' - p = None - if batch_size > 0 and (self._replace or batch_size <= self._size): - # sampling weight - p = (self.weight / self.weight.sum())[:self._size] - indice = np.random.choice( - self._size, batch_size, p=p, - replace=self._replace) - p = p[indice] # weight of each sample - elif batch_size == 0: - p = np.full(shape=self._size, fill_value=1.0 / self._size) + assert self._size > 0, 'Cannot sample a buffer with 0 size!' + if batch_size == 0: indice = np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index), ]) else: - raise ValueError( - f"batch_size should be less than {len(self)}, \ - or set replace=True") + scalar = np.random.rand(batch_size) * self.weight.reduce() + indice = self.weight.get_prefix_sum_idx(scalar) batch = self[indice] - batch["impt_weight"] = (self._size * p) ** (-self._beta) + # impt_weight + # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) + # simplified formula: (p_j/p_min)**(-beta) + batch.weight = (batch.weight / self._min_prio) ** (-self._beta) return batch, indice - def update_weight(self, indice: Union[slice, np.ndarray], - new_weight: np.ndarray) -> None: + def update_weight(self, indice: Union[np.ndarray], + new_weight: Union[np.ndarray, torch.Tensor]) -> None: """Update priority weight by indice in this buffer. - :param np.ndarray indice: indice you want to update weight - :param np.ndarray new_weight: new priority weight you want to update + :param np.ndarray indice: indice you want to update weight. + :param np.ndarray new_weight: new priority weight you want to update. """ - if self._replace: - if isinstance(indice, slice): - # convert slice to ndarray - indice = np.arange(indice.stop)[indice] - # remove the same values in indice - indice, unique_indice = np.unique( - indice, return_index=True) - new_weight = new_weight[unique_indice] - self.weight[indice] = np.power(np.abs(new_weight), self._alpha) + weight = np.abs(to_numpy(new_weight)) + self.__eps + self.weight[indice] = weight ** self._alpha + self._max_prio = max(self._max_prio, weight.max()) + self._min_prio = min(self._min_prio, weight.min()) def __getitem__(self, index: Union[ slice, int, np.integer, np.ndarray]) -> Batch: @@ -459,6 +449,6 @@ def __getitem__(self, index: Union[ done=self.done[index], obs_next=self.get(index, 'obs_next'), info=self.get(index, 'info'), - weight=self.weight[index], policy=self.get(index, 'policy'), + weight=self.weight[index], ) diff --git a/tianshou/data/utils/__init__.py b/tianshou/data/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tianshou/data/utils.py b/tianshou/data/utils/converter.py similarity index 100% rename from tianshou/data/utils.py rename to tianshou/data/utils/converter.py diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py new file mode 100644 index 000000000..60a60dd50 --- /dev/null +++ b/tianshou/data/utils/segtree.py @@ -0,0 +1,133 @@ +import numpy as np +from typing import Union, Optional +# from numba import njit + + +# numba version, 5x speed up +# with size=100000 and bsz=64 +# first block (vectorized np): 0.0923 (now) -> 0.0251 +# second block (for-loop): 0.2914 -> 0.0192 (future) +# @njit +def _get_prefix_sum_idx(value, bound, sums): + index = np.ones(value.shape, dtype=np.int64) + while index[0] < bound: + index *= 2 + direct = sums[index] < value + value -= sums[index] * direct + index += direct + # for _, s in enumerate(value): + # i = 1 + # while i < bound: + # l = i * 2 + # if sums[l] >= s: + # i = l + # else: + # s = s - sums[l] + # i = l + 1 + # index[_] = i + index -= bound + return index + + +class SegmentTree: + """Implementation of Segment Tree: store an array ``arr`` with size ``n`` + in a segment tree, support value update and fast query of ``min/max/sum`` + for the interval ``[left, right)`` in O(log n) time. + + The detailed procedure is as follows: + + 1. Pad the array to have length of power of 2, so that leaf nodes in the\ + segment tree have the same depth. + 2. Store the segment tree in a binary heap. + + :param int size: the size of segment tree. + :param str operation: the operation of segment tree. Choices are "sum", + "min" and "max". Default: "sum". + """ + + def __init__(self, size: int, + operation: str = 'sum') -> None: + bound = 1 + while bound < size: + bound *= 2 + self._size = size + self._bound = bound + assert operation in ['sum', 'min', 'max'], \ + f'Unknown operation {operation}.' + if operation == 'sum': + self._op, self._init_value = np.add, 0. + elif operation == 'min': + self._op, self._init_value = np.minimum, np.inf + else: + self._op, self._init_value = np.maximum, -np.inf + # assert isinstance(self._op, np.ufunc) + self._value = np.full([bound * 2], self._init_value) + + def __len__(self): + return self._size + + def __getitem__(self, index: Union[int, np.ndarray] + ) -> Union[float, np.ndarray]: + """Return self[index]""" + return self._value[index + self._bound] + + def __setitem__(self, index: Union[int, np.ndarray], + value: Union[float, np.ndarray]) -> None: + """Duplicate values in ``index`` are handled by numpy: later index + overwrites previous ones. + + :: + + >>> a = np.array([1, 2, 3, 4]) + >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7] + >>> print(a) + [6 7 3 4] + + """ + # TODO numba njit version + if isinstance(index, int): + index = np.array([index]) + assert np.all(0 <= index) and np.all(index < self._size) + if self._op is np.add: + assert np.all(0. <= value) + index = index + self._bound + self._value[index] = value + while index[0] > 1: + index //= 2 + self._value[index] = self._op( + self._value[index * 2], self._value[index * 2 + 1]) + + def reduce(self, start: Optional[int] = 0, + end: Optional[int] = None) -> float: + """Return operation(value[start:end]).""" + # TODO numba njit version + if start == 0 and end is None: + return self._value[1] + if end is None: + end = self._size + if end < 0: + end += self._size + # nodes in (start, end) should be aggregated + start, end = start + self._bound - 1, end + self._bound + result = self._init_value + while end - start > 1: # (start, end) interval is not empty + if start % 2 == 0: + result = self._op(result, self._value[start + 1]) + if end % 2 == 1: + result = self._op(result, self._value[end - 1]) + start, end = start // 2, end // 2 + return result + + def get_prefix_sum_idx( + self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]: + """Return the minimum index for each ``v`` in ``value`` so that + ``v <= sums[i]``, where sums[i] = \\sum_{j=0}^{i} arr[j]. + """ + assert self._op is np.add + assert np.all(value >= 0.) and np.all(value < self._value[1]) + single = False + if not isinstance(value, np.ndarray): + value = np.array([value]) + single = True + index = _get_prefix_sum_idx(value, self._bound, self._value) + return index.item() if single else index diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 9d7711c20..cc1a59326 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from typing import Dict, List, Union, Optional, Callable -from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy +from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ + to_torch_as, to_numpy class BasePolicy(ABC, nn.Module): @@ -213,4 +214,11 @@ def compute_nstep_return( returns = to_torch_as(returns, target_q) gammas = to_torch_as(gamma ** gammas, target_q) batch.returns = target_q * gammas + returns + # prio buffer update + if isinstance(buffer, PrioritizedReplayBuffer): + batch.update_weight = buffer.update_weight + batch.indice = indice + batch.weight = to_torch_as(batch.weight, target_q) + else: + batch.weight = torch.ones_like(target_q) return batch diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index da4833a69..2205102f2 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,7 +1,6 @@ import torch import numpy as np from copy import deepcopy -import torch.nn.functional as F from typing import Dict, Tuple, Union, Optional from tianshou.policy import BasePolicy @@ -144,7 +143,11 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() - critic_loss = F.mse_loss(current_q, target_q) + td = current_q - target_q + if hasattr(batch, 'update_weight'): # prio-buffer + batch.update_weight(batch.indice, td) + critic_loss = (td.pow(2) * batch.weight).mean() + # critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 9bf60a6c9..c37dac515 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,12 +1,10 @@ import torch import numpy as np from copy import deepcopy -import torch.nn.functional as F from typing import Dict, Union, Optional from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ - to_torch_as, to_numpy +from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy class DQNPolicy(BasePolicy): @@ -95,9 +93,6 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, batch = self.compute_nstep_return( batch, buffer, indice, self._target_q, self._gamma, self._n_step, self._rew_norm) - if isinstance(buffer, PrioritizedReplayBuffer): - batch.update_weight = buffer.update_weight - batch.indice = indice return batch def forward(self, batch: Batch, @@ -164,13 +159,11 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() - if hasattr(batch, 'update_weight'): - td = r - q - batch.update_weight(batch.indice, to_numpy(td)) - impt_weight = to_torch_as(batch.impt_weight, q) - loss = (td.pow(2) * impt_weight).mean() - else: - loss = F.mse_loss(q, r) + td = r - q + if hasattr(batch, 'update_weight'): # prio-buffer + batch.update_weight(batch.indice, td) + loss = (td.pow(2) * batch.weight).mean() + # loss = F.mse_loss(q, r) loss.backward() self.optim.step() self._cnt += 1 diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index b67a95b90..ce4a5baf0 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,7 +1,6 @@ import torch import numpy as np from copy import deepcopy -import torch.nn.functional as F from typing import Dict, Tuple, Union, Optional from tianshou.policy import DDPGPolicy @@ -141,16 +140,23 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() - critic1_loss = F.mse_loss(current_q1, target_q) + td1 = current_q1 - target_q + critic1_loss = (td1.pow(2) * batch.weight).mean() + # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() - critic2_loss = F.mse_loss(current_q2, target_q) + td2 = current_q2 - target_q + critic2_loss = (td2.pow(2) * batch.weight).mean() + # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() + # prio-buffer + if hasattr(batch, 'update_weight'): + batch.update_weight(batch.indice, (td1 + td2) / 2.) # actor obs_result = self(batch, explorating=False) a = obs_result.act diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index d90f51087..698145f1d 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,7 +1,6 @@ import torch import numpy as np from copy import deepcopy -import torch.nn.functional as F from typing import Dict, Tuple, Optional from tianshou.policy import DDPGPolicy @@ -119,16 +118,22 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() - critic1_loss = F.mse_loss(current_q1, target_q) + td1 = current_q1 - target_q + critic1_loss = (td1.pow(2) * batch.weight).mean() + # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() - critic2_loss = F.mse_loss(current_q2, target_q) + td2 = current_q2 - target_q + critic2_loss = (td2.pow(2) * batch.weight).mean() + # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() + if hasattr(batch, 'update_weight'): # prio-buffer + batch.update_weight(batch.indice, (td1 + td2) / 2.) if self._cnt % self._freq == 0: actor_loss = -self.critic1( batch.obs, self(batch, eps=0).act).mean()