这是indexloc提供的服务,不要输入任何密码
Skip to content

optimize training procedure and improve code coverage #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 60 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
ea62500
collect fake data when buffer is None in Collector
Trinkle23897 Aug 20, 2020
b8a797c
add env_id in info for all environments
youkaichao Aug 20, 2020
d47c38d
fix test in collector preprocess_fn
Trinkle23897 Aug 20, 2020
43cbed7
potential bugfix for subproc.wait
youkaichao Aug 20, 2020
cea56cb
add steps count for test env; copy data for list buffer
youkaichao Aug 20, 2020
b0ae34c
enable exact n_episode for each env.
youkaichao Aug 20, 2020
6725bfc
.keys()
Trinkle23897 Aug 20, 2020
9489e7a
test __contains__
Trinkle23897 Aug 20, 2020
4d943e7
fix atari test
Trinkle23897 Aug 21, 2020
c1e4fbd
move deepcopy to collector (whole_data inplace modification cause Lis…
Trinkle23897 Aug 21, 2020
4516b90
bypsas the attr check for batch.weight, test_dqn training fps 1870 ->…
Trinkle23897 Aug 21, 2020
c8ca9e5
change nstep to_torch and reach 1950+ (near v0.2.4.post1)
Trinkle23897 Aug 21, 2020
c50de09
fix a bug in per
Trinkle23897 Aug 21, 2020
b0e3cd5
batch.pop
Trinkle23897 Aug 21, 2020
d52f4f3
move previous script to runnable/
Trinkle23897 Aug 21, 2020
6be957c
rename
Trinkle23897 Aug 21, 2020
1367e2c
little enhancement by modifying _parse_value
Trinkle23897 Aug 21, 2020
cdbf8f6
add max_batchsize in a2c and ppo
Trinkle23897 Aug 21, 2020
7585f49
move test_gae to test/base/test_returns
Trinkle23897 Aug 21, 2020
543a7d9
add test_nstep
Trinkle23897 Aug 21, 2020
7681742
increase drqn gamma
Trinkle23897 Aug 21, 2020
a9c1b2e
performance improvement (+50) by analyzing traces
Trinkle23897 Aug 21, 2020
4d24149
add policy.eval() before watching its performance
Trinkle23897 Aug 21, 2020
dbddda7
remove previous atari script
Trinkle23897 Aug 22, 2020
88db97b
find a bug in exact n_episode
Trinkle23897 Aug 22, 2020
226a518
fix 0 in n_episode
Trinkle23897 Aug 22, 2020
5fe0850
improve little coverage
Trinkle23897 Aug 22, 2020
8c0f414
add missing test
Trinkle23897 Aug 22, 2020
0a168ce
add missing test for buffer, to_numpy and to_torch
Trinkle23897 Aug 22, 2020
a13bffa
add missing test for venv and utils
Trinkle23897 Aug 22, 2020
e920906
fix test
Trinkle23897 Aug 22, 2020
a68fc99
fix RecurrentActorProb and add test
Trinkle23897 Aug 22, 2020
b265265
little increase
Trinkle23897 Aug 22, 2020
51336a8
add a little test
Trinkle23897 Aug 22, 2020
e7387ff
minor fix
Trinkle23897 Aug 22, 2020
b710aed
merge_last in batch.split() (#185)
Trinkle23897 Aug 22, 2020
d98ed79
fix batch.split
Trinkle23897 Aug 22, 2020
25ddfd7
add merge_last in policy
Trinkle23897 Aug 22, 2020
e288d8c
change merge_last logic and add docs in preprocess_fn
Trinkle23897 Aug 23, 2020
c3d2bad
test_pg is too slow
Trinkle23897 Aug 24, 2020
da6277c
fix tensorboard logging
Trinkle23897 Aug 24, 2020
50aa5cf
add a check of buffer
Trinkle23897 Aug 24, 2020
4b90396
size 2000 -> 256
Trinkle23897 Aug 24, 2020
c8ed82d
simplify test batch.split
Trinkle23897 Aug 24, 2020
96bc690
fix docstring
Trinkle23897 Aug 24, 2020
20b7b48
simplify batch.split
Trinkle23897 Aug 25, 2020
340931f
optimize for batch.{cat/stack/empty}
Trinkle23897 Aug 26, 2020
1ca2d98
remove buffer **kwargs
Trinkle23897 Aug 26, 2020
9dce2c7
fix some type
Trinkle23897 Aug 26, 2020
df4129f
fix
Trinkle23897 Aug 26, 2020
c1dade5
reserve only one bypass in collector
Trinkle23897 Aug 26, 2020
efb4e00
minor fix
Trinkle23897 Aug 26, 2020
089d784
minor fix for some docstrings
Trinkle23897 Aug 26, 2020
225ab6a
simplify setup.py
Trinkle23897 Aug 26, 2020
5c2c506
test new init
Trinkle23897 Aug 26, 2020
9eee84f
version file
Trinkle23897 Aug 27, 2020
c1ac1dc
version file
Trinkle23897 Aug 27, 2020
9b643cf
fix some type annotation
Trinkle23897 Aug 27, 2020
12bc4cd
docstring
Trinkle23897 Aug 27, 2020
d1b1831
docstring
Trinkle23897 Aug 27, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ policy.load_state_dict(torch.load('dqn.pth'))
Watch the performance with 35 FPS:

```python
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
```
Expand Down
7 changes: 4 additions & 3 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.

If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.

This function receives typically 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a dict or a Batch. For example, you can write your hook as:
This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env reset, while every key is specified for normal steps. For example, you can write your hook as:
::

import numpy as np
Expand All @@ -109,9 +109,11 @@ This function receives typically 7 keys, as listed in :class:`~tianshou.data.Bat
self.baseline = 0
def preprocess_fn(**kwargs):
"""change reward to zero mean"""
# if only obs exist -> reset
# if obs/act/rew/done/... exist -> normal step
if 'rew' not in kwargs:
# means that it is called after env.reset(), it can only process the obs
return {} # none of the variables are needed to be updated
return Batch() # none of the variables are needed to be updated
else:
n = len(kwargs['rew']) # the number of envs in collector
if self.episode_log is None:
Expand All @@ -125,7 +127,6 @@ This function receives typically 7 keys, as listed in :class:`~tianshou.data.Bat
self.episode_log[i] = []
self.baseline = np.mean(self.main_log)
return Batch(rew=kwargs['rew'])
# you can also return with {'rew': kwargs['rew']}

And finally,
::
Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ Watch the Agent's Performance
:class:`~tianshou.data.Collector` supports rendering. Here is the example of watching the agent's performance in 35 FPS:
::

policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)

Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ With the above preparation, we are close to the first learned agent. The followi
env = TicTacToeEnv(args.board_size, args.win_size)
policy, optim = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
Expand Down
108 changes: 0 additions & 108 deletions examples/atari/pong_dqn.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
9 changes: 6 additions & 3 deletions examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,12 @@ def test_fn(x):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
18 changes: 11 additions & 7 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_args():

class EnvWrapper(object):
"""Env wrapper for reward scale, action repeat and action noise"""

def __init__(self, task, action_repeat=3,
reward_scale=5, act_noise=0.3):
self._env = gym.make(task)
Expand Down Expand Up @@ -71,19 +72,20 @@ def step(self, action):
def test_sac_bipedal(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN

env = EnvWrapper(args.task)

def IsStop(reward):
return reward >= 300 * 5
return reward >= env.spec.reward_threshold

env = EnvWrapper(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]

train_envs = SubprocVectorEnv(
[lambda: EnvWrapper(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: EnvWrapper(args.task) for _ in range(args.test_num)])
test_envs = SubprocVectorEnv([lambda: EnvWrapper(args.task, reward_scale=1)
for _ in range(args.test_num)])

# seed
np.random.seed(args.seed)
Expand Down Expand Up @@ -138,9 +140,11 @@ def save_fn(policy):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = EnvWrapper(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=16, render=args.render)
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
9 changes: 6 additions & 3 deletions examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,12 @@ def test_fn(x):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
8 changes: 5 additions & 3 deletions examples/box2d/sac_mcc.py → examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ def stop_fn(x):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
8 changes: 5 additions & 3 deletions examples/mujoco/ant_v2_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def stop_fn(x):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
8 changes: 5 additions & 3 deletions examples/mujoco/ant_v2_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ def stop_fn(x):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
8 changes: 5 additions & 3 deletions examples/mujoco/ant_v2_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ def stop_fn(x):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
8 changes: 5 additions & 3 deletions examples/mujoco/halfcheetahBullet_v0_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ def stop_fn(x):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
8 changes: 5 additions & 3 deletions examples/mujoco/point_maze_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ def stop_fn(x):
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_step=1000, render=args.render)
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


Expand Down
10 changes: 1 addition & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,10 @@

from setuptools import setup, find_packages

import re
from os import path

here = path.abspath(path.dirname(__file__))

# Get the version string
with open(path.join(here, 'tianshou', '__init__.py')) as f:
version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1)

setup(
name='tianshou',
version=version,
version='0.2.6',
description='A Library for Deep Reinforcement Learning',
long_description=open('README.md', encoding='utf8').read(),
long_description_content_type='text/markdown',
Expand Down
3 changes: 3 additions & 0 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False,
self.recurse_state = recurse_state
self.ma_rew = ma_rew
self._md_action = multidiscrete_action
# how many steps this env has stepped
self.steps = 0
if dict_state:
self.observation_space = Dict(
{"index": Box(shape=(1, ), low=0, high=size - 1),
Expand Down Expand Up @@ -74,6 +76,7 @@ def _get_state(self):
return np.array([self.index], dtype=np.float32)

def step(self, action):
self.steps += 1
if self._md_action:
action = action[0]
if self.done:
Expand Down
Loading