这是indexloc提供的服务,不要输入任何密码
Skip to content

update utils.network #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ from tianshou.utils.net.common import Net
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape)
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)
```

Expand Down
6 changes: 4 additions & 2 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ The explanation of each Tianshou class/function will be deferred to their first
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
Expand Down Expand Up @@ -249,7 +250,8 @@ Here it is:
args.action_shape = env.action_space.shape or env.action_space.n

if agent_learn is None:
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
if optim is None:
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent_learn = DQNPolicy(
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from tianshou.policy import C51Policy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import C51
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer

from atari_network import C51
from atari_wrapper import wrap_deepmind


Expand Down Expand Up @@ -40,8 +40,8 @@ def get_args():
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import DQN
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer

from atari_network import DQN
from atari_wrapper import wrap_deepmind


Expand All @@ -37,8 +37,8 @@ def get_args():
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()
Expand Down
82 changes: 82 additions & 0 deletions examples/atari/atari_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import numpy as np
from torch import nn
from typing import Any, Dict, Tuple, Union, Optional, Sequence


class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.

For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""

def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
device: Union[str, int, torch.device] = "cpu",
features_only: bool = False,
) -> None:
super().__init__()
self.device = device
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
nn.Flatten())
with torch.no_grad():
self.output_dim = np.prod(
self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only:
self.net = nn.Sequential(
self.net,
nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
nn.Linear(512, np.prod(action_shape)))
self.output_dim = np.prod(action_shape)

def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
x = torch.as_tensor(
x, device=self.device, dtype=torch.float32) # type: ignore
return self.net(x), state


class C51(DQN):
"""Reference: A distributional perspective on reinforcement learning.

For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""

def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
self.action_shape = action_shape
self.num_atoms = num_atoms

def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.num_atoms).softmax(dim=-1)
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
return x, state
14 changes: 8 additions & 6 deletions examples/atari/runnable/pong_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from tianshou.policy import A2CPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net

from atari import create_atari_environment, preprocess_fn

Expand All @@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=8)
parser.add_argument('--logdir', type=str, default='log')
Expand All @@ -40,7 +41,7 @@ def get_args():
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--max_episode_steps', type=int, default=2000)
parser.add_argument('--max-episode-steps', type=int, default=2000)
return parser.parse_args()


Expand All @@ -62,11 +63,12 @@ def test_a2c(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,
Expand Down
14 changes: 8 additions & 6 deletions examples/atari/runnable/pong_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from tianshou.policy import PPOPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net

from atari import create_atari_environment, preprocess_fn

Expand All @@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=8)
parser.add_argument('--logdir', type=str, default='log')
Expand All @@ -40,7 +41,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--max_episode_steps', type=int, default=2000)
parser.add_argument('--max-episode-steps', type=int, default=2000)
return parser.parse_args()


Expand All @@ -62,11 +63,12 @@ def test_ppo(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,
Expand Down
14 changes: 11 additions & 3 deletions examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=0)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128])
parser.add_argument('--dueling-q-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--dueling-v-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
Expand Down Expand Up @@ -56,8 +61,11 @@ def test_dqn(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
args.action_shape, args.device, dueling=(2, 2)).to(args.device)
Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
dueling_param=(Q_param, V_param)).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
Expand Down
24 changes: 14 additions & 10 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
Expand Down Expand Up @@ -87,20 +88,23 @@ def test_sac_bipedal(args=get_args()):
test_envs.seed(args.seed)

# model
net_a = Net(args.layer_num, args.state_shape, device=args.device)
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = ActorProb(
net_a, args.action_shape, args.max_action, args.device, unbounded=True
).to(args.device)
net_a, args.action_shape, max_action=args.max_action,
device=args.device, unbounded=True).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)

net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

if args.auto_alpha:
Expand Down
15 changes: 11 additions & 4 deletions examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--collect-per-step', type=int, default=16)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--dueling-q-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--dueling-v-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
Expand Down Expand Up @@ -57,9 +62,11 @@ def test_dqn(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
args.action_shape, args.device,
dueling=(2, 2)).to(args.device)
Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
dueling_param=(Q_param, V_param)).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
Expand Down
22 changes: 13 additions & 9 deletions examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
Expand Down Expand Up @@ -61,19 +62,22 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = ActorProb(
net, args.action_shape,
args.max_action, args.device, unbounded=True
max_action=args.max_action, device=args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True,
device=args.device)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True,
device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

if args.auto_alpha:
Expand Down
Loading