这是indexloc提供的服务,不要输入任何密码
Skip to content

Asynchronous sampling vector environment #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions test/base/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gym
import time
import random
import numpy as np
from gym.spaces import Discrete, MultiDiscrete, Box

Expand All @@ -9,9 +10,10 @@ class MyTestEnv(gym.Env):
"""

def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
multidiscrete_action=False):
multidiscrete_action=False, random_sleep=False):
self.size = size
self.sleep = sleep
self.random_sleep = random_sleep
self.dict_state = dict_state
self.ma_rew = ma_rew
self._md_action = multidiscrete_action
Expand Down Expand Up @@ -48,7 +50,9 @@ def step(self, action):
if self.done:
raise ValueError('step after done !!!')
if self.sleep > 0:
time.sleep(self.sleep)
sleep_time = random.random() if self.random_sleep else 1
sleep_time *= self.sleep
time.sleep(sleep_time)
if self.index == self.size:
self.done = True
return self._get_dict_state(), self._get_reward(), self.done, {}
Expand Down
48 changes: 47 additions & 1 deletion test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import BasePolicy
from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.env import VectorEnv, SubprocVectorEnv, AsyncVectorEnv
from tianshou.data import Collector, Batch, ReplayBuffer

if __name__ == '__main__':
Expand Down Expand Up @@ -103,6 +103,51 @@ def test_collector():
c2.collect(n_episode=[1, 1, 1, 1], random=True)


def test_collector_with_async():
env_lens = [2, 3, 4, 5]
writer = SummaryWriter('log/async_collector')
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
for i in env_lens]

venv = AsyncVectorEnv(env_fns)
policy = MyPolicy()
c1 = Collector(policy, venv,
ReplayBuffer(size=1000, ignore_obs_next=False),
logger.preprocess_fn)
c1.collect(n_episode=10)
# check if the data in the buffer is chronological
# i.e. data in the buffer are full episodes, and each episode is
# returned by the same environment
env_id = c1.buffer.info['env_id']
size = len(c1.buffer)
obs = c1.buffer.obs[:size]
done = c1.buffer.done[:size]
print(env_id[:size])
print(obs)
obs_ground_truth = []
i = 0
while i < size:
# i is the start of an episode
if done[i]:
# this episode has one transition
assert env_lens[env_id[i]] == 1
i += 1
continue
j = i
while True:
j += 1
# in one episode, the environment id is the same
assert env_id[j] == env_id[i]
if done[j]:
break
j = j + 1 # j is the start of the next episode
assert j - i == env_lens[env_id[i]]
obs_ground_truth += list(range(j - i))
i = j
assert np.allclose(obs, obs_ground_truth)


def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
Expand Down Expand Up @@ -181,3 +226,4 @@ def reward_metric(x):
test_collector()
test_collector_with_dict_state()
test_collector_with_ma()
test_collector_with_async()
42 changes: 41 additions & 1 deletion test/base/test_env.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,53 @@
import time
import numpy as np
from gym.spaces.discrete import Discrete
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
from tianshou.data import Batch
from tianshou.env import VectorEnv, SubprocVectorEnv, \
RayVectorEnv, AsyncVectorEnv

if __name__ == '__main__':
from env import MyTestEnv
else: # pytest
from test.base.env import MyTestEnv


def test_async_env(num=8, sleep=0.1):
# simplify the test case, just keep stepping
size = 10000
env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
for i in range(size, size + num)
]
v = AsyncVectorEnv(env_fns, wait_num=num // 2)
v.seed()
v.reset()
# for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
# P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
# expectation of v is n / (n + 1)
# for a synchronous environment, the following actions should take
# about 7 * sleep * num / (num + 1) seconds
# for AsyncVectorEnv, the analysis is complicated, but the time cost
# should be smaller
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
current_index_start = 0
action = action_list[:num]
env_ids = list(range(num))
o = []
spent_time = time.time()
while current_index_start < len(action_list):
A, B, C, D = v.step(action=action, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
env_ids = b.info.env_id
o.append(b)
current_index_start += len(action)
action = action_list[current_index_start: current_index_start + len(A)]
spent_time = time.time() - spent_time
data = Batch.cat(o)
# assure 1/7 improvement
assert spent_time < 6.0 * sleep * num / (num + 1)
return spent_time, data


def test_vecenv(size=10, num=8, sleep=0.001):
verbose = __name__ == '__main__'
env_fns = [
Expand Down Expand Up @@ -60,3 +99,4 @@ def test_vecenv(size=10, num=8, sleep=0.001):

if __name__ == '__main__':
test_vecenv()
test_async_env()
96 changes: 83 additions & 13 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import numpy as np
from typing import Any, Dict, List, Union, Optional, Callable

from tianshou.env import BaseVectorEnv, VectorEnv
from tianshou.env import BaseVectorEnv, VectorEnv, AsyncVectorEnv
from tianshou.policy import BasePolicy
from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
from tianshou.data.batch import _create_value


class Collector(object):
Expand Down Expand Up @@ -96,6 +97,13 @@ def __init__(self,
env = VectorEnv([lambda: env])
self.env = env
self.env_num = len(env)
# environments that are available in step()
# this means all environments in synchronous simulation
# but only a subset of environments in asynchronous simulation
self._ready_env_ids = np.arange(self.env_num)
# self.async is a flag to indicate whether this collector works
# with asynchronous simulation
self.is_async = isinstance(env, AsyncVectorEnv)
# need cache buffers before storing in the main buffer
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
Expand All @@ -105,6 +113,9 @@ def __init__(self,
self.process_fn = policy.process_fn
self._action_noise = action_noise
self._rew_metric = reward_metric or Collector._default_rew_metric
# avoid creating attribute outside __init__
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
obs_next={}, policy={})
self.reset()

@staticmethod
Expand Down Expand Up @@ -139,6 +150,7 @@ def reset_env(self) -> None:
"""Reset all of the environment(s)' states and reset all of the cache
buffers (if need).
"""
self._ready_env_ids = np.arange(self.env_num)
obs = self.env.reset()
if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs).get('obs', obs)
Expand All @@ -159,7 +171,7 @@ def close(self) -> None:
self.env.close()

def _reset_state(self, id: Union[int, List[int]]) -> None:
"""Reset self.data.state[id]."""
"""Reset the hidden state: self.data.state[id]."""
state = self.data.state # it is a reference
if isinstance(state, torch.Tensor):
state[id].zero_()
Expand Down Expand Up @@ -207,13 +219,23 @@ def collect(self,
# episode of each environment
episode_count = np.zeros(self.env_num)
reward_total = 0.0
whole_data = Batch()
while True:
if step_count >= 100000 and episode_count.sum() == 0:
warnings.warn(
'There are already many steps in an episode. '
'You should add a time limitation to your environment!',
Warning)

if self.is_async:
# self.data are the data for all environments
# in async simulation, only a subset of data are disposed
# so we store the whole data in ``whole_data``, let self.data
# to be all the data available in ready environments, and
# finally set these back into all the data
whole_data = self.data
self.data = self.data[self._ready_env_ids]

# restore the state and the input data
last_state = self.data.state
if last_state.is_empty():
Expand All @@ -222,8 +244,16 @@ def collect(self,

# calculate the next action
if random:
if self.is_async:
# TODO self.env.action_space will invoke remote call for
# all environments, which may hang in async simulation.
# This can be avoided by using a random policy, but not
# in the collector level. Leave it as a future work.
raise RuntimeError("cannot use random "
"sampling in async simulation!")
spaces = self.env.action_space
result = Batch(
act=[a.sample() for a in self.env.action_space])
act=[spaces[i].sample() for i in self._ready_env_ids])
else:
with torch.no_grad():
result = self.policy(self.data, last_state)
Expand All @@ -243,8 +273,18 @@ def collect(self,
self.data.act += self._action_noise(self.data.act.shape)

# step in env
obs_next, rew, done, info = self.env.step(self.data.act)

if not self.is_async:
obs_next, rew, done, info = self.env.step(self.data.act)
else:
# store computed actions, states, etc
_batch_set_item(whole_data, self._ready_env_ids,
self.data, self.env_num)
# fetch finished data
obs_next, rew, done, info = self.env.step(
action=self.data.act, id=self._ready_env_ids)
self._ready_env_ids = np.array([i['env_id'] for i in info])
# get the stepped data
self.data = whole_data[self._ready_env_ids]
# move data to self.data
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)

Expand All @@ -256,9 +296,11 @@ def collect(self,
if self.preprocess_fn:
result = self.preprocess_fn(**self.data)
self.data.update(result)
for i in range(self.env_num):
self._cached_buf[i].add(**self.data[i])
if self.data.done[i]:
for j, i in enumerate(self._ready_env_ids):
# j is the index in current ready_env_ids
# i is the index in all environments
self._cached_buf[i].add(**self.data[j])
if self.data.done[j]:
if n_step or np.isscalar(n_episode) or \
episode_count[i] < n_episode[i]:
episode_count[i] += 1
Expand All @@ -267,17 +309,24 @@ def collect(self,
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
self._cached_buf[i].reset()
self._reset_state(i)
self._reset_state(j)
obs_next = self.data.obs_next
if sum(self.data.done):
env_ind = np.where(self.data.done)[0]
obs_reset = self.env.reset(env_ind)
env_ind_local = np.where(self.data.done)[0]
env_ind_global = self._ready_env_ids[env_ind_local]
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_next[env_ind] = self.preprocess_fn(
obs_next[env_ind_local] = self.preprocess_fn(
obs=obs_reset).get('obs', obs_reset)
else:
obs_next[env_ind] = obs_reset
obs_next[env_ind_local] = obs_reset
self.data.obs = obs_next
if self.is_async:
# set data back
_batch_set_item(whole_data, self._ready_env_ids,
self.data, self.env_num)
# let self.data be the data in all environments again
self.data = whole_data
if n_step:
if step_count >= n_step:
break
Expand Down Expand Up @@ -320,3 +369,24 @@ def sample(self, batch_size: int) -> Batch:
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data


def _batch_set_item(source: Batch, indices: np.ndarray,
target: Batch, size: int):
# for any key chain k, there are three cases
# 1. source[k] is non-reserved, but target[k] does not exist or is reserved
# 2. source[k] does not exist or is reserved, but target[k] is non-reserved
# 3. both source[k] and target[k] is non-reserved
for k, v in target.items():
if not isinstance(v, Batch) or not v.is_empty():
# target[k] is non-reserved
vs = source.get(k, Batch())
if isinstance(vs, Batch) and vs.is_empty():
# case 2
# use __dict__ to avoid many type checks
source.__dict__[k] = _create_value(v[0], size)
else:
# target[k] is reserved
# case 1
continue
source.__dict__[k][indices] = v
8 changes: 6 additions & 2 deletions tianshou/env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from tianshou.env.basevecenv import BaseVectorEnv
from tianshou.env.vecenv import VectorEnv, SubprocVectorEnv, RayVectorEnv
from tianshou.env.vecenv.base import BaseVectorEnv
from tianshou.env.vecenv.dummy import VectorEnv
from tianshou.env.vecenv.subproc import SubprocVectorEnv
from tianshou.env.vecenv.asyncenv import AsyncVectorEnv
from tianshou.env.vecenv.rayenv import RayVectorEnv
from tianshou.env.maenv import MultiAgentEnv

__all__ = [
'BaseVectorEnv',
'VectorEnv',
'SubprocVectorEnv',
'AsyncVectorEnv',
'RayVectorEnv',
'MultiAgentEnv',
]
Loading