diff --git a/README.md b/README.md index 1764051cd..c86c40529 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,8 @@ policy.load_state_dict(torch.load('dqn.pth')) Watch the performance with 35 FPS: ```python +policy.eval() +policy.set_eps(eps_test) collector = ts.data.Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) ``` diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 0f42f8198..32358439b 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -96,7 +96,7 @@ This is related to `Issue 42 `_. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. -This function receives typically 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a dict or a Batch. For example, you can write your hook as: +This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env reset, while every key is specified for normal steps. For example, you can write your hook as: :: import numpy as np @@ -109,9 +109,11 @@ This function receives typically 7 keys, as listed in :class:`~tianshou.data.Bat self.baseline = 0 def preprocess_fn(**kwargs): """change reward to zero mean""" + # if only obs exist -> reset + # if obs/act/rew/done/... exist -> normal step if 'rew' not in kwargs: # means that it is called after env.reset(), it can only process the obs - return {} # none of the variables are needed to be updated + return Batch() # none of the variables are needed to be updated else: n = len(kwargs['rew']) # the number of envs in collector if self.episode_log is None: @@ -125,7 +127,6 @@ This function receives typically 7 keys, as listed in :class:`~tianshou.data.Bat self.episode_log[i] = [] self.baseline = np.mean(self.main_log) return Batch(rew=kwargs['rew']) - # you can also return with {'rew': kwargs['rew']} And finally, :: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index e01760058..9655ee8e1 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -176,6 +176,8 @@ Watch the Agent's Performance :class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS: :: + policy.eval() + policy.set_eps(0.05) collector = ts.data.Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 6ab79d800..0a20bf969 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -285,6 +285,8 @@ With the above preparation, we are close to the first learned agent. The followi env = TicTacToeEnv(args.board_size, args.win_size) policy, optim = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent) + policy.eval() + policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/atari/pong_dqn.py b/examples/atari/pong_dqn.py deleted file mode 100644 index 6dda89400..000000000 --- a/examples/atari/pong_dqn.py +++ /dev/null @@ -1,108 +0,0 @@ -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import DQNPolicy -from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.discrete import DQN -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer - -from atari import create_atari_environment, preprocess_fn - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pong') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0.05) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--n-step', type=int, default=1) - parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--layer-num', type=int, default=3) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=8) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - return parser.parse_args() - - -def test_dqn(args=get_args()): - env = create_atari_environment(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.env.action_space.shape or env.env.action_space.n - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv([ - lambda: create_atari_environment(args.task) - for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([ - lambda: create_atari_environment(args.task) - for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = DQN( - args.state_shape[0], args.state_shape[1], - args.action_shape, args.device) - net = net.to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size), - preprocess_fn=preprocess_fn) - test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) - # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * 4) - print(len(train_collector.buffer)) - # log - writer = SummaryWriter(args.logdir + '/' + 'dqn') - - def stop_fn(x): - if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold - else: - return False - - def train_fn(x): - policy.set_eps(args.eps_train) - - def test_fn(x): - policy.set_eps(args.eps_test) - - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, writer=writer) - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = create_atari_environment(args.task) - collector = Collector(policy, env, preprocess_fn=preprocess_fn) - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_dqn(get_args()) diff --git a/examples/atari/atari.py b/examples/atari/runnable/atari.py similarity index 100% rename from examples/atari/atari.py rename to examples/atari/runnable/atari.py diff --git a/examples/atari/pong_a2c.py b/examples/atari/runnable/pong_a2c.py similarity index 100% rename from examples/atari/pong_a2c.py rename to examples/atari/runnable/pong_a2c.py diff --git a/examples/atari/pong_ppo.py b/examples/atari/runnable/pong_ppo.py similarity index 100% rename from examples/atari/pong_ppo.py rename to examples/atari/runnable/pong_ppo.py diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index e3de12de7..6345d62eb 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -102,9 +102,12 @@ def test_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 31b83f43b..a92963d83 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -44,6 +44,7 @@ def get_args(): class EnvWrapper(object): """Env wrapper for reward scale, action repeat and action noise""" + def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3): self._env = gym.make(task) @@ -71,10 +72,11 @@ def step(self, action): def test_sac_bipedal(args=get_args()): torch.set_num_threads(1) # we just need only one thread for NN + env = EnvWrapper(args.task) + def IsStop(reward): - return reward >= 300 * 5 + return reward >= env.spec.reward_threshold - env = EnvWrapper(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] @@ -82,8 +84,8 @@ def IsStop(reward): train_envs = SubprocVectorEnv( [lambda: EnvWrapper(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: EnvWrapper(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: EnvWrapper(args.task, reward_scale=1) + for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -138,9 +140,11 @@ def save_fn(policy): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = EnvWrapper(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=16, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 0e66c65f7..aa0f5888c 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -99,9 +99,12 @@ def test_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/box2d/sac_mcc.py b/examples/box2d/mcc_sac.py similarity index 95% rename from examples/box2d/sac_mcc.py rename to examples/box2d/mcc_sac.py index 845ffcd7b..6e09e6c1f 100644 --- a/examples/box2d/sac_mcc.py +++ b/examples/box2d/mcc_sac.py @@ -112,9 +112,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/ant_v2_ddpg.py b/examples/mujoco/ant_v2_ddpg.py index ef7ea6c42..db65f582a 100644 --- a/examples/mujoco/ant_v2_ddpg.py +++ b/examples/mujoco/ant_v2_ddpg.py @@ -88,9 +88,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/ant_v2_sac.py b/examples/mujoco/ant_v2_sac.py index 402784f28..108be79e3 100644 --- a/examples/mujoco/ant_v2_sac.py +++ b/examples/mujoco/ant_v2_sac.py @@ -98,9 +98,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/ant_v2_td3.py b/examples/mujoco/ant_v2_td3.py index fad3f911c..db59e18d5 100644 --- a/examples/mujoco/ant_v2_td3.py +++ b/examples/mujoco/ant_v2_td3.py @@ -98,9 +98,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/halfcheetahBullet_v0_sac.py b/examples/mujoco/halfcheetahBullet_v0_sac.py index 8f1a103e4..3aec4f85a 100644 --- a/examples/mujoco/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/halfcheetahBullet_v0_sac.py @@ -104,9 +104,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/examples/mujoco/point_maze_td3.py b/examples/mujoco/point_maze_td3.py index 42e91146c..1f4a217ef 100644 --- a/examples/mujoco/point_maze_td3.py +++ b/examples/mujoco/point_maze_td3.py @@ -104,9 +104,11 @@ def stop_fn(x): if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_step=1000, render=args.render) + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/setup.py b/setup.py index 64aac40b2..d7487e269 100644 --- a/setup.py +++ b/setup.py @@ -3,18 +3,10 @@ from setuptools import setup, find_packages -import re -from os import path - -here = path.abspath(path.dirname(__file__)) - -# Get the version string -with open(path.join(here, 'tianshou', '__init__.py')) as f: - version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) setup( name='tianshou', - version=version, + version='0.2.6', description='A Library for Deep Reinforcement Learning', long_description=open('README.md', encoding='utf8').read(), long_description_content_type='text/markdown', diff --git a/test/base/env.py b/test/base/env.py index cc0991072..f0907f8e1 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -21,6 +21,8 @@ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, self.recurse_state = recurse_state self.ma_rew = ma_rew self._md_action = multidiscrete_action + # how many steps this env has stepped + self.steps = 0 if dict_state: self.observation_space = Dict( {"index": Box(shape=(1, ), low=0, high=size - 1), @@ -74,6 +76,7 @@ def _get_state(self): return np.array([self.index], dtype=np.float32) def step(self, action): + self.steps += 1 if self._md_action: action = action[0] if self.done: diff --git a/test/base/test_batch.py b/test/base/test_batch.py index a823491b4..650d56080 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -26,7 +26,9 @@ def test_batch(): assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) - assert b.a == 3 + assert 'a' in b and b.a == 3 + assert b.pop('a') == 3 + assert 'a' not in b with pytest.raises(AssertionError): Batch({1: 2}) with pytest.raises(TypeError): @@ -41,6 +43,8 @@ def test_batch(): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) + batch.cat_(batch) + assert torch.allclose(batch.a, torch.ones(4, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] @@ -60,6 +64,28 @@ def test_batch(): with pytest.raises(AttributeError): b.obs print(batch) + batch = Batch(a=np.arange(10)) + with pytest.raises(AssertionError): + list(batch.split(0)) + data = [ + (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), + (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), + (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), + (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), + (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), + (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), + (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), + (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), + ] + for size, merge_last, result in data: + bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) + assert [bs[i].a.tolist() for i in range(len(bs))] == result batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) @@ -75,7 +101,9 @@ def test_batch(): assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] + assert Batch(a=set((1, 2, 1))).shape == [] assert batch2.shape[0] == 1 + assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): @@ -96,15 +124,18 @@ def test_batch(): assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e - for batch_slice in [ - batch2[slice(0, 1)], batch2[:1], batch2[0:]]: + for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e + batch2.a.d.f = {} batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 + assert batch2_sum.a.d.f.is_empty() + with pytest.raises(TypeError): + batch2 += [1] batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) @@ -171,6 +202,11 @@ def test_batch_over_batch(): batch5[:, -1] += 1 assert np.allclose(batch5.a, [1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) + with pytest.raises(ValueError): + batch5[:, -1] = 1 + batch5[:, 0] = {'a': -1} + assert np.allclose(batch5.a, [-1, 3]) + assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) def test_batch_cat_and_stack(): @@ -199,9 +235,9 @@ def test_batch_cat_and_stack(): assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert ans.a.t.is_empty() - b12_stack = Batch.stack((b1, b2)) - assert isinstance(b12_stack.a.d.e, np.ndarray) - assert b12_stack.a.d.e.ndim == 2 + assert b1.stack_([b2]) is None + assert isinstance(b1.a.d.e, np.ndarray) + assert b1.a.d.e.ndim == 2 # test cat with incompatible keys b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) @@ -297,6 +333,16 @@ def test_batch_cat_and_stack(): assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) + # exceptions + assert Batch.cat([]).is_empty() + assert Batch.stack([]).is_empty() + b1 = Batch(e=[4, 5], d=6) + b2 = Batch(e=[4, 6]) + with pytest.raises(ValueError): + Batch.cat([b1, b2]) + with pytest.raises(ValueError): + Batch.stack([b1, b2], axis=1) + def test_batch_over_batch_to_torch(): batch = Batch( @@ -306,17 +352,21 @@ def test_batch_over_batch_to_torch(): d=torch.ones((1,), dtype=torch.float64) ) ) + batch.b.__dict__['e'] = 1 # bypass the check batch.to_torch() assert isinstance(batch.a, torch.Tensor) assert isinstance(batch.b.c, torch.Tensor) assert isinstance(batch.b.d, torch.Tensor) + assert isinstance(batch.b.e, torch.Tensor) assert batch.a.dtype == torch.float64 assert batch.b.c.dtype == torch.float32 assert batch.b.d.dtype == torch.float64 + assert batch.b.e.dtype == torch.int64 batch.to_torch(dtype=torch.float32) assert batch.a.dtype == torch.float32 assert batch.b.c.dtype == torch.float32 assert batch.b.d.dtype == torch.float32 + assert batch.b.e.dtype == torch.float32 def test_utils_to_torch_numpy(): @@ -347,13 +397,13 @@ def test_utils_to_torch_numpy(): assert isinstance(data_list_3_torch, list) assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch) assert all(starmap(np.allclose, - zip(to_numpy(to_torch(data_list_3)), data_list_3))) + zip(to_numpy(to_torch(data_list_3)), data_list_3))) data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))] data_list_4_torch = to_torch(data_list_4) assert isinstance(data_list_4_torch, list) assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch) assert all(starmap(np.allclose, - zip(to_numpy(to_torch(data_list_4)), data_list_4))) + zip(to_numpy(to_torch(data_list_4)), data_list_4))) data_list_5 = [np.zeros(2), np.zeros((3, 3))] data_list_5_torch = to_torch(data_list_5) assert isinstance(data_list_5_torch, list) @@ -366,6 +416,22 @@ def test_utils_to_torch_numpy(): assert isinstance(data_empty_array, np.ndarray) assert data_empty_array.shape == (0, 2, 2) assert np.allclose(to_numpy(to_torch(data_array)), data_array) + # additional test for to_numpy, for code-coverage + assert isinstance(to_numpy(1), np.ndarray) + assert isinstance(to_numpy(1.), np.ndarray) + assert isinstance(to_numpy({'a': torch.tensor(1)})['a'], np.ndarray) + assert isinstance(to_numpy(Batch(a=torch.tensor(1))).a, np.ndarray) + assert to_numpy(None).item() is None + assert to_numpy(to_numpy).item() == to_numpy + # additional test for to_torch, for code-coverage + assert isinstance(to_torch(1), torch.Tensor) + assert to_torch(1).dtype == torch.int64 + assert to_torch(1.).dtype == torch.float64 + assert isinstance(to_torch({'a': [1]})['a'], torch.Tensor) + with pytest.raises(TypeError): + to_torch(None) + with pytest.raises(TypeError): + to_torch(np.array([{}, '2'])) def test_batch_pickle(): diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 393534c03..16bf5c34f 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -16,13 +16,19 @@ def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) + buf.update(buf) + assert str(buf) == buf.__class__.__name__ + '()' obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(obs, a, rew, done, obs_next, info) + buf.add(obs, [a], rew, done, obs_next, info) obs = obs_next assert len(buf) == min(bufsize, i + 1) + with pytest.raises(ValueError): + buf._add_to_buffer('rew', np.array([1, 2, 3])) + assert buf.act.dtype == np.object + assert isinstance(buf.act[0], list) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() @@ -37,6 +43,11 @@ def test_replaybuffer(size=10, bufsize=20): assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert np.all(b.info.b.c[1:] == 0.0) + with pytest.raises(IndexError): + b[22] + b = ListReplayBuffer() + with pytest.raises(NotImplementedError): + b.sample(0) def test_ignore_obs_next(size=10): @@ -89,14 +100,16 @@ def test_stack(size=5, bufsize=9, stack_num=4): if done: obs = env.reset(1) indice = np.arange(len(buf)) - assert np.allclose(buf.get(indice, 'obs'), np.expand_dims( - [[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], - [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1)) + assert np.allclose(buf.get(indice, 'obs')[..., 0], [ + [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], + [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], + [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) assert indice in [2, 6] + with pytest.raises(IndexError): + buf[bufsize * 2] def test_priortized_replaybuffer(size=32, bufsize=15): @@ -139,6 +152,7 @@ def test_segtree(): # small test actual_len = 8 tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes + assert len(tree) == actual_len assert np.all([tree[i] == init for i in range(actual_len)]) with pytest.raises(IndexError): tree[actual_len] @@ -154,6 +168,8 @@ def test_segtree(): ref = realop(naive[i:j]) out = tree.reduce(i, j) assert np.allclose(ref, out) + assert np.allclose(tree.reduce(start=1), realop(naive[1:])) + assert np.allclose(tree.reduce(end=-1), realop(naive[:-1])) # batch setitem for _ in range(1000): index = np.random.choice(actual_len, size=4) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 1026c9407..217531611 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -42,30 +42,30 @@ def __init__(self, writer): def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb - # if info is not provided from env, it will be a ``Batch()``. - if not kwargs.get('info', Batch()).is_empty(): + # if only obs exist -> reset + # if obs/act/rew/done/... exist -> normal step + if 'rew' in kwargs: n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) - self.writer.add_scalar('key', np.mean( - info['key']), global_step=self.cnt) + if 'key' in info.keys(): + self.writer.add_scalar('key', np.mean( + info['key']), global_step=self.cnt) self.cnt += 1 return Batch(info=info) - # or: return {'info': info} else: return Batch() @staticmethod def single_preprocess_fn(**kwargs): # same as above, without tfb - if not kwargs.get('info', Batch()).is_empty(): + if 'rew' in kwargs: n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) return Batch(info=info) - # or: return {'info': info} else: return Batch() @@ -82,43 +82,64 @@ def test_collector(): c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) c0.collect(n_step=3) - assert np.allclose(c0.buffer.obs[:4], - np.expand_dims([0, 1, 0, 1], axis=-1)) - assert np.allclose(c0.buffer[:4].obs_next, - np.expand_dims([1, 2, 1, 2], axis=-1)) + assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 1]) + assert np.allclose(c0.buffer[:4].obs_next[..., 0], [1, 2, 1, 2]) c0.collect(n_episode=3) - assert np.allclose(c0.buffer.obs[:10], - np.expand_dims([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], axis=-1)) - assert np.allclose(c0.buffer[:10].obs_next, - np.expand_dims([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], axis=-1)) + assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) + assert np.allclose(c0.buffer[:10].obs_next[..., 0], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) c1.collect(n_step=6) - assert np.allclose(c1.buffer.obs[:11], np.expand_dims( - [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3], axis=-1)) - assert np.allclose(c1.buffer[:11].obs_next, np.expand_dims( - [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4], axis=-1)) + assert np.allclose(c1.buffer.obs[:11, 0], + [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) + assert np.allclose(c1.buffer[:11].obs_next[..., 0], + [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) c1.collect(n_episode=2) - assert np.allclose(c1.buffer.obs[11:21], - np.expand_dims([0, 1, 2, 3, 4, 0, 1, 0, 1, 2], axis=-1)) - assert np.allclose(c1.buffer[11:21].obs_next, - np.expand_dims([1, 2, 3, 4, 5, 1, 2, 1, 2, 3], axis=-1)) + assert np.allclose(c1.buffer.obs[11:21, 0], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) + assert np.allclose(c1.buffer[11:21].obs_next[..., 0], + [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) c1.collect(n_episode=3, random=True) c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False), logger.preprocess_fn) c2.collect(n_episode=[1, 2, 2, 2]) - assert np.allclose(c2.buffer.obs_next[:26], np.expand_dims([ + assert np.allclose(c2.buffer.obs_next[:26, 0], [ 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, - 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], axis=-1)) + 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) c2.reset_env() c2.collect(n_episode=[2, 2, 2, 2]) - assert np.allclose(c2.buffer.obs_next[26:54], np.expand_dims([ + assert np.allclose(c2.buffer.obs_next[26:54, 0], [ 1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5, - 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], axis=-1)) + 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) c2.collect(n_episode=[1, 1, 1, 1], random=True) +def test_collector_with_exact_episodes(): + env_lens = [2, 6, 3, 10] + writer = SummaryWriter('log/exact_collector') + logger = Logger(writer) + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True) + for i in env_lens] + + venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) + policy = MyPolicy() + c1 = Collector(policy, venv, + ReplayBuffer(size=1000, ignore_obs_next=False), + logger.preprocess_fn) + n_episode1 = [2, 0, 5, 1] + n_episode2 = [1, 3, 2, 0] + c1.collect(n_episode=n_episode1) + expected_steps = sum([a * b for a, b in zip(env_lens, n_episode1)]) + actual_steps = sum(venv.steps) + assert expected_steps == actual_steps + c1.collect(n_episode=n_episode2) + expected_steps = sum( + [a * (b + c) for a, b, c in zip(env_lens, n_episode1, n_episode2)]) + actual_steps = sum(venv.steps) + assert expected_steps == actual_steps + + def test_collector_with_async(): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') @@ -185,10 +206,10 @@ def test_collector_with_dict_state(): batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) - assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([ + assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index[..., 0], [ 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., - 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.], axis=-1)) + 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), Logger.single_preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) @@ -219,11 +240,10 @@ def reward_metric(x): batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) - obs = np.array(np.expand_dims([ + assert np.allclose(c0.buffer[:len(c0.buffer)].obs[..., 0], [ 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., - 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.], axis=-1)) - assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs) + 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1] @@ -241,3 +261,4 @@ def reward_metric(x): test_collector_with_dict_state() test_collector_with_ma() test_collector_with_async() + test_collector_with_exact_episodes() diff --git a/test/base/test_env.py b/test/base/test_env.py index 96de70236..6f67df4b2 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -48,6 +48,7 @@ def test_async_env(size=10000, num=8, sleep=0.1): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num // 2, timeout=1e-3) + v.seed(None) v.reset() # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un} # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1} diff --git a/test/base/test_returns.py b/test/base/test_returns.py new file mode 100644 index 000000000..664968541 --- /dev/null +++ b/test/base/test_returns.py @@ -0,0 +1,104 @@ +import time +import torch +import numpy as np + +from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer + + +def compute_episodic_return_base(batch, gamma): + returns = np.zeros_like(batch.rew) + last = 0 + for i in reversed(range(len(batch.rew))): + returns[i] = batch.rew[i] + if not batch.done[i]: + returns[i] += last * gamma + last = returns[i] + batch.returns = returns + return batch + + +def test_episodic_returns(size=2560): + fn = BasePolicy.compute_episodic_return + batch = Batch( + done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), + rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), + ) + batch = fn(batch, None, gamma=.1, gae_lambda=1) + ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) + assert np.allclose(batch.returns, ans) + batch = Batch( + done=np.array([0, 1, 0, 1, 0, 1, 0.]), + rew=np.array([7, 6, 1, 2, 3, 4, 5.]), + ) + batch = fn(batch, None, gamma=.1, gae_lambda=1) + ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) + assert np.allclose(batch.returns, ans) + batch = Batch( + done=np.array([0, 1, 0, 1, 0, 0, 1.]), + rew=np.array([7, 6, 1, 2, 3, 4, 5.]), + ) + batch = fn(batch, None, gamma=.1, gae_lambda=1) + ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) + assert np.allclose(batch.returns, ans) + batch = Batch( + done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), + rew=np.array([ + 101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]) + ) + v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) + ret = fn(batch, v, gamma=0.99, gae_lambda=0.95) + returns = np.array([ + 454.8344, 376.1143, 291.298, 200., + 464.5610, 383.1085, 295.387, 201., + 474.2876, 390.1027, 299.476, 202.]) + assert np.allclose(ret.returns, returns) + if __name__ == '__main__': + batch = Batch( + done=np.random.randint(100, size=size) == 0, + rew=np.random.random(size), + ) + cnt = 3000 + t = time.time() + for _ in range(cnt): + compute_episodic_return_base(batch, gamma=.1) + print(f'vanilla: {(time.time() - t) / cnt}') + t = time.time() + for _ in range(cnt): + fn(batch, None, gamma=.1, gae_lambda=1) + print(f'policy: {(time.time() - t) / cnt}') + + +def target_q_fn(buffer, indice): + # return the next reward + indice = (indice + 1 - buffer.done[indice]) % len(buffer) + return torch.tensor(-buffer.rew[indice], dtype=torch.float32) + + +def test_nstep_returns(): + buf = ReplayBuffer(10) + for i in range(12): + buf.add(obs=0, act=0, rew=i + 1, done=i % 4 == 3) + batch, indice = buf.sample(0) + assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) + # rew: [10, 11, 2, 3, 4, 5, 6, 7, 8, 9] + # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] + # test nstep = 1 + returns = BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=1).pop('returns') + assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) + # test nstep = 2 + returns = BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns') + assert np.allclose(returns, [ + 3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + # test nstep = 10 + returns = BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns') + assert np.allclose(returns, [ + 3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) + + +if __name__ == '__main__': + test_nstep_returns() + test_episodic_returns() diff --git a/test/base/test_utils.py b/test/base/test_utils.py new file mode 100644 index 000000000..5944bfbe5 --- /dev/null +++ b/test/base/test_utils.py @@ -0,0 +1,68 @@ +import torch +import numpy as np + +from tianshou.utils import MovAvg +from tianshou.exploration import GaussianNoise, OUNoise +from tianshou.utils.net.common import Net +from tianshou.utils.net.discrete import DQN +from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic + + +def test_noise(): + noise = GaussianNoise() + size = (3, 4, 5) + assert np.allclose(noise(size).shape, size) + noise = OUNoise() + noise.reset() + assert np.allclose(noise(size).shape, size) + + +def test_moving_average(): + stat = MovAvg(10) + assert np.allclose(stat.get(), 0) + assert np.allclose(stat.mean(), 0) + assert np.allclose(stat.std() ** 2, 0) + stat.add(torch.tensor([1])) + stat.add(np.array([2])) + stat.add([3, 4]) + stat.add(5.) + assert np.allclose(stat.get(), 3) + assert np.allclose(stat.mean(), 3) + assert np.allclose(stat.std() ** 2, 2) + + +def test_net(): + # here test the networks that does not appear in the other script + bsz = 64 + # common net + state_shape = (10, 2) + action_shape = (5, ) + data = torch.rand([bsz, *state_shape]) + expect_output_shape = [bsz, *action_shape] + net = Net(3, state_shape, action_shape, norm_layer=torch.nn.LayerNorm) + assert list(net(data)[0].shape) == expect_output_shape + net = Net(3, state_shape, action_shape, dueling=(2, 2)) + assert list(net(data)[0].shape) == expect_output_shape + # recurrent actor/critic + data = data.flatten(1) + net = RecurrentActorProb(3, state_shape, action_shape) + mu, sigma = net(data)[0] + assert mu.shape == sigma.shape + assert list(mu.shape) == [bsz, 5] + net = RecurrentCritic(3, state_shape, action_shape) + data = torch.rand([bsz, 8, np.prod(state_shape)]) + act = torch.rand(expect_output_shape) + assert list(net(data, act).shape) == [bsz, 1] + # DQN + state_shape = (4, 84, 84) + action_shape = (6, ) + data = np.random.rand(bsz, *state_shape) + expect_output_shape = [bsz, *action_shape] + net = DQN(*state_shape, action_shape) + assert list(net(data)[0].shape) == expect_output_shape + + +if __name__ == '__main__': + test_noise() + test_moving_average() + test_net() diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 457fcd592..5d8a7bc82 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -108,6 +108,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index ed42e7901..eba53789e 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -123,6 +123,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index dffebc70e..cab46a9c1 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -109,11 +109,13 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') # here we define an imitation collector with a trivial policy + policy.eval() if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal net = Actor(Net(1, args.state_shape, device=args.device), @@ -136,6 +138,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index d2b95421e..e3a325f7e 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -113,6 +113,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index d99bc1448..3eafd0e42 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -98,10 +98,12 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') + policy.eval() # here we define an imitation collector with a trivial policy if args.task == 'CartPole-v0': env.spec.reward_threshold = 190 # lower the goal @@ -124,6 +126,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index aeb849f41..bcf193ff9 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -114,6 +114,8 @@ def test_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index c4d976715..e403c21a1 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -16,13 +16,13 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--stack-num', type=int, default=4) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) @@ -100,6 +100,7 @@ def test_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index ee934a340..3604adbc6 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,93 +1,25 @@ import os import gym -import time import torch import pprint import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.net.common import Net -from tianshou.env import DummyVectorEnv from tianshou.policy import PGPolicy +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Batch, Collector, ReplayBuffer - - -def compute_return_base(batch, aa=None, bb=None, gamma=0.1): - returns = np.zeros_like(batch.rew) - last = 0 - for i in reversed(range(len(batch.rew))): - returns[i] = batch.rew[i] - if not batch.done[i]: - returns[i] += last * gamma - last = returns[i] - batch.returns = returns - return batch - - -def test_fn(size=2560): - policy = PGPolicy(None, None, None, discount_factor=0.1) - buf = ReplayBuffer(100) - buf.add(1, 1, 1, 1, 1) - fn = policy.process_fn - # fn = compute_return_base - batch = Batch( - done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), - rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), - ) - batch = fn(batch, buf, 0) - ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) - assert np.allclose(batch.returns, ans) - batch = Batch( - done=np.array([0, 1, 0, 1, 0, 1, 0.]), - rew=np.array([7, 6, 1, 2, 3, 4, 5.]), - ) - batch = fn(batch, buf, 0) - ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) - assert np.allclose(batch.returns, ans) - batch = Batch( - done=np.array([0, 1, 0, 1, 0, 0, 1.]), - rew=np.array([7, 6, 1, 2, 3, 4, 5.]), - ) - batch = fn(batch, buf, 0) - ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) - assert np.allclose(batch.returns, ans) - batch = Batch( - done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), - rew=np.array([ - 101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]) - ) - v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) - ret = policy.compute_episodic_return(batch, v, gamma=0.99, gae_lambda=0.95) - returns = np.array([ - 454.8344, 376.1143, 291.298, 200., - 464.5610, 383.1085, 295.387, 201., - 474.2876, 390.1027, 299.476, 202.]) - assert np.allclose(ret.returns, returns) - if __name__ == '__main__': - batch = Batch( - done=np.random.randint(100, size=size) == 0, - rew=np.random.random(size), - ) - cnt = 3000 - t = time.time() - for _ in range(cnt): - compute_return_base(batch) - print(f'vanilla: {(time.time() - t) / cnt}') - t = time.time() - for _ in range(cnt): - policy.process_fn(batch, buf, 0) - print(f'policy: {(time.time() - t) / cnt}') +from tianshou.data import Collector, ReplayBuffer def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) parser.add_argument('--collect-per-step', type=int, default=10) @@ -155,11 +87,11 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') if __name__ == '__main__': - # test_fn() test_pg() diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 515e2f225..0c52c899a 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -112,6 +112,7 @@ def stop_fn(x): pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) + policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 5422c6e3b..96383ae3b 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -169,6 +169,8 @@ def watch(args: argparse.Namespace = get_args(), env = TicTacToeEnv(args.board_size, args.win_size) policy, optim = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent) + policy.eval() + policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index b37444681..3134004f1 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -28,9 +28,7 @@ def data(): def test_init(): for _ in np.arange(1e5): _ = ReplayBuffer(1e5) - _ = PrioritizedReplayBuffer( - size=int(1e5), alpha=0.5, - beta=0.5, repeat_sample=True) + _ = PrioritizedReplayBuffer(size=int(1e5), alpha=0.5, beta=0.5) _ = ListReplayBuffer() diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index f9d8a3e4e..21260f5ec 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -22,8 +22,7 @@ def __init__(self): def reset(self): self._index = 0 self.done = np.random.randint(3, high=200) - return {'observable': np.zeros((10, 10, 1)), - 'hidden': self._index} + return {'observable': np.zeros((10, 10, 1)), 'hidden': self._index} def step(self, action): if self._index == self.done: @@ -56,11 +55,9 @@ def data(): np.random.seed(0) env = SimpleEnv() env.seed(0) - env_vec = DummyVectorEnv( - [lambda: SimpleEnv() for _ in range(100)]) + env_vec = DummyVectorEnv([lambda: SimpleEnv() for _ in range(100)]) env_vec.seed(np.random.randint(1000, size=100).tolist()) - env_subproc = SubprocVectorEnv( - [lambda: SimpleEnv() for _ in range(8)]) + env_subproc = SubprocVectorEnv([lambda: SimpleEnv() for _ in range(8)]) env_subproc.seed(np.random.randint(1000, size=100).tolist()) env_subproc_init = SubprocVectorEnv( [lambda: SimpleEnv() for _ in range(8)]) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index d73c93fc0..77016dc85 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,6 +1,7 @@ from tianshou import data, env, utils, policy, trainer, \ exploration + __version__ = '0.2.6' __all__ = [ 'env', diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index ae07023ea..fe70d1f7a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -5,7 +5,8 @@ from copy import deepcopy from numbers import Number from collections.abc import Collection -from typing import Any, List, Tuple, Union, Iterator, Optional +from typing import Any, List, Tuple, Union, Iterator, KeysView, ValuesView, \ + ItemsView, Optional # Disable pickle warning related to torch, since it has been removed # on torch master branch. See Pull Request #39003 for details: @@ -18,14 +19,14 @@ def _is_batch_set(data: Any) -> bool: # Batch set is a list/tuple of dict/Batch objects, # or 1-D np.ndarray with np.object type, # where each element is a dict/Batch object - if isinstance(data, (list, tuple)): - if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): - return True - elif isinstance(data, np.ndarray) and data.dtype == np.object: + if isinstance(data, np.ndarray): # most often case # ``for e in data`` will just unpack the first dimension, # but data.tolist() will flatten ndarray of objects # so do not use data.tolist() - if all(isinstance(e, (dict, Batch)) for e in data): + return data.dtype == np.object and \ + all(isinstance(e, (dict, Batch)) for e in data) + elif isinstance(data, (list, tuple)): + if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): return True return False @@ -48,13 +49,14 @@ def _is_number(value: Any) -> bool: # isinstance(value, Number) checks 1, 1.0, np.int(1), np.float(1.0), etc. # isinstance(value, np.nummber) checks np.int32(1), np.float64(1.0), etc. # isinstance(value, np.bool_) checks np.bool_(True), etc. - is_number = isinstance(value, Number) - is_number = is_number or isinstance(value, np.number) - is_number = is_number or isinstance(value, np.bool_) - return is_number + # similar to np.isscalar but np.isscalar('st') returns True + return isinstance(value, (Number, np.number, np.bool_)) def _to_array_with_correct_type(v: Any) -> np.ndarray: + if isinstance(v, np.ndarray) and \ + issubclass(v.dtype.type, (np.bool_, np.number)): # most often case + return v # convert the value to np.ndarray # convert to np.object data type if neither bool nor number # raises an exception if array's elements are tensors themself @@ -85,6 +87,7 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) is_scalar = _is_scalar(inst) if not stack and is_scalar: + # should never hit since it has already checked in Batch.cat_ # here we do not consider scalar types, following the behavior of numpy # which does not support concatenation of zero-dimensional arrays # (scalars) @@ -101,9 +104,7 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ dtype=target_type) elif isinstance(inst, torch.Tensor): return torch.full(shape, - fill_value=0, - device=inst.device, - dtype=inst.dtype) + fill_value=0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): @@ -115,17 +116,22 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ return np.array([None for _ in range(size)]) -def _assert_type_keys(keys): - keys = list(keys) +def _assert_type_keys(keys) -> None: assert all(isinstance(e, str) for e in keys), \ f"keys should all be string, but got {keys}" def _parse_value(v: Any): - if isinstance(v, dict): - v = Batch(v) - elif isinstance(v, (Batch, torch.Tensor)): - pass + if isinstance(v, Batch): # most often case + return v + elif (isinstance(v, np.ndarray) and + issubclass(v.dtype.type, (np.bool_, np.number))) or \ + isinstance(v, torch.Tensor) or v is None: # third often case + return v + elif _is_number(v): # second often case, but it is more time-consuming + return np.asanyarray(v) + elif isinstance(v, dict): + return Batch(v) else: if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \ len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v): @@ -134,18 +140,17 @@ def _parse_value(v: Any): except RuntimeError as e: raise TypeError("Batch does not support non-stackable iterable" " of torch.Tensor as unique value yet.") from e - try: - v_ = _to_array_with_correct_type(v) - except ValueError as e: - raise TypeError("Batch does not support heterogeneous list/tuple" - " of tensors as unique value yet.") from e if _is_batch_set(v): v = Batch(v) # list of dict / Batch else: # None, scalar, normal data list (main case) # or an actual list of objects - v = v_ - return v + try: + v = _to_array_with_correct_type(v) + except ValueError as e: + raise TypeError("Batch does not support heterogeneous list/" + "tuple of tensors as unique value yet.") from e + return v class Batch: @@ -155,6 +160,7 @@ class Batch: For a detailed description, please refer to :ref:`batch_concept`. """ + def __init__(self, batch_dict: Optional[Union[ dict, 'Batch', Tuple[Union[dict, 'Batch']], @@ -173,11 +179,11 @@ def __init__(self, if len(kwargs) > 0: self.__init__(kwargs, copy=copy) - def __setattr__(self, key: str, value: Any): + def __setattr__(self, key: str, value: Any) -> None: """self.key = value""" self.__dict__[key] = _parse_value(value) - def __getstate__(self): + def __getstate__(self) -> dict: """Pickling interface. Only the actual data are serialized for both efficiency and simplicity. """ @@ -188,7 +194,7 @@ def __getstate__(self): state[k] = v return state - def __setstate__(self, state): + def __setstate__(self, state) -> None: """Unpickling interface. At this point, self is an empty Batch instance that has not been initialized, so it can safely be initialized by the pickle state. @@ -216,13 +222,13 @@ def __setitem__(self, index: Union[ str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: """Assign value to self[index].""" + value = _parse_value(value) if isinstance(index, str): - self.__dict__[index] = _parse_value(value) + self.__dict__[index] = value return - value = _parse_value(value) if isinstance(value, (np.ndarray, torch.Tensor)): - raise ValueError("Batch does not supported tensor assignment." - " Use a compatible Batch or dict instead.") + raise ValueError("Batch does not supported tensor assignment. " + "Use a compatible Batch or dict instead.") if not set(value.keys()).issubset(self.__dict__.keys()): raise KeyError( "Creating keys is not supported by item assignment.") @@ -298,7 +304,7 @@ def __repr__(self) -> str: """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False - for k, v in self.items(): + for k, v in self.__dict__.items(): rpl = '\n' + ' ' * (6 + len(k)) obj = pprint.pformat(v).replace('\n', rpl) s += f' {k}: {obj},\n' @@ -309,22 +315,32 @@ def __repr__(self) -> str: s = self.__class__.__name__ + '()' return s - def keys(self) -> List[str]: + def __contains__(self, key: str) -> bool: + """Return key in self.""" + return key in self.__dict__ + + def keys(self) -> KeysView[str]: """Return self.keys().""" return self.__dict__.keys() - def values(self) -> List[Any]: + def values(self) -> ValuesView[Any]: """Return self.values().""" return self.__dict__.values() - def items(self) -> List[Tuple[str, Any]]: + def items(self) -> ItemsView[str, Any]: """Return self.items().""" return self.__dict__.items() - def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: + def get(self, k: str, d: Optional[Any] = None) -> Any: """Return self[k] if k in self else d. d defaults to None.""" return self.__dict__.get(k, d) + def pop(self, k: str, d: Optional[Any] = None) -> Any: + """Return and remove self[k] if k in self else d. d defaults to + None. + """ + return self.__dict__.pop(k, d) + def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray. This is an in-place operation. @@ -364,7 +380,7 @@ def to_torch(self, dtype: Optional[torch.dtype] = None, self.__dict__[k] = v def __cat(self, - batches: Union['Batch', List[Union[dict, 'Batch']]], + batches: List[Union[dict, 'Batch']], lens: List[int]) -> None: """:: @@ -395,7 +411,6 @@ def __cat(self, for batch in batches] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] - _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): batch_holder = Batch() @@ -407,11 +422,9 @@ def __cat(self, # cat Batch(a=np.zeros((3, 4))) and Batch(a=Batch(b=Batch())) # will fail here v = np.concatenate(v) - v = _to_array_with_correct_type(v) - self.__dict__[k] = v + self.__dict__[k] = _to_array_with_correct_type(v) keys_total = set.union(*[set(b.keys()) for b in batches]) keys_reserve_or_partial = set.difference(keys_total, keys_shared) - _assert_type_keys(keys_reserve_or_partial) # keys that are reserved in all batches keys_reserve = set.difference(keys_total, set.union(*keys_map)) # keys that occur only in some batches, but not all @@ -429,8 +442,8 @@ def __cat(self, try: self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val except KeyError: - self.__dict__[k] = \ - _create_value(val, sum_lens[-1], stack=False) + self.__dict__[k] = _create_value( + val, sum_lens[-1], stack=False) self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val def cat_(self, @@ -453,11 +466,10 @@ def cat_(self, lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches] except TypeError as e: - e2 = ValueError( - f'Batch.cat_ meets an exception. Maybe because there is ' - f'any scalar in {batches} but Batch.cat_ does not support' - f'the concatenation of scalar.') - raise Exception([e, e2]) + raise ValueError( + f'Batch.cat_ meets an exception. Maybe because there is any ' + f'scalar in {batches} but Batch.cat_ does not support the ' + f'concatenation of scalar.') from e if not self.is_empty(): batches = [self] + list(batches) lens = [0 if self.is_empty(recurse=True) else len(self)] + lens @@ -503,16 +515,14 @@ def stack_(self, for batch in batches] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] - _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): - if all(isinstance(e, (dict, Batch)) for e in v): - self.__dict__[k] = Batch.stack(v, axis) - elif all(isinstance(e, torch.Tensor) for e in v): + if all(isinstance(e, torch.Tensor) for e in v): # second often self.__dict__[k] = torch.stack(v, axis) - else: + elif all(isinstance(e, (Batch, dict)) for e in v): # third often + self.__dict__[k] = Batch.stack(v, axis) + else: # most often case is np.ndarray v = np.stack(v, axis) - v = _to_array_with_correct_type(v) - self.__dict__[k] = v + self.__dict__[k] = _to_array_with_correct_type(v) # all the keys keys_total = set.union(*[set(b.keys()) for b in batches]) # keys that are reserved in all batches @@ -525,7 +535,6 @@ def stack_(self, raise ValueError( f"Stack of Batch with non-shared keys {keys_partial} " f"is only supported with axis=0, but got axis={axis}!") - _assert_type_keys(keys_reserve_or_partial) for k in keys_reserve: # reserved keys self.__dict__[k] = Batch() @@ -539,8 +548,7 @@ def stack_(self, try: self.__dict__[k][i] = val except KeyError: - self.__dict__[k] = \ - _create_value(val, len(batches)) + self.__dict__[k] = _create_value(val, len(batches)) self.__dict__[k][i] = val @staticmethod @@ -597,17 +605,17 @@ def empty_(self, index: Union[ ) """ for k, v in self.items(): - if v is None: - continue - if isinstance(v, Batch): - self.__dict__[k].empty_(index=index) - elif isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor): # most often case self.__dict__[k][index] = 0 + elif v is None: + continue elif isinstance(v, np.ndarray): if v.dtype == np.object: self.__dict__[k][index] = None else: self.__dict__[k][index] = 0 + elif isinstance(v, Batch): + self.__dict__[k].empty_(index=index) else: # scalar value warnings.warn('You are calling Batch.empty on a NumPy scalar, ' 'which may cause undefined behaviors.') @@ -633,10 +641,8 @@ def update(self, batch: Optional[Union[dict, 'Batch']] = None, if batch is None: self.update(kwargs) return - if isinstance(batch, dict): - batch = Batch(batch) for k, v in batch.items(): - self.__dict__[k] = v + self.__dict__[k] = _parse_value(v) if kwargs: self.update(kwargs) @@ -704,22 +710,26 @@ def shape(self) -> List[int]: return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ else data_shape[0] - def split(self, size: Optional[int] = None, - shuffle: bool = True) -> Iterator['Batch']: + def split(self, size: int, shuffle: bool = True, + merge_last: bool = False) -> Iterator['Batch']: """Split whole data into multiple small batches. - :param int size: if it is ``None``, it does not split the data batch; - otherwise it will divide the data batch with the given size. - Default to ``None``. + :param int size: divide the data batch with the given size, but one + batch if the length of the batch is smaller than ``size``. :param bool shuffle: randomly shuffle the entire data batch if it is ``True``, otherwise remain in the same. Default to ``True``. + :param bool merge_last: merge the last batch into the previous one. + Default to ``False``. """ length = len(self) - if size is None: - size = length + assert 1 <= size # size can be greater than length, return whole batch if shuffle: indices = np.random.permutation(length) else: indices = np.arange(length) - for idx in np.arange(0, length, size): - yield self[indices[idx:(idx + size)]] + merge_last = merge_last and length % size > 0 + for idx in range(0, length, size): + if merge_last and idx + size + size >= length: + yield self[indices[idx:]] + break + yield self[indices[idx:idx + size]] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 24aeb0bcc..7d77b3058 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -122,7 +122,7 @@ class ReplayBuffer: def __init__(self, size: int, stack_num: int = 1, ignore_obs_next: bool = False, - sample_avail: bool = False, **kwargs) -> None: + sample_avail: bool = False) -> None: super().__init__() self._maxsize = size self._indices = np.arange(size) @@ -163,7 +163,8 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] - if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: + if isinstance(inst, (np.ndarray, torch.Tensor)) \ + and value.shape[1:] != inst.shape: raise ValueError( "Cannot add data to a buffer with different shape, with key " f"{name}, expect {value.shape[1:]}, given {inst.shape}.") @@ -198,12 +199,12 @@ def update(self, buffer: 'ReplayBuffer') -> None: buffer.stack_num = stack_num_orig def add(self, - obs: Union[dict, Batch, np.ndarray], - act: Union[np.ndarray, float], + obs: Union[dict, Batch, np.ndarray, float], + act: Union[dict, Batch, np.ndarray, float], rew: Union[int, float], - done: bool, - obs_next: Optional[Union[dict, Batch, np.ndarray]] = None, - info: dict = {}, + done: Union[bool, int], + obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None, + info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs) -> None: """Add a batch of data into replay buffer.""" @@ -351,9 +352,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def _add_to_buffer( self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: - if inst is None: - return - if self._meta.__dict__.get(name, None) is None: + if self._meta.__dict__.get(name) is None: self._meta.__dict__[name] = [] self._meta.__dict__[name].append(inst) @@ -393,14 +392,14 @@ def __getattr__(self, key: str) -> Union['Batch', Any]: return super().__getattr__(key) def add(self, - obs: Union[dict, np.ndarray], - act: Union[np.ndarray, float], + obs: Union[dict, Batch, np.ndarray, float], + act: Union[dict, Batch, np.ndarray, float], rew: Union[int, float], - done: bool, - obs_next: Optional[Union[dict, np.ndarray]] = None, - info: dict = {}, + done: Union[bool, int], + obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None, + info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, - weight: float = None, + weight: Optional[float] = None, **kwargs) -> None: """Add a batch of data into replay buffer.""" if weight is None: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 268792a24..57f885d9c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -3,6 +3,7 @@ import torch import warnings import numpy as np +from copy import deepcopy from typing import Any, Dict, List, Union, Optional, Callable from tianshou.env import BaseVectorEnv, DummyVectorEnv @@ -79,7 +80,7 @@ def __init__(self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, - preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, + preprocess_fn: Callable[[Any], Batch] = None, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: @@ -97,7 +98,6 @@ def __init__(self, self.is_async = env.is_async # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] - self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn @@ -106,8 +106,6 @@ def __init__(self, self._action_noise = action_noise self._rew_metric = reward_metric or Collector._default_rew_metric # avoid creating attribute outside __init__ - self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, - obs_next={}, policy={}) self.reset() @staticmethod @@ -202,14 +200,28 @@ def collect(self, * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ - assert (n_step and not n_episode) or (not n_step and n_episode), \ - "One and only one collection number specification is permitted!" + assert (n_step is not None and n_episode is None and n_step > 0) or ( + n_step is None and n_episode is not None and np.sum(n_episode) > 0 + ), "Only one of n_step or n_episode is allowed in Collector.collect, " + f"got n_step = {n_step}, n_episode = {n_episode}." start_time = time.time() step_count = 0 # episode of each environment episode_count = np.zeros(self.env_num) + # If n_episode is a list, and some envs have collected the required + # number of episodes, these envs will be recorded in this list, and + # they will not be stepped. + finished_env_ids = [] reward_total = 0.0 whole_data = Batch() + list_n_episode = False + if n_episode is not None and not np.isscalar(n_episode): + assert len(n_episode) == self.get_env_num() + list_n_episode = True + finished_env_ids = [ + i for i in self._ready_env_ids if n_episode[i] <= 0] + self._ready_env_ids = np.array( + [x for x in self._ready_env_ids if x not in finished_env_ids]) while True: if step_count >= 100000 and episode_count.sum() == 0: warnings.warn( @@ -217,12 +229,14 @@ def collect(self, 'You should add a time limitation to your environment!', Warning) - if self.is_async: - # self.data are the data for all environments - # in async simulation, only a subset of data are disposed + is_async = self.is_async or len(finished_env_ids) > 0 + if is_async: + # self.data are the data for all environments in async + # simulation or some envs have finished, + # **only a subset of data are disposed**, # so we store the whole data in ``whole_data``, let self.data - # to be all the data available in ready environments, and - # finally set these back into all the data + # to be the data available in ready environments, and finally + # set these back into all the data whole_data = self.data self.data = self.data[self._ready_env_ids] @@ -247,16 +261,15 @@ def collect(self, state = Batch() self.data.update(state=state, policy=result.get('policy', Batch())) # save hidden state to policy._state, in order to save into buffer - if not (isinstance(self.data.state, Batch) - and self.data.state.is_empty()): + if not (isinstance(state, Batch) and state.is_empty()): self.data.policy._state = self.data.state self.data.act = to_numpy(result.act) - if self._action_noise is not None: + if self._action_noise is not None: # noqa self.data.act += self._action_noise(self.data.act.shape) # step in env - if not self.is_async: + if not is_async: obs_next, rew, done, info = self.env.step(self.data.act) else: # store computed actions, states, etc @@ -264,7 +277,7 @@ def collect(self, self.data, self.env_num) # fetch finished data obs_next, rew, done, info = self.env.step( - action=self.data.act, id=self._ready_env_ids) + self.data.act, id=self._ready_env_ids) self._ready_env_ids = np.array([i['env_id'] for i in info]) # get the stepped data self.data = whole_data[self._ready_env_ids] @@ -279,23 +292,34 @@ def collect(self, if self.preprocess_fn: result = self.preprocess_fn(**self.data) self.data.update(result) + for j, i in enumerate(self._ready_env_ids): # j is the index in current ready_env_ids # i is the index in all environments - self._cached_buf[i].add(**self.data[j]) - if self.data.done[j]: - if n_step or np.isscalar(n_episode) or \ - episode_count[i] < n_episode[i]: + if self.buffer is None: + # users do not want to store data, so we store + # small fake data here to make the code clean + self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0) + else: + self._cached_buf[i].add(**self.data[j]) + + if done[j]: + if not (list_n_episode and + episode_count[i] >= n_episode[i]): episode_count[i] += 1 reward_total += np.sum(self._cached_buf[i].rew, axis=0) step_count += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) + if list_n_episode and \ + episode_count[i] >= n_episode[i]: + # env i has collected enough data, it has finished + finished_env_ids.append(i) self._cached_buf[i].reset() self._reset_state(j) obs_next = self.data.obs_next - if sum(self.data.done): - env_ind_local = np.where(self.data.done)[0] + if sum(done): + env_ind_local = np.where(done)[0] env_ind_global = self._ready_env_ids[env_ind_local] obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: @@ -304,12 +328,15 @@ def collect(self, else: obs_next[env_ind_local] = obs_reset self.data.obs = obs_next - if self.is_async: + if is_async: # set data back + whole_data = deepcopy(whole_data) # avoid reference in ListBuf _batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num) # let self.data be the data in all environments again self.data = whole_data + self._ready_env_ids = np.array( + [x for x in self._ready_env_ids if x not in finished_env_ids]) if n_step: if step_count >= n_step: break @@ -321,6 +348,10 @@ def collect(self, (episode_count >= n_episode).all(): break + # finished envs are ready, and can be used for the next collection + self._ready_env_ids = np.array( + self._ready_env_ids.tolist() + finished_env_ids) + # generate the statistics episode_count = sum(episode_count) duration = max(time.time() - start_time, 1e-9) @@ -353,6 +384,7 @@ def sample(self, batch_size: int) -> Batch: 'Collector.sample is deprecated and will cause error if you use ' 'prioritized experience replay! Collector.sample will be removed ' 'upon version 0.3. Use policy.update instead!', Warning) + assert self.buffer is not None, "Cannot get sample from empty buffer!" batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 92a9db0f6..e97b05416 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -10,13 +10,19 @@ def to_numpy(x: Union[ Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[ Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without torch.Tensor.""" - if isinstance(x, torch.Tensor): + if isinstance(x, torch.Tensor): # most often case x = x.detach().cpu().numpy() + elif isinstance(x, np.ndarray): # second often case + pass + elif isinstance(x, (np.number, np.bool_, Number)): + x = np.asanyarray(x) + elif x is None: + x = np.array(None, dtype=np.object) + elif isinstance(x, Batch): + x.to_numpy() elif isinstance(x, dict): for k, v in x.items(): x[k] = to_numpy(v) - elif isinstance(x, Batch): - x.to_numpy() elif isinstance(x, (list, tuple)): try: x = to_numpy(_parse_value(x)) @@ -32,36 +38,35 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], device: Union[str, int, torch.device] = 'cpu' ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without np.ndarray.""" - if isinstance(x, torch.Tensor): + if isinstance(x, np.ndarray) and \ + issubclass(x.dtype.type, (np.bool_, np.number)): # most often case + x = torch.from_numpy(x).to(device) + if dtype is not None: + x = x.type(dtype) + elif isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype) x = x.to(device) + elif isinstance(x, (np.number, np.bool_, Number)): + x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, dict): for k, v in x.items(): x[k] = to_torch(v, dtype, device) elif isinstance(x, Batch): x.to_torch(dtype, device) - elif isinstance(x, (np.number, np.bool_, Number)): - x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): try: x = to_torch(_parse_value(x), dtype, device) except TypeError: x = [to_torch(e, dtype, device) for e in x] else: # fallback - x = np.asanyarray(x) - if issubclass(x.dtype.type, (np.bool_, np.number)): - x = torch.from_numpy(x).to(device) - if dtype is not None: - x = x.type(dtype) - else: - raise TypeError(f"object {x} cannot be converted to torch.") + raise TypeError(f"object {x} cannot be converted to torch.") return x -def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], +def to_torch_as(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], y: torch.Tensor - ) -> Union[dict, Batch, torch.Tensor]: + ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without np.ndarray. Same as ``to_torch(x, dtype=y.dtype, device=y.device)``. """ diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 504d3e196..04323498d 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -1,7 +1,7 @@ import gym import warnings import numpy as np -from typing import List, Tuple, Union, Optional, Callable, Any +from typing import List, Union, Optional, Callable, Any from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \ RayEnvWorker @@ -116,23 +116,23 @@ def __getattr__(self, key: str) -> Any: """ return [getattr(worker, key) for worker in self.workers] - def _wrap_id( - self, id: Optional[Union[int, List[int]]] = None) -> List[int]: + def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> List[int]: if id is None: id = list(range(self.env_num)) elif np.isscalar(id): id = [id] return id - def _assert_id( - self, id: Optional[Union[int, List[int]]] = None) -> List[int]: + def _assert_id(self, id: List[int]) -> None: for i in id: assert i not in self.waiting_id, \ f'Cannot interact with environment {i} which is stepping now.' assert i in self.ready_id, \ f'Can only interact with ready environments {self.ready_id}.' - def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: + def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> np.ndarray: """Reset the state of all the environments and return initial observations if id is ``None``, otherwise reset the specific environments with the given id, either an int or a list. @@ -145,15 +145,16 @@ def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: return obs def step(self, - action: Optional[np.ndarray], - id: Optional[Union[int, List[int]]] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + action: np.ndarray, + id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> List[np.ndarray]: """Run one timestep of all the environments’ dynamics if id is "None", otherwise run one timestep for some environments with given id, either an int or a list. When the end of episode is reached, you are responsible for calling reset(id) to reset this environment’s state. - Accept a batch of action and return a tuple (obs, rew, done, info). + Accept a batch of action and return a tuple (batch_obs, batch_rew, + batch_done, batch_info) in numpy format. :param numpy.ndarray action: a batch of action provided by the agent. @@ -182,7 +183,11 @@ def step(self, assert len(action) == len(id) for i, j in enumerate(id): self.workers[j].send_action(action[i]) - result = [self.workers[j].get_result() for j in id] + result = [] + for j in id: + obs, rew, done, info = self.workers[j].get_result() + info["env_id"] = j + result.append((obs, rew, done, info)) else: if action is not None: self._assert_id(id) @@ -218,10 +223,10 @@ def seed(self, which a reproducer pass to "seed". """ self._assert_is_not_closed() - if np.isscalar(seed): - seed = [seed + _ for _ in range(self.env_num)] - elif seed is None: + if seed is None: seed = [seed] * self.env_num + elif np.isscalar(seed): + seed = [seed + i for i in range(self.env_num)] return [w.seed(s) for w, s in zip(self.workers, seed)] def render(self, **kwargs) -> List[Any]: diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 87fb6c2e8..2b56dab9b 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -10,6 +10,7 @@ class EnvWorker(ABC): def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False + self.result = (None, None, None, None) @abstractmethod def __getattr__(self, key: str) -> Any: diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 97b7087b0..893500b28 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -12,10 +12,8 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: super().__init__(env_fn) self.env = env_fn() - def __getattr__(self, key: str): - if hasattr(self.env, key): - return getattr(self.env, key) - return None + def __getattr__(self, key: str) -> Any: + return getattr(self.env, key) def reset(self) -> Any: return self.env.reset() diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index f9f4fa9ff..3f71d828b 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -17,7 +17,7 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: super().__init__(env_fn) self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn()) - def __getattr__(self, key: str): + def __getattr__(self, key: str) -> Any: return ray.get(self.env.__getattr__.remote(key)) def reset(self) -> Any: diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 6ba108eba..3186b01db 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -127,7 +127,7 @@ def __init__(self, env_fn: Callable[[], gym.Env], self.process.start() self.child_remote.close() - def __getattr__(self, key: str): + def __getattr__(self, key: str) -> Any: self.parent_remote.send(['getattr', key]) return self.parent_remote.recv() @@ -165,11 +165,12 @@ def wait(workers: List['SubprocEnvWorker'], break else: remain_time = timeout - remain_conns = [conn for conn in remain_conns - if conn not in ready_conns] + # connection.wait hangs if the list is empty new_ready_conns = connection.wait( remain_conns, timeout=remain_time) ready_conns.extend(new_ready_conns) + remain_conns = [conn for conn in remain_conns + if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] def send_action(self, action: np.ndarray) -> None: diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index 34fd50399..19f4424cc 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -14,7 +14,7 @@ def __call__(self, **kwargs) -> np.ndarray: """Generate new noise.""" raise NotImplementedError - def reset(self, **kwargs) -> None: + def reset(self) -> None: """Reset to the initial state.""" pass diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 01398ca3c..5a6c01dd7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -213,20 +213,18 @@ def compute_nstep_return( returns[done[now] > 0] = 0 returns = (rew[now] - mean) / std + gamma * returns terminal = (indice + n_step - 1) % buf_len - target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ] + target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, ) + target_q = to_numpy(target_q_torch) target_q[gammas != n_step] = 0 - returns = to_torch_as(returns, target_q) - gammas = to_torch_as(gamma ** gammas, target_q) - batch.returns = target_q * gammas + returns + target_q = target_q * (gamma ** gammas) + returns + batch.returns = to_torch_as(target_q, target_q_torch) # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): - batch.weight = to_torch_as(batch.weight, target_q) - else: - batch.weight = torch.ones_like(target_q) + batch.weight = to_torch_as(batch.weight, target_q_torch) return batch def post_process_fn(self, batch: Batch, - buffer: ReplayBuffer, indice: np.ndarray): + buffer: ReplayBuffer, indice: np.ndarray) -> None: """Post-process the data from the provided replay buffer. Typical usage is to update the sampling weight in prioritized experience replay. Check out :ref:`policy_concept` for more information. @@ -235,7 +233,8 @@ def post_process_fn(self, batch: Batch, and hasattr(batch, 'weight'): buffer.update_weight(indice, batch.weight) - def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs): + def update(self, batch_size: int, buffer: Optional[ReplayBuffer], + *args, **kwargs) -> Dict[str, Union[float, List[float]]]: """Update the policy network and replay buffer (if needed). It includes three function steps: process_fn, learn, and post_process_fn. @@ -243,6 +242,8 @@ def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs): buffer, otherwise it will sample a batch with the given batch_size. :param ReplayBuffer buffer: the corresponding replay buffer. """ + if buffer is None: + return {} batch, indice = buffer.sample(batch_size) batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, *args, **kwargs) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 52d8dd248..0f7cffd58 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -25,6 +25,12 @@ class A2CPolicy(PGPolicy): defaults to ``None``. :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation, defaults to 0.95. + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to ``False``. + :param int max_batchsize: the maximum size of the batch when computing GAE, + depends on the size of available memory and the memory cost of the + model; should be as large as possible within the memory constraint; + defaults to 256. .. seealso:: @@ -36,14 +42,14 @@ def __init__(self, actor: torch.nn.Module, critic: torch.nn.Module, optim: torch.optim.Optimizer, - dist_fn: torch.distributions.Distribution - = torch.distributions.Categorical, + dist_fn: torch.distributions.Distribution, discount_factor: float = 0.99, vf_coef: float = .5, ent_coef: float = .01, max_grad_norm: Optional[float] = None, gae_lambda: float = 0.95, reward_normalization: bool = False, + max_batchsize: int = 256, **kwargs) -> None: super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self.actor = actor @@ -53,7 +59,7 @@ def __init__(self, self._w_vf = vf_coef self._w_ent = ent_coef self._grad_norm = max_grad_norm - self._batch = 64 + self._batch = max_batchsize self._rew_norm = reward_normalization def process_fn(self, batch: Batch, buffer: ReplayBuffer, @@ -63,7 +69,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): - for b in batch.split(self._batch, shuffle=False): + for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(to_numpy(self.critic(b.obs_next))) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( @@ -97,10 +103,9 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: - self._batch = batch_size losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): - for b in batch.split(batch_size): + for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs).flatten() diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 79a65d3bb..6c34e34ac 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -141,10 +141,11 @@ def forward(self, batch: Batch, return Batch(act=actions, state=h) def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + weight = batch.pop('weight', 1.) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q - critic_loss = (td.pow(2) * batch.weight).mean() + critic_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index f1a01a6e7..5c5e45d5a 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -37,7 +37,7 @@ def __init__(self, optim: torch.optim.Optimizer, discount_factor: float = 0.99, estimation_step: int = 1, - target_update_freq: Optional[int] = 0, + target_update_freq: int = 0, reward_normalization: bool = False, **kwargs) -> None: super().__init__(**kwargs) @@ -156,11 +156,12 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() + weight = batch.pop('weight', 1.) q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() td = r - q - loss = (td.pow(2) * batch.weight).mean() + loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() self.optim.step() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 8fded95ec..3eaae641e 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -24,8 +24,7 @@ class PGPolicy(BasePolicy): def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, - dist_fn: torch.distributions.Distribution - = torch.distributions.Categorical, + dist_fn: torch.distributions.Distribution, discount_factor: float = 0.99, reward_normalization: bool = False, **kwargs) -> None: @@ -82,7 +81,7 @@ def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): - for b in batch.split(batch_size): + for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() dist = self(b).dist a = to_torch_as(b.act, dist.logits) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 3094be82e..2db5baf3b 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -34,6 +34,10 @@ class PPOPolicy(PGPolicy): defaults to ``True``. :param bool reward_normalization: normalize the returns to Normal(0, 1), defaults to ``True``. + :param int max_batchsize: the maximum size of the batch when computing GAE, + depends on the size of available memory and the memory cost of the + model; should be as large as possible within the memory constraint; + defaults to 256. .. seealso:: @@ -56,6 +60,7 @@ def __init__(self, dual_clip: Optional[float] = None, value_clip: bool = True, reward_normalization: bool = True, + max_batchsize: int = 256, **kwargs) -> None: super().__init__(None, None, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm @@ -66,7 +71,7 @@ def __init__(self, self.actor = actor self.critic = critic self.optim = optim - self._batch = 64 + self._batch = max_batchsize assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].' self._lambda = gae_lambda assert dual_clip is None or dual_clip > 1, \ @@ -83,7 +88,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, batch.rew = (batch.rew - mean) / std v, v_, old_log_prob = [], [], [] with torch.no_grad(): - for b in batch.split(self._batch, shuffle=False): + for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(self.critic(b.obs_next)) v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( @@ -132,10 +137,9 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: - self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): - for b in batch.split(batch_size): + for b in batch.split(batch_size, merge_last=True): dist = self(b).dist value = self.critic(b.obs).flatten() ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 341fe7b11..dfbc60e05 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -137,11 +137,12 @@ def _target_q(self, buffer: ReplayBuffer, return target_q def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + weight = batch.pop('weight', 1.) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q - critic1_loss = (td1.pow(2) * batch.weight).mean() + critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() @@ -149,7 +150,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q - critic2_loss = (td2.pow(2) * batch.weight).mean() + critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 9a340950b..9150f3770 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -115,11 +115,12 @@ def _target_q(self, buffer: ReplayBuffer, return target_q def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + weight = batch.pop('weight', 1.) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td1 = current_q1 - target_q - critic1_loss = (td1.pow(2) * batch.weight).mean() + critic1_loss = (td1.pow(2) * weight).mean() # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() @@ -127,7 +128,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() td2 = current_q2 - target_q - critic2_loss = (td2.pow(2) * batch.weight).mean() + critic2_loss = (td2.pow(2) * weight).mean() # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index f6329888d..c0d991d63 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -64,12 +64,12 @@ def forward(self, batch: Batch, { "act": actions corresponding to the input - "state":{ + "state": { "agent_1": output state of agent_1's policy for the state "agent_2": xxx ... "agent_n": xxx} - "out":{ + "out": { "agent_1": output of agent_1's policy for the input "agent_2": xxx ... diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 171cbb9da..153f94d9c 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -71,7 +71,7 @@ def offpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ global_step = 0 - best_epoch, best_reward = -1, -1 + best_epoch, best_reward = -1, -1. stat = {} start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy @@ -88,7 +88,7 @@ def offpolicy_trainer( if test_in_train and stop_fn and stop_fn(result['rew']): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test) + epoch, episode_per_test, writer, global_step) if stop_fn and stop_fn(test_result['rew']): if save_fn: save_fn(policy) @@ -104,13 +104,13 @@ def offpolicy_trainer( train_fn(epoch) for i in range(update_per_step * min( result['n/st'] // collect_per_step, t.total - t.n)): - global_step += 1 + global_step += collect_per_step losses = policy.update(batch_size, train_collector.buffer) for k in result.keys(): data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: - writer.add_scalar( - k, result[k], global_step=global_step) + writer.add_scalar('train/' + k, result[k], + global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() @@ -124,8 +124,8 @@ def offpolicy_trainer( if t.n <= t.total: t.update() # test - result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test) + result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, global_step) if best_epoch == -1 or best_reward < result['rew']: best_reward = result['rew'] best_epoch = epoch diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index e31724d66..ea57ed1c1 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -71,7 +71,7 @@ def onpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ global_step = 0 - best_epoch, best_reward = -1, -1 + best_epoch, best_reward = -1, -1. stat = {} start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy @@ -88,7 +88,7 @@ def onpolicy_trainer( if test_in_train and stop_fn and stop_fn(result['rew']): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test) + epoch, episode_per_test, writer, global_step) if stop_fn and stop_fn(test_result['rew']): if save_fn: save_fn(policy) @@ -109,12 +109,12 @@ def onpolicy_trainer( for k in losses.keys(): if isinstance(losses[k], list): step = max(step, len(losses[k])) - global_step += step + global_step += step * collect_per_step for k in result.keys(): data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: writer.add_scalar( - k, result[k], global_step=global_step) + 'train/' + k, result[k], global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() @@ -128,8 +128,8 @@ def onpolicy_trainer( if t.n <= t.total: t.update() # test - result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test) + result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, global_step) if best_epoch == -1 or best_reward < result['rew']: best_reward = result['rew'] best_epoch = epoch diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index eb9bd3245..ba914d842 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,6 +1,7 @@ import time import numpy as np -from typing import Dict, List, Union, Callable +from torch.utils.tensorboard import SummaryWriter +from typing import Dict, List, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -9,9 +10,11 @@ def test_episode( policy: BasePolicy, collector: Collector, - test_fn: Callable[[int], None], + test_fn: Optional[Callable[[int], None]], epoch: int, - n_episode: Union[int, List[int]]) -> Dict[str, float]: + n_episode: Union[int, List[int]], + writer: SummaryWriter = None, + global_step: int = None) -> Dict[str, float]: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() @@ -23,7 +26,11 @@ def test_episode( n_ = np.zeros(n) + n_episode // n n_[:n_episode % n] += 1 n_episode = list(n_) - return collector.collect(n_episode=n_episode) + result = collector.collect(n_episode=n_episode) + if writer is not None and global_step is not None: + for k in result.keys(): + writer.add_scalar('test/' + k, result[k], global_step=global_step) + return result def gather_info(start_time: float, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index a84a7e7cd..eb68a9710 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,12 +1,13 @@ import torch import numpy as np from torch import nn -from typing import Tuple, Union, Optional +from typing import List, Tuple, Union, Optional from tianshou.data import to_torch -def miniblock(inp: int, oup: int, norm_layer: nn.modules.Module): +def miniblock(inp: int, oup: int, + norm_layer: nn.modules.Module) -> List[nn.modules.Module]: ret = [nn.Linear(inp, oup)] if norm_layer is not None: ret += [norm_layer(oup)] @@ -28,7 +29,7 @@ class Net(nn.Module): """ def __init__(self, layer_num: int, state_shape: tuple, - action_shape: Optional[tuple] = 0, + action_shape: Optional[Union[tuple, int]] = 0, device: Union[str, torch.device] = 'cpu', softmax: bool = False, concat: bool = False, diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index bbd2d9655..03a11f59c 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -2,7 +2,7 @@ import numpy as np from torch import nn -from tianshou.data import to_torch +from tianshou.data import to_torch, to_torch_as class Actor(nn.Module): @@ -10,8 +10,8 @@ class Actor(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, action_shape, - max_action, device='cpu', hidden_layer_size=128): + def __init__(self, preprocess_net, action_shape, max_action=1., + device='cpu', hidden_layer_size=128): super().__init__() self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, np.prod(action_shape)) @@ -35,7 +35,7 @@ def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128): self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, 1) - def forward(self, s, a=None, **kwargs): + def forward(self, s, a=None, info={}): """(s, a) -> logits -> Q(s, a)""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) @@ -53,7 +53,7 @@ class ActorProb(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, action_shape, max_action, + def __init__(self, preprocess_net, action_shape, max_action=1., device='cpu', unbounded=False, hidden_layer_size=128): super().__init__() self.preprocess = preprocess_net @@ -63,7 +63,7 @@ def __init__(self, preprocess_net, action_shape, max_action, self._max = max_action self._unbounded = unbounded - def forward(self, s, state=None, **kwargs): + def forward(self, s, state=None, info={}): """s -> logits -> (mu, sigma)""" logits, h = self.preprocess(s, state) mu = self.mu(logits) @@ -80,8 +80,8 @@ class RecurrentActorProb(nn.Module): :ref:`build_the_network`. """ - def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu', hidden_layer_size=128): + def __init__(self, layer_num, state_shape, action_shape, max_action=1., + device='cpu', unbounded=False, hidden_layer_size=128): super().__init__() self.device = device self.nn = nn.LSTM(input_size=np.prod(state_shape), @@ -89,8 +89,10 @@ def __init__(self, layer_num, state_shape, action_shape, num_layers=layer_num, batch_first=True) self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape)) self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) + self._max = max_action + self._unbounded = unbounded - def forward(self, s, **kwargs): + def forward(self, s, state=None, info={}): """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -98,13 +100,24 @@ def forward(self, s, **kwargs): # in evaluation phase. if len(s.shape) == 2: s = s.unsqueeze(-2) - logits, _ = self.nn(s) - logits = logits[:, -1] + self.nn.flatten_parameters() + if state is None: + s, (h, c) = self.nn(s) + else: + # we store the stack data in [bsz, len, ...] format + # but pytorch rnn needs [len, bsz, ...] + s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(), + state['c'].transpose(0, 1).contiguous())) + logits = s[:, -1] mu = self.mu(logits) + if not self._unbounded: + mu = self._max * torch.tanh(mu) shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() - return (mu, sigma), None + # please ensure the first dim is batch size: [bsz, len, ...] + return (mu, sigma), {'h': h.transpose(0, 1).detach(), + 'c': c.transpose(0, 1).detach()} class RecurrentCritic(nn.Module): @@ -134,8 +147,7 @@ def forward(self, s, a=None): s, (h, c) = self.nn(s) s = s[:, -1] if a is not None: - if not isinstance(a, torch.Tensor): - a = torch.tensor(a, device=self.device, dtype=torch.float32) + a = to_torch_as(a, s) s = torch.cat([s, a], dim=1) s = self.fc2(s) return s diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index afed6dfb5..c7fed2bcb 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -45,7 +45,7 @@ class DQN(nn.Module): Reference paper: "Human-level control through deep reinforcement learning". """ - def __init__(self, h, w, action_shape, device='cpu'): + def __init__(self, c, h, w, action_shape, device='cpu'): super(DQN, self).__init__() self.device = device @@ -66,7 +66,7 @@ def conv2d_layers_size_out(size, linear_input_size = convw * convh * 64 self.net = nn.Sequential( - nn.Conv2d(4, 32, kernel_size=8, stride=4), + nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), @@ -74,12 +74,11 @@ def conv2d_layers_size_out(size, nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(linear_input_size, 512), - nn.Linear(512, action_shape) + nn.Linear(512, np.prod(action_shape)) ) def forward(self, x, state=None, info={}): r"""x -> Q(x, \*)""" if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device, dtype=torch.float32) - x = x.permute(0, 3, 1, 2) return self.net(x), state