这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def IsStop(reward):

policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
Expand Down
8 changes: 3 additions & 5 deletions examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,12 @@ def test_sac(args=get_args()):
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
alpha = (target_entropy, log_alpha, alpha_optim)
else:
alpha = args.alpha
args.alpha = (target_entropy, log_alpha, alpha_optim)

policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, ignore_done=True,
exploration_noise=OUNoise(0.0, args.noise_std))
# collector
Expand Down
5 changes: 3 additions & 2 deletions examples/mujoco/ant_v2_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/ant_v2_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_sac(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, ignore_done=True)
# collector
train_collector = Collector(
Expand Down
10 changes: 6 additions & 4 deletions examples/mujoco/ant_v2_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma,
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/halfcheetahBullet_v0_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def test_sac(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(
Expand Down
10 changes: 6 additions & 4 deletions examples/mujoco/point_maze_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma,
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(
Expand Down
5 changes: 3 additions & 2 deletions test/continuous/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
Expand Down
4 changes: 2 additions & 2 deletions test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def test_sac_with_il(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
Expand Down
9 changes: 6 additions & 3 deletions test/continuous/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,12 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
args.policy_noise, args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
Expand Down
18 changes: 7 additions & 11 deletions tianshou/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from tianshou import data, env, utils, policy, trainer, exploration

# pre-compile some common-type function-call to produce the correct benchmark
# result: https://github.com/thu-ml/tianshou/pull/193#discussion_r480536371
utils.pre_compile()


__version__ = '0.2.7'
__version__ = "0.2.7"

__all__ = [
'env',
'data',
'utils',
'policy',
'trainer',
'exploration',
"env",
"data",
"utils",
"policy",
"trainer",
"exploration",
]
21 changes: 10 additions & 11 deletions tianshou/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from tianshou.data.batch import Batch
from tianshou.data.utils.converter import to_numpy, to_torch, \
to_torch_as
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector

__all__ = [
'Batch',
'to_numpy',
'to_torch',
'to_torch_as',
'SegmentTree',
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',
'Collector',
"Batch",
"to_numpy",
"to_torch",
"to_torch_as",
"SegmentTree",
"ReplayBuffer",
"ListReplayBuffer",
"PrioritizedReplayBuffer",
"Collector",
]
Loading