这是indexloc提供的服务,不要输入任何密码
Skip to content

move atari wrapper to examples and publish v0.2.4 #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions tianshou/env/atari.py → examples/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import gym
import numpy as np
from gym.spaces.box import Box
from tianshou.data import Batch

SIZE = 84
FRAME = 4


def create_atari_environment(name=None, sticky_actions=True,
Expand All @@ -14,6 +18,27 @@ def create_atari_environment(name=None, sticky_actions=True,
return env


def preprocess_fn(obs=None, act=None, rew=None, done=None,
obs_next=None, info=None, policy=None):
if obs_next is not None:
obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:]))
obs_next = np.moveaxis(obs_next, 0, -1)
obs_next = cv2.resize(obs_next, (SIZE, SIZE))
obs_next = np.asanyarray(obs_next, dtype=np.uint8)
obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE))
obs_next = np.moveaxis(obs_next, 1, -1)
elif obs is not None:
obs = np.reshape(obs, (-1, *obs.shape[2:]))
obs = np.moveaxis(obs, 0, -1)
obs = cv2.resize(obs, (SIZE, SIZE))
obs = np.asanyarray(obs, dtype=np.uint8)
obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE))
obs = np.moveaxis(obs, 1, -1)

return Batch(obs=obs, act=act, rew=rew, done=done,
obs_next=obs_next, info=info)


class preprocessing(object):
def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,
size=84, max_episode_steps=2000):
Expand All @@ -35,7 +60,8 @@ def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,

@property
def observation_space(self):
return Box(low=0, high=255, shape=(self.size, self.size, 4),
return Box(low=0, high=255,
shape=(self.size, self.size, self.frame_skip),
dtype=np.uint8)

def action_space(self):
Expand All @@ -57,8 +83,8 @@ def reset(self):
self._grayscale_obs(self.screen_buffer[0])
self.screen_buffer[1].fill(0)

return np.stack([
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
return np.array([self._pool_and_resize()
for _ in range(self.frame_skip)])

def render(self, mode='human'):
return self.env.render(mode)
Expand All @@ -85,19 +111,15 @@ def step(self, action):
self._grayscale_obs(self.screen_buffer[t_])

observation.append(self._pool_and_resize())
while len(observation) > 0 and len(observation) < self.frame_skip:
if len(observation) == 0:
observation = [self._pool_and_resize()
for _ in range(self.frame_skip)]
while len(observation) > 0 and \
len(observation) < self.frame_skip:
observation.append(observation[-1])
if len(observation) > 0:
observation = np.stack(observation, axis=-1)
else:
observation = np.stack([
self._pool_and_resize() for _ in range(self.frame_skip)],
axis=-1)
if self.count >= self.max_episode_steps:
terminal = True
else:
terminal = False
return observation, total_reward, (terminal or is_terminal), info
terminal = self.count >= self.max_episode_steps
return np.array(observation), total_reward, \
(terminal or is_terminal), info

def _grayscale_obs(self, output):
self.env.ale.getScreenGrayscale(output)
Expand All @@ -108,9 +130,4 @@ def _pool_and_resize(self):
np.maximum(self.screen_buffer[0], self.screen_buffer[1],
out=self.screen_buffer[0])

transformed_image = cv2.resize(self.screen_buffer[0],
(self.size, self.size),
interpolation=cv2.INTER_AREA)
int_image = np.asarray(transformed_image, dtype=np.uint8)
# return np.expand_dims(int_image, axis=2)
return int_image
return self.screen_buffer[0]
24 changes: 11 additions & 13 deletions examples/pong_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment

from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net

from atari import create_atari_environment, preprocess_fn


def get_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -45,20 +45,17 @@ def get_args():


def test_a2c(args=get_args()):
env = create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
env = create_atari_environment(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
for _ in range(args.training_num)])
[lambda: create_atari_environment(args.task)
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
for _ in range(args.test_num)])
[lambda: create_atari_environment(args.task)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
Expand All @@ -76,8 +73,9 @@ def test_a2c(args=get_args()):
ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
policy, train_envs, ReplayBuffer(args.buffer_size),
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(args.logdir + '/' + 'a2c')

Expand All @@ -99,7 +97,7 @@ def stop_fn(x):
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
Expand Down
14 changes: 8 additions & 6 deletions examples/pong_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from tianshou.utils.net.discrete import DQN
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment

from atari import create_atari_environment, preprocess_fn


def get_args():
Expand Down Expand Up @@ -49,8 +50,8 @@ def test_dqn(args=get_args()):
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv([
lambda: create_atari_environment(
args.task) for _ in range(args.test_num)])
lambda: create_atari_environment(args.task)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
Expand All @@ -68,8 +69,9 @@ def test_dqn(args=get_args()):
target_update_freq=args.target_update_freq)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
policy, train_envs, ReplayBuffer(args.buffer_size),
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * 4)
print(len(train_collector.buffer))
Expand Down Expand Up @@ -101,7 +103,7 @@ def test_fn(x):
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
Expand Down
21 changes: 11 additions & 10 deletions examples/pong_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net

from atari import create_atari_environment, preprocess_fn


def get_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -44,17 +45,16 @@ def get_args():


def test_ppo(args=get_args()):
env = create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
env = create_atari_environment(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space().shape or env.action_space().n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv([lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
train_envs = SubprocVectorEnv([
lambda: create_atari_environment(args.task)
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv([lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
test_envs = SubprocVectorEnv([
lambda: create_atari_environment(args.task)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
Expand All @@ -77,8 +77,9 @@ def test_ppo(args=get_args()):
action_range=None)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
policy, train_envs, ReplayBuffer(args.buffer_size),
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(args.logdir + '/' + 'ppo')

Expand All @@ -100,7 +101,7 @@ def stop_fn(x):
pprint.pprint(result)
# Let's watch its performance!
env = create_atari_environment(args.task)
collector = Collector(policy, env)
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
result = collector.collect(n_step=2000, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
Expand Down
2 changes: 1 addition & 1 deletion test/discrete/test_drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
Expand Down
2 changes: 1 addition & 1 deletion tianshou/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, \
exploration

__version__ = '0.2.3'
__version__ = '0.2.4'
__all__ = [
'env',
'data',
Expand Down