这是indexloc提供的服务,不要输入任何密码
Skip to content

rm the wrong lines, add the ignored files #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e458845
update atari.py
Mehooz Mar 25, 2020
920b398
fix setup.py
Mehooz Mar 26, 2020
3c6b130
fix setup.py
Mehooz Mar 26, 2020
0a9c3bc
add args "render"
Mehooz Mar 26, 2020
e18979d
add args render
Mehooz Mar 26, 2020
e451ec0
Merge branch 'n+e'
Mehooz Mar 26, 2020
d83c0e7
change the tensorboard writter
Mehooz Mar 26, 2020
dfb7ead
change the tensorboard writter
Mehooz Mar 26, 2020
449b687
change device, render, tensorboard log location
Mehooz Mar 27, 2020
1a36ee3
Merge remote-tracking branch 'origin/master'
Mehooz Mar 27, 2020
41266c1
change device, render, tensorboard log location
Mehooz Mar 27, 2020
d8c4152
remove some wrong local files
Mehooz Mar 27, 2020
38e80f1
fix some tab mistakes and the envs name in continuous/test_xx.py
Mehooz Mar 27, 2020
ecc2e4f
add examples and point robot maze environment
Mehooz Mar 27, 2020
0acba5f
fix some bugs during testing examples
Mehooz Mar 27, 2020
95f0874
add dqn network and fix some args
Mehooz Mar 27, 2020
0be25e5
change back the tensorboard writter's frequency to ensure ppo and a2c…
Mehooz Mar 27, 2020
2255131
add a warning to collector
Mehooz Mar 27, 2020
b30b818
rm some unrelated files
Mehooz Mar 27, 2020
0023427
reformat
Mehooz Mar 27, 2020
c28b19d
fix a bug in test_dqn due to the model wrong selection
Mehooz Mar 27, 2020
2f2d267
change atari frame skip and observation to improve performance
Mehooz Mar 28, 2020
87f429d
readd some files
Mehooz Mar 28, 2020
b3be8ff
change import
Mehooz Mar 28, 2020
0b90d18
Merge remote-tracking branch 'remotes/origin/master'
Mehooz Mar 28, 2020
3d0983c
modified readme
Mehooz Mar 28, 2020
c0ff794
rm tensorboard log
Mehooz Mar 28, 2020
c026a1e
update atari and mujoco which are ignored
Mehooz Mar 28, 2020
f3715d4
rm the wrong lines
Mehooz Mar 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so

# .idea folder
.idea/

# Distribution / packaging
.Python
build/
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ If you find Tianshou useful, please cite it in your publications.

```latex
@misc{tianshou,
author = {Jiayi Weng},
author = {Jiayi Weng, Minghao Zhang},
title = {Tianshou},
year = {2020},
publisher = {GitHub},
Expand Down
10 changes: 6 additions & 4 deletions examples/discrete_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, preprocess_net, action_shape):

def forward(self, s, state=None, info={}):
logits, h = self.preprocess(s, state)
logits = F.softmax(self.last(logits), dim=-1)
logits = F.softmax(logits, dim=-1)
return logits, h


Expand All @@ -56,7 +56,7 @@ def __init__(self, h, w, action_shape, device='cpu'):
super(DQN, self).__init__()
self.device = device

self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=2)
self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(32)
Expand All @@ -69,7 +69,8 @@ def conv2d_size_out(size, kernel_size=5, stride=2):
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
linear_input_size = convw * convh * 32
self.head = nn.Linear(linear_input_size, action_shape)
self.fc = nn.Linear(linear_input_size, 512)
self.head = nn.Linear(512, action_shape)

def forward(self, x, state=None, info={}):
if not isinstance(x, torch.Tensor):
Expand All @@ -78,4 +79,5 @@ def forward(self, x, state=None, info={}):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1)), state
x = self.fc(x.reshape(x.size(0), -1))
return self.head(x), state
11 changes: 7 additions & 4 deletions test/discrete/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def forward(self, s):


class DQN(nn.Module):

def __init__(self, h, w, action_shape, device='cpu'):
super(DQN, self).__init__()
self.device = device

self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(32)
Expand All @@ -68,12 +69,14 @@ def conv2d_size_out(size, kernel_size=5, stride=2):
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
linear_input_size = convw * convh * 32
self.head = nn.Linear(linear_input_size, action_shape)
self.fc = nn.Linear(linear_input_size, 512)
self.head = nn.Linear(512, action_shape)

def forward(self, x, state=None, info={}):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float)
s = torch.tensor(x, device=self.device, dtype=torch.float)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1)), state
x = self.fc(x.reshape(x.size(0), -1))
return self.head(x), state
124 changes: 124 additions & 0 deletions tianshou/env/atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import gym
import numpy as np
from gym.spaces.box import Box


def create_atari_environment(name=None, sticky_actions=True, max_episode_steps=2000):
game_version = 'v0' if sticky_actions else 'v4'
name = '{}NoFrameskip-{}'.format(name, game_version)
env = gym.make(name)
env = env.env
env = preprocessing(env, max_episode_steps=max_episode_steps)
return env


class preprocessing(object):

def __init__(self, env, frame_skip=4, terminal_on_life_loss=False,
size=84, max_episode_steps=2000):
self.max_episode_steps = max_episode_steps
self.env = env
self.terminal_on_life_loss = terminal_on_life_loss
self.frame_skip = frame_skip
self.size = size
self.count = 0
obs_dims = self.env.observation_space

self.screen_buffer = [
np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8),
np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8)
]

self.game_over = False
self.lives = 0

@property
def observation_space(self):

return Box(low=0, high=255, shape=(self.size, self.size, 4),
dtype=np.uint8)

def action_space(self):
return self.env.action_space

def reward_range(self):
return self.env.reward_range

def metadata(self):
return self.env.metadata

def close(self):
return self.env.close()

def reset(self):
self.count = 0
self.env.reset()
self.lives = self.env.ale.lives()
self._grayscale_obs(self.screen_buffer[0])
self.screen_buffer[1].fill(0)

return np.stack([self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)

def render(self, mode):

return self.env.render(mode)

def step(self, action):

total_reward = 0.
observation = []
for t in range(self.frame_skip):
self.count += 1
_, reward, terminal, info = self.env.step(action)
total_reward += reward

if self.terminal_on_life_loss:
lives = self.env.ale.lives()
is_terminal = terminal or lives < self.lives
self.lives = lives
else:
is_terminal = terminal

if is_terminal:
break
elif t >= self.frame_skip - 2:
t_ = t - (self.frame_skip - 2)
self._grayscale_obs(self.screen_buffer[t_])

observation.append(self._pool_and_resize())
while len(observation) > 0 and len(observation) < self.frame_skip:
observation.append(observation[-1])
if len(observation) > 0:
observation = np.stack(observation, axis=-1)
else:
observation = np.stack([self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
if self.count >= self.max_episode_steps:
terminal = True
self.terminal = terminal
return observation, total_reward, is_terminal, info

def _grayscale_obs(self, output):

self.env.ale.getScreenGrayscale(output)
return output

def _pool_and_resize(self):

if self.frame_skip > 1:
np.maximum(self.screen_buffer[0], self.screen_buffer[1],
out=self.screen_buffer[0])

transformed_image = cv2.resize(self.screen_buffer[0],
(self.size, self.size),
interpolation=cv2.INTER_AREA)
int_image = np.asarray(transformed_image, dtype=np.uint8)
# return np.expand_dims(int_image, axis=2)
return int_image


if __name__ == '__main__':
create_atari_environment()
26 changes: 26 additions & 0 deletions tianshou/env/mujoco/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from gym.envs.registration import register
import gym

register(
id='PointMaze-v0',
entry_point='tianshou.env.mujoco.point_maze_env:PointMazeEnv',
kwargs={
"maze_size_scaling": 4,
"maze_id": "Maze2",
"maze_height": 0.5,
"manual_collision": True,
"goal": (1, 3),
}
)

register(
id='PointMaze-v1',
entry_point='tianshou.env.mujoco.point_maze_env:PointMazeEnv',
kwargs={
"maze_size_scaling": 2,
"maze_id": "Maze2",
"maze_height": 0.5,
"manual_collision": True,
"goal": (1, 3),
}
)
34 changes: 34 additions & 0 deletions tianshou/env/mujoco/assets/point.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<mujoco>
<compiler inertiafromgeom="true" angle="degree" coordinate="local"/>
<option timestep="0.02" integrator="RK4"/>
<default>
<joint limited="false" armature="0" damping="0"/>
<geom condim="3" conaffinity="0" margin="0" friction="1 0.5 0.5" rgba="0.8 0.6 0.4 1" density="100"/>
</default>
<asset>
<texture type="skybox" builtin="gradient" width="100" height="100" rgb1="1 1 1" rgb2="0 0 0"/>
<texture name="texgeom" type="cube" builtin="flat" mark="cross" width="127" height="1278" rgb1="0.8 0.6 0.4"
rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>
<texture name="texplane" type="2d" builtin="checker" rgb1="0 0 0" rgb2="0.8 0.8 0.8" width="100" height="100"/>
<material name='MatPlane' texture="texplane" shininess="1" texrepeat="30 30" specular="1" reflectance="0.5"/>
<material name='geom' texture="texgeom" texuniform="true"/>
</asset>
<worldbody>
<light directional="true" cutoff="100" exponent="1" diffuse="1 1 1" specular=".1 .1 .1" pos="0 0 1.3"
dir="-0 0 -1.3"/>
<geom name='floor' material="MatPlane" pos='0 0 0' size='40 40 40' type='plane' conaffinity='1'
rgba='0.8 0.9 0.8 1' condim='3'/>
<body name="torso" pos="0 0 0">
<geom name="pointbody" type="sphere" size="0.5" pos="0 0 0.5"/>
<geom name="pointarrow" type="box" size="0.5 0.1 0.1" pos="0.6 0 0.5"/>
<joint name='ballx' type='slide' axis='1 0 0' pos='0 0 0'/>
<joint name='bally' type='slide' axis='0 1 0' pos='0 0 0'/>
<joint name='rot' type='hinge' axis='0 0 1' pos='0 0 0' limited="false"/>
</body>
</worldbody>
<actuator>
<!-- Those are just dummy actuators for providing ranges -->
<motor joint='ballx' ctrlrange="-1 1" ctrllimited="true"/>
<motor joint='rot' ctrlrange="-0.25 0.25" ctrllimited="true"/>
</actuator>
</mujoco>
Loading