这是indexloc提供的服务,不要输入任何密码
Skip to content

readd import PointMaze, which has been wrongly removed #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e458845
update atari.py
Mehooz Mar 25, 2020
920b398
fix setup.py
Mehooz Mar 26, 2020
3c6b130
fix setup.py
Mehooz Mar 26, 2020
0a9c3bc
add args "render"
Mehooz Mar 26, 2020
e18979d
add args render
Mehooz Mar 26, 2020
e451ec0
Merge branch 'n+e'
Mehooz Mar 26, 2020
d83c0e7
change the tensorboard writter
Mehooz Mar 26, 2020
dfb7ead
change the tensorboard writter
Mehooz Mar 26, 2020
449b687
change device, render, tensorboard log location
Mehooz Mar 27, 2020
1a36ee3
Merge remote-tracking branch 'origin/master'
Mehooz Mar 27, 2020
41266c1
change device, render, tensorboard log location
Mehooz Mar 27, 2020
d8c4152
remove some wrong local files
Mehooz Mar 27, 2020
38e80f1
fix some tab mistakes and the envs name in continuous/test_xx.py
Mehooz Mar 27, 2020
ecc2e4f
add examples and point robot maze environment
Mehooz Mar 27, 2020
0acba5f
fix some bugs during testing examples
Mehooz Mar 27, 2020
95f0874
add dqn network and fix some args
Mehooz Mar 27, 2020
0be25e5
change back the tensorboard writter's frequency to ensure ppo and a2c…
Mehooz Mar 27, 2020
2255131
add a warning to collector
Mehooz Mar 27, 2020
b30b818
rm some unrelated files
Mehooz Mar 27, 2020
0023427
reformat
Mehooz Mar 27, 2020
c28b19d
fix a bug in test_dqn due to the model wrong selection
Mehooz Mar 27, 2020
2f2d267
change atari frame skip and observation to improve performance
Mehooz Mar 28, 2020
87f429d
readd some files
Mehooz Mar 28, 2020
b3be8ff
change import
Mehooz Mar 28, 2020
0b90d18
Merge remote-tracking branch 'remotes/origin/master'
Mehooz Mar 28, 2020
3d0983c
modified readme
Mehooz Mar 28, 2020
c0ff794
rm tensorboard log
Mehooz Mar 28, 2020
c026a1e
update atari and mujoco which are ignored
Mehooz Mar 28, 2020
f3715d4
rm the wrong lines
Mehooz Mar 28, 2020
34ae9f0
Merge remote-tracking branch 'origin/master'
Mehooz Mar 28, 2020
a97d821
readd the import of PointMaze
Mehooz Mar 28, 2020
cb478d9
fix a typo in test/discrete/net.py
Mehooz Mar 28, 2020
a21d800
add a class declaration to pass the flake8
Mehooz Mar 28, 2020
8d54e4a
fix flake8 errors
Mehooz Mar 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/point_maze_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv

from continuous_net import Actor, Critic


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PointMaze-v0')
parser.add_argument('--task', type=str, default='PointMaze-v1')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=3e-5)
Expand Down
2 changes: 2 additions & 0 deletions tianshou/env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from tianshou.env.common import EnvWrapper, FrameStack
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
SubprocVectorEnv, RayVectorEnv
from tianshou.env import mujoco

__all__ = [
'mujoco',
'EnvWrapper',
'FrameStack',
'BaseVectorEnv',
Expand Down
37 changes: 19 additions & 18 deletions tianshou/env/mujoco/point_maze_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def get_top_down_view(self):

def valid(row, col):
return self._view.shape[0] > row >= 0 \
and self._view.shape[1] > col >= 0
and self._view.shape[1] > col >= 0

def update_view(x, y, d, row=None, col=None):
if row is None or col is None:
Expand All @@ -252,36 +252,36 @@ def update_view(x, y, d, row=None, col=None):

if valid(row, col):
self._view[row, col, d] += (
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
if valid(row - 1, col):
self._view[row - 1, col, d] += (
(max(0., 0.5 - row_frac)) *
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
(max(0., 0.5 - row_frac)) *
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
if valid(row + 1, col):
self._view[row + 1, col, d] += (
(max(0., row_frac - 0.5)) *
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
(max(0., row_frac - 0.5)) *
(min(1., col_frac + 0.5) - max(0., col_frac - 0.5)))
if valid(row, col - 1):
self._view[row, col - 1, d] += (
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
(max(0., 0.5 - col_frac)))
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
(max(0., 0.5 - col_frac)))
if valid(row, col + 1):
self._view[row, col + 1, d] += (
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
(max(0., col_frac - 0.5)))
(min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) *
(max(0., col_frac - 0.5)))
if valid(row - 1, col - 1):
self._view[row - 1, col - 1, d] += (
(max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
(max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac))
if valid(row - 1, col + 1):
self._view[row - 1, col + 1, d] += (
(max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
(max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5))
if valid(row + 1, col + 1):
self._view[row + 1, col + 1, d] += (
(max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
(max(0., row_frac - 0.5)) * max(0., col_frac - 0.5))
if valid(row + 1, col - 1):
self._view[row + 1, col - 1, d] += (
(max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))
(max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac))

# Draw ant.
robot_x, robot_y = self.wrapped_env.get_body_com("torso")[:2]
Expand Down Expand Up @@ -376,7 +376,8 @@ def get_range_sensor_obs(self):
sensor_readings = np.zeros((self._n_bins, 3))
for ray_idx in range(self._n_bins):
ray_ori = (ori - self._sensor_span * 0.5 + (
2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span)
2 * ray_idx + 1.0) /
(2 * self._n_bins) * self._sensor_span)
ray_segments = []
# Get all segments that intersect with ray.
for seg in segments:
Expand All @@ -401,8 +402,8 @@ def get_range_sensor_obs(self):
2 if maze_env_utils.can_move(seg_type) else # Block.
None)
if first_seg["distance"] <= self._sensor_range:
sensor_readings[ray_idx][idx] = (
self._sensor_range - first_seg[
sensor_readings[ray_idx][idx] = \
(self._sensor_range - first_seg[
"distance"]) / self._sensor_range
return sensor_readings

Expand Down