diff --git a/README.md b/README.md index b2a207997..eebbfde64 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Here is Tianshou's other features: - Elegant framework, using only ~2000 lines of code - Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling) - Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training) -- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) +- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) - Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process) - Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation - Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) @@ -74,8 +74,8 @@ $ pip install tianshou After installation, open your python console and type ```python -import tianshou as ts -print(ts.__version__) +import tianshou +print(tianshou.__version__) ``` If no error occurs, you have successfully installed Tianshou. diff --git a/docs/index.rst b/docs/index.rst index dba58bd72..bfb2ddfdf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,11 +24,11 @@ Welcome to Tianshou! Here is Tianshou's other features: * Elegant framework, using only ~2000 lines of code -* Support parallel environment sampling for all algorithms: :ref:`parallel_sampling` -* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` +* Support parallel environment simulation (synchronous or asynchronous) for all algorithms: :ref:`parallel_sampling` +* Support recurrent state/action representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` * Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` * Support customized training process: :ref:`customize_training` -* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay for all Q-learning based algorithms +* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support multi-agent RL: :doc:`/tutorials/tictactoe` 中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_ @@ -63,8 +63,8 @@ If you use Anaconda or Miniconda, you can install Tianshou through the following After installation, open your python console and type :: - import tianshou as ts - print(ts.__version__) + import tianshou + print(tianshou.__version__) If no error occurs, you have successfully installed Tianshou. diff --git a/examples/box2d/README.md b/examples/box2d/README.md new file mode 100644 index 000000000..0935534b4 --- /dev/null +++ b/examples/box2d/README.md @@ -0,0 +1,7 @@ +# Bipedal-Hardcore-SAC + +- Our default choice: remove the done flag penalty, will soon converge to \~250 reward within 100 epochs (10M env steps, 3~4 hours, see the image below) +- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward) +- Action noise is only necessary in the beginning. It is a negative impact at the end of the training. Removing it can reach \~255 (our best result under the original env, no done penalty removed). + +![](results/sac/BipedalHardcore.png) \ No newline at end of file diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index a92963d83..4e123719b 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -24,13 +24,13 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.1) - parser.add_argument('--epoch', type=int, default=1000) - parser.add_argument('--step-per-epoch', type=int, default=2400) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--layer-num', type=int, default=1) parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument('--rew-norm', type=int, default=0) @@ -39,14 +39,14 @@ def get_args(): parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--resume_path', type=str, default=None) return parser.parse_args() class EnvWrapper(object): """Env wrapper for reward scale, action repeat and action noise""" - def __init__(self, task, action_repeat=3, - reward_scale=5, act_noise=0.3): + def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3): self._env = gym.make(task) self.action_repeat = action_repeat self.reward_scale = reward_scale @@ -70,8 +70,6 @@ def step(self, action): def test_sac_bipedal(args=get_args()): - torch.set_num_threads(1) # we just need only one thread for NN - env = EnvWrapper(args.task) def IsStop(reward): @@ -118,6 +116,10 @@ def IsStop(reward): reward_normalization=args.rew_norm, ignore_done=args.ignore_done, estimation_step=args.n_step) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path)) + print("Loaded agent from: ", args.resume_path) # collector train_collector = Collector( @@ -135,7 +137,8 @@ def save_fn(policy): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer) + args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer, + test_in_train=False) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/box2d/results/sac/BipedalHardcore.png b/examples/box2d/results/sac/BipedalHardcore.png new file mode 100644 index 000000000..0b4196955 Binary files /dev/null and b/examples/box2d/results/sac/BipedalHardcore.png differ diff --git a/setup.py b/setup.py index 175112c44..789ea2d57 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import os from setuptools import setup, find_packages +def get_version() -> str: + # https://packaging.python.org/guides/single-sourcing-package-version/ + init = open(os.path.join("tianshou", "__init__.py"), "r").read().split() + return init[init.index("__version__") + 2][1:-1] + + setup( name='tianshou', - version='0.2.6', + version=get_version(), description='A Library for Deep Reinforcement Learning', long_description=open('README.md', encoding='utf8').read(), long_description_content_type='text/markdown', diff --git a/tianshou/__init__.py b/tianshou/__init__.py index d44b4dc5a..e03d8640c 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -5,7 +5,7 @@ utils.pre_compile() -__version__ = '0.2.6' +__version__ = '0.2.7' __all__ = [ 'env', diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index c01e45fd0..674de0429 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -8,7 +8,7 @@ class ImitationPolicy(BasePolicy): - """Implementation of vanilla imitation learning (for continuous action space). + """Implementation of vanilla imitation learning. :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> a) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index c0d991d63..74ab0ecac 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -36,7 +36,9 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, # reward can be empty Batch (after initial reset) or nparray. has_rew = isinstance(buffer.rew, np.ndarray) if has_rew: # save the original reward in save_rew - save_rew, buffer.rew = buffer.rew, Batch() + # Since we do not override buffer.__setattr__, here we use _meta to + # change buffer.rew, otherwise buffer.rew = Batch() has no effect. + save_rew, buffer._meta.rew = buffer.rew, Batch() for policy in self.policies: agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] if len(agent_index) == 0: @@ -45,11 +47,11 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] if has_rew: tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] - buffer.rew = save_rew[:, policy.agent_id - 1] + buffer._meta.rew = save_rew[:, policy.agent_id - 1] results[f'agent_{policy.agent_id}'] = \ policy.process_fn(tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew - buffer.rew = save_rew + buffer._meta.rew = save_rew return Batch(results) def forward(self, batch: Batch,