From ba87cb490704803fc659f6b958d06210c1443172 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jul 2020 17:30:57 +0800 Subject: [PATCH 01/11] remove dummy net; delete two files --- examples/ant_v2_ddpg.py | 3 +- examples/ant_v2_sac.py | 3 +- examples/ant_v2_td3.py | 3 +- examples/continuous_net.py | 81 ------------------ examples/discrete_net.py | 83 ------------------- examples/halfcheetahBullet_v0_sac.py | 2 +- examples/point_maze_td3.py | 2 +- examples/pong_a2c.py | 6 +- examples/pong_dqn.py | 3 +- examples/pong_ppo.py | 7 +- examples/sac_mcc.py | 2 +- test/continuous/test_ddpg.py | 6 +- test/continuous/test_ppo.py | 6 +- test/continuous/test_sac_with_il.py | 6 +- test/continuous/test_td3.py | 6 +- test/discrete/test_a2c_with_il.py | 12 +-- test/discrete/test_dqn.py | 6 +- test/discrete/test_drqn.py | 6 +- test/discrete/test_pdqn.py | 6 +- test/discrete/test_pg.py | 6 +- test/discrete/test_ppo.py | 10 +-- tianshou/utils/net/__init__.py | 0 .../utils/net/continuous.py | 5 +- .../net.py => tianshou/utils/net/discrete.py | 41 ++++++++- 24 files changed, 71 insertions(+), 240 deletions(-) delete mode 100644 examples/continuous_net.py delete mode 100644 examples/discrete_net.py create mode 100644 tianshou/utils/net/__init__.py rename test/continuous/net.py => tianshou/utils/net/continuous.py (96%) rename test/discrete/net.py => tianshou/utils/net/discrete.py (69%) diff --git a/examples/ant_v2_ddpg.py b/examples/ant_v2_ddpg.py index 16b029908..5ea3b778e 100644 --- a/examples/ant_v2_ddpg.py +++ b/examples/ant_v2_ddpg.py @@ -10,8 +10,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise - -from continuous_net import Actor, Critic +from tianshou.utils.net.continuous import Actor, Critic def get_args(): diff --git a/examples/ant_v2_sac.py b/examples/ant_v2_sac.py index 1d28615ec..947ef5fbb 100644 --- a/examples/ant_v2_sac.py +++ b/examples/ant_v2_sac.py @@ -10,8 +10,7 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv - -from continuous_net import ActorProb, Critic +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): diff --git a/examples/ant_v2_td3.py b/examples/ant_v2_td3.py index 45770b28d..c17c0f94f 100644 --- a/examples/ant_v2_td3.py +++ b/examples/ant_v2_td3.py @@ -10,8 +10,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise - -from continuous_net import Actor, Critic +from tianshou.utils.net.continuous import Actor, Critic def get_args(): diff --git a/examples/continuous_net.py b/examples/continuous_net.py deleted file mode 100644 index c76ab1745..000000000 --- a/examples/continuous_net.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import numpy as np -from torch import nn - - -class Actor(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model += [nn.Linear(128, np.prod(action_shape))] - self.model = nn.Sequential(*self.model) - self._max = max_action - - def forward(self, s, **kwargs): - s = torch.tensor(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - logits = self._max * torch.tanh(logits) - return logits, None - - -class ActorProb(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu', unbounded=False): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model = nn.Sequential(*self.model) - self.mu = nn.Linear(128, np.prod(action_shape)) - self.sigma = nn.Linear(128, np.prod(action_shape)) - self._max = max_action - self._unbounded = unbounded - - def forward(self, s, **kwargs): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - if not self._unbounded: - mu = self._max * torch.tanh(self.mu(logits)) - sigma = torch.exp(self.sigma(logits)) - return (mu, sigma), None - - -class Critic(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model += [nn.Linear(128, 1)] - self.model = nn.Sequential(*self.model) - - def forward(self, s, a=None): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) - if a is not None and not isinstance(a, torch.Tensor): - a = torch.tensor(a, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - if a is None: - logits = self.model(s) - else: - a = a.view(batch, -1) - logits = self.model(torch.cat([s, a], dim=1)) - return logits diff --git a/examples/discrete_net.py b/examples/discrete_net.py deleted file mode 100644 index eda710866..000000000 --- a/examples/discrete_net.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import numpy as np -from torch import nn -import torch.nn.functional as F - - -class Net(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - if action_shape: - self.model += [nn.Linear(128, np.prod(action_shape))] - self.model = nn.Sequential(*self.model) - - def forward(self, s, state=None, info={}): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - return logits, state - - -class Actor(nn.Module): - def __init__(self, preprocess_net, action_shape): - super().__init__() - self.preprocess = preprocess_net - self.last = nn.Linear(128, np.prod(action_shape)) - - def forward(self, s, state=None, info={}): - logits, h = self.preprocess(s, state) - logits = F.softmax(self.last(logits), dim=-1) - return logits, h - - -class Critic(nn.Module): - def __init__(self, preprocess_net): - super().__init__() - self.preprocess = preprocess_net - self.last = nn.Linear(128, 1) - - def forward(self, s): - logits, h = self.preprocess(s, None) - logits = self.last(logits) - return logits - - -class DQN(nn.Module): - - def __init__(self, h, w, action_shape, device='cpu'): - super(DQN, self).__init__() - self.device = device - - self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2) - self.bn1 = nn.BatchNorm2d(16) - self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2) - self.bn2 = nn.BatchNorm2d(32) - self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2) - self.bn3 = nn.BatchNorm2d(32) - - def conv2d_size_out(size, kernel_size=5, stride=2): - return (size - (kernel_size - 1) - 1) // stride + 1 - - convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w))) - convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h))) - linear_input_size = convw * convh * 32 - self.fc = nn.Linear(linear_input_size, 512) - self.head = nn.Linear(512, action_shape) - - def forward(self, x, state=None, info={}): - if not isinstance(x, torch.Tensor): - x = torch.tensor(x, device=self.device, dtype=torch.float) - x = x.permute(0, 3, 1, 2) - x = F.relu(self.bn1(self.conv1(x))) - x = F.relu(self.bn2(self.conv2(x))) - x = F.relu(self.bn3(self.conv3(x))) - x = self.fc(x.reshape(x.size(0), -1)) - return self.head(x), state diff --git a/examples/halfcheetahBullet_v0_sac.py b/examples/halfcheetahBullet_v0_sac.py index 57ca8bae7..66f39653e 100644 --- a/examples/halfcheetahBullet_v0_sac.py +++ b/examples/halfcheetahBullet_v0_sac.py @@ -16,7 +16,7 @@ except ImportError: pass -from continuous_net import ActorProb, Critic +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): diff --git a/examples/point_maze_td3.py b/examples/point_maze_td3.py index b3f2c95a7..012eb60d5 100644 --- a/examples/point_maze_td3.py +++ b/examples/point_maze_td3.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise -from continuous_net import Actor, Critic +from tianshou.utils.net.continuous import Actor, Critic from mujoco.register import reg diff --git a/examples/pong_a2c.py b/examples/pong_a2c.py index 0490a43ac..1cfc71fed 100644 --- a/examples/pong_a2c.py +++ b/examples/pong_a2c.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment -from discrete_net import Net, Actor, Critic +from tianshou.utils.net.discrete import Net, ActorHead, CriticHead def get_args(): @@ -65,8 +65,8 @@ def test_a2c(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = ActorHead(net, args.action_shape).to(args.device) + critic = CriticHead(net).to(args.device) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/examples/pong_dqn.py b/examples/pong_dqn.py index 98c49a8a8..85a7e6ec9 100644 --- a/examples/pong_dqn.py +++ b/examples/pong_dqn.py @@ -6,12 +6,11 @@ from tianshou.policy import DQNPolicy from tianshou.env import SubprocVectorEnv +from tianshou.utils.net.discrete import DQN from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment -from discrete_net import DQN - def get_args(): parser = argparse.ArgumentParser() diff --git a/examples/pong_ppo.py b/examples/pong_ppo.py index f9976b570..4b7213b32 100644 --- a/examples/pong_ppo.py +++ b/examples/pong_ppo.py @@ -9,8 +9,7 @@ from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment - -from discrete_net import Net, Actor, Critic +from tianshou.utils.net.discrete import Net, ActorHead, CriticHead def get_args(): @@ -63,8 +62,8 @@ def test_ppo(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = ActorHead(net, args.action_shape).to(args.device) + critic = CriticHead(net).to(args.device) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/examples/sac_mcc.py b/examples/sac_mcc.py index 6455975be..6ccbd5815 100644 --- a/examples/sac_mcc.py +++ b/examples/sac_mcc.py @@ -12,7 +12,7 @@ from tianshou.env import VectorEnv from tianshou.exploration import OUNoise -from continuous_net import ActorProb, Critic +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 9428f1152..9853f9f3c 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -11,11 +11,7 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.exploration import GaussianNoise - -if __name__ == '__main__': - from net import Actor, Critic -else: # pytest - from test.continuous.net import Actor, Critic +from tianshou.utils.net.continuous import Actor, Critic def get_args(): diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index dd0e765bc..013109d96 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -11,11 +11,7 @@ from tianshou.policy.dist import DiagGaussian from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -if __name__ == '__main__': - from net import ActorProb, Critic -else: # pytest - from test.continuous.net import ActorProb, Critic +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 8b1a3d1b2..55cf910bc 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -10,11 +10,7 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.policy import SACPolicy, ImitationPolicy - -if __name__ == '__main__': - from net import Actor, ActorProb, Critic -else: # pytest - from test.continuous.net import Actor, ActorProb, Critic +from tianshou.utils.net.continuous import Actor, ActorProb, Critic def get_args(): diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 6d3133bb7..f7ed5402e 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,11 +11,7 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.exploration import GaussianNoise - -if __name__ == '__main__': - from net import Actor, Critic -else: # pytest - from test.continuous.net import Actor, Critic +from tianshou.utils.net.continuous import Actor, Critic def get_args(): diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 267732f23..95ae72b8f 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -10,11 +10,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.trainer import onpolicy_trainer, offpolicy_trainer - -if __name__ == '__main__': - from net import Net, Actor, Critic -else: # pytest - from test.discrete.net import Net, Actor, Critic +from tianshou.utils.net.discrete import Net, ActorHead, CriticHead def get_args(): @@ -67,8 +63,8 @@ def test_a2c_with_il(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = ActorHead(net, args.action_shape).to(args.device) + critic = CriticHead(net).to(args.device) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical @@ -111,7 +107,7 @@ def stop_fn(x): if args.task == 'CartPole-v0': env.spec.reward_threshold = 190 # lower the goal net = Net(1, args.state_shape, device=args.device) - net = Actor(net, args.action_shape).to(args.device) + net = ActorHead(net, args.action_shape).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector(il_policy, test_envs) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b37819438..6125abd6c 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -10,11 +10,7 @@ from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -if __name__ == '__main__': - from net import Net -else: # pytest - from test.discrete.net import Net +from tianshou.utils.net.discrete import Net def get_args(): diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 8c15e3b64..d2f160b8d 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -10,11 +10,7 @@ from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -if __name__ == '__main__': - from net import Recurrent -else: # pytest - from test.discrete.net import Recurrent +from tianshou.utils.net.discrete import Recurrent def get_args(): diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py index 70bfdbb47..49682679c 100644 --- a/test/discrete/test_pdqn.py +++ b/test/discrete/test_pdqn.py @@ -6,16 +6,12 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils.net.discrete import Net from tianshou.env import VectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer -if __name__ == '__main__': - from net import Net -else: # pytest - from test.discrete.net import Net - def get_args(): parser = argparse.ArgumentParser() diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index d2817a759..2e87c813d 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -7,16 +7,12 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils.net.discrete import Net from tianshou.env import VectorEnv from tianshou.policy import PGPolicy from tianshou.trainer import onpolicy_trainer from tianshou.data import Batch, Collector, ReplayBuffer -if __name__ == '__main__': - from net import Net -else: # pytest - from test.discrete.net import Net - def compute_return_base(batch, aa=None, bb=None, gamma=0.1): returns = np.zeros_like(batch.rew) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 44850a85f..b18c9374b 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -10,11 +10,7 @@ from tianshou.policy import PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -if __name__ == '__main__': - from net import Net, Actor, Critic -else: # pytest - from test.discrete.net import Net, Actor, Critic +from tianshou.utils.net.discrete import Net, ActorHead, CriticHead def get_args(): @@ -69,8 +65,8 @@ def test_ppo(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = ActorHead(net, args.action_shape).to(args.device) + critic = CriticHead(net).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/tianshou/utils/net/__init__.py b/tianshou/utils/net/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/continuous/net.py b/tianshou/utils/net/continuous.py similarity index 96% rename from test/continuous/net.py rename to tianshou/utils/net/continuous.py index 043044a50..c4033961c 100644 --- a/test/continuous/net.py +++ b/tianshou/utils/net/continuous.py @@ -30,7 +30,7 @@ def forward(self, s, **kwargs): class ActorProb(nn.Module): def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu'): + max_action, device='cpu', unbounded=False): super().__init__() self.device = device self.model = [ @@ -43,6 +43,7 @@ def __init__(self, layer_num, state_shape, action_shape, self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) # self.sigma = nn.Linear(128, np.prod(action_shape)) self._max = max_action + self._unbounded = unbounded def forward(self, s, **kwargs): s = to_torch(s, device=self.device, dtype=torch.float) @@ -50,6 +51,8 @@ def forward(self, s, **kwargs): s = s.view(batch, -1) logits = self.model(s) mu = self.mu(logits) + if not self._unbounded: + mu = self._max * torch.tanh(mu) shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() diff --git a/test/discrete/net.py b/tianshou/utils/net/discrete.py similarity index 69% rename from test/discrete/net.py rename to tianshou/utils/net/discrete.py index 1dcf783a6..3a9afde66 100644 --- a/test/discrete/net.py +++ b/tianshou/utils/net/discrete.py @@ -1,3 +1,7 @@ +""" +Commonly used network modules for discrete output. +""" + import torch import numpy as np from torch import nn @@ -30,7 +34,7 @@ def forward(self, s, state=None, info={}): return logits, state -class Actor(nn.Module): +class ActorHead(nn.Module): def __init__(self, preprocess_net, action_shape): super().__init__() self.preprocess = preprocess_net @@ -42,7 +46,7 @@ def forward(self, s, state=None, info={}): return logits, h -class Critic(nn.Module): +class CriticHead(nn.Module): def __init__(self, preprocess_net): super().__init__() self.preprocess = preprocess_net @@ -89,3 +93,36 @@ def forward(self, s, state=None, info={}): # please ensure the first dim is batch size: [bsz, len, ...] return s, {'h': h.transpose(0, 1).detach(), 'c': c.transpose(0, 1).detach()} + + +class DQN(nn.Module): + + def __init__(self, h, w, action_shape, device='cpu'): + super(DQN, self).__init__() + self.device = device + + self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2) + self.bn1 = nn.BatchNorm2d(16) + self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2) + self.bn2 = nn.BatchNorm2d(32) + self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2) + self.bn3 = nn.BatchNorm2d(32) + + def conv2d_size_out(size, kernel_size=5, stride=2): + return (size - (kernel_size - 1) - 1) // stride + 1 + + convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w))) + convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h))) + linear_input_size = convw * convh * 32 + self.fc = nn.Linear(linear_input_size, 512) + self.head = nn.Linear(512, action_shape) + + def forward(self, x, state=None, info={}): + if not isinstance(x, torch.Tensor): + x = torch.tensor(x, device=self.device, dtype=torch.float) + x = x.permute(0, 3, 1, 2) + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x = self.fc(x.reshape(x.size(0), -1)) + return self.head(x), state From e4153203cd24af7639fe5486c88b907cd9247d1a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jul 2020 19:13:13 +0800 Subject: [PATCH 02/11] split code to have backbone and head --- examples/ant_v2_ddpg.py | 15 +++-- examples/ant_v2_sac.py | 17 +++--- examples/ant_v2_td3.py | 17 +++--- examples/halfcheetahBullet_v0_sac.py | 18 +++--- examples/point_maze_td3.py | 17 +++--- examples/pong_a2c.py | 3 +- examples/pong_ppo.py | 3 +- examples/sac_mcc.py | 18 +++--- test/continuous/test_ddpg.py | 13 ++-- test/continuous/test_ppo.py | 12 ++-- test/continuous/test_sac_with_il.py | 21 ++++--- test/continuous/test_td3.py | 17 +++--- test/discrete/test_a2c_with_il.py | 3 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_pdqn.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo.py | 3 +- tianshou/utils/net/common.py | 78 ++++++++++++++++++++++++ tianshou/utils/net/continuous.py | 90 +++++++++++----------------- tianshou/utils/net/discrete.py | 63 ------------------- 21 files changed, 218 insertions(+), 198 deletions(-) create mode 100644 tianshou/utils/net/common.py diff --git a/examples/ant_v2_ddpg.py b/examples/ant_v2_ddpg.py index 5ea3b778e..206b00a0c 100644 --- a/examples/ant_v2_ddpg.py +++ b/examples/ant_v2_ddpg.py @@ -10,7 +10,8 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHead, CriticHead def get_args(): @@ -56,14 +57,12 @@ def test_ddpg(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = Actor( - args.layer_num, args.state_shape, args.action_shape, - args.max_action, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHead(net, args.action_shape, args.max_action, + args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic = CriticHead(net, args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, diff --git a/examples/ant_v2_sac.py b/examples/ant_v2_sac.py index 947ef5fbb..a8e6f6536 100644 --- a/examples/ant_v2_sac.py +++ b/examples/ant_v2_sac.py @@ -10,7 +10,8 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHeadProb, CriticHead def get_args(): @@ -57,17 +58,19 @@ def test_sac(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHeadProb( + net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic1 = CriticHead( + net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + critic2 = CriticHead( + net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( diff --git a/examples/ant_v2_td3.py b/examples/ant_v2_td3.py index c17c0f94f..d251566fd 100644 --- a/examples/ant_v2_td3.py +++ b/examples/ant_v2_td3.py @@ -10,7 +10,8 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHead, CriticHead def get_args(): @@ -59,17 +60,19 @@ def test_td3(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = Actor( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHead( + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic1 = CriticHead( + net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + critic2 = CriticHead( + net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( diff --git a/examples/halfcheetahBullet_v0_sac.py b/examples/halfcheetahBullet_v0_sac.py index 66f39653e..054105d60 100644 --- a/examples/halfcheetahBullet_v0_sac.py +++ b/examples/halfcheetahBullet_v0_sac.py @@ -15,8 +15,8 @@ import pybullet_envs except ImportError: pass - -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHeadProb, CriticHead def get_args(): @@ -66,17 +66,19 @@ def test_sac(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHeadProb( + net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic1 = CriticHead( + net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + critic2 = CriticHead( + net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( diff --git a/examples/point_maze_td3.py b/examples/point_maze_td3.py index 012eb60d5..c603dccca 100644 --- a/examples/point_maze_td3.py +++ b/examples/point_maze_td3.py @@ -10,7 +10,8 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHead, CriticHead from mujoco.register import reg @@ -63,17 +64,19 @@ def test_td3(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = Actor( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHead( + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic1 = CriticHead( + net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + critic2 = CriticHead( + net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( diff --git a/examples/pong_a2c.py b/examples/pong_a2c.py index 1cfc71fed..c6424aba1 100644 --- a/examples/pong_a2c.py +++ b/examples/pong_a2c.py @@ -10,7 +10,8 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment -from tianshou.utils.net.discrete import Net, ActorHead, CriticHead +from tianshou.utils.net.discrete import ActorHead, CriticHead +from tianshou.utils.net.common import Net def get_args(): diff --git a/examples/pong_ppo.py b/examples/pong_ppo.py index 4b7213b32..e5d4fcab1 100644 --- a/examples/pong_ppo.py +++ b/examples/pong_ppo.py @@ -9,7 +9,8 @@ from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment -from tianshou.utils.net.discrete import Net, ActorHead, CriticHead +from tianshou.utils.net.discrete import ActorHead, CriticHead +from tianshou.utils.net.common import Net def get_args(): diff --git a/examples/sac_mcc.py b/examples/sac_mcc.py index 6ccbd5815..28b305fa3 100644 --- a/examples/sac_mcc.py +++ b/examples/sac_mcc.py @@ -11,8 +11,8 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv from tianshou.exploration import OUNoise - -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHeadProb, CriticHead def get_args(): @@ -62,17 +62,19 @@ def test_sac(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHeadProb( + net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic1 = CriticHead( + net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + critic2 = CriticHead( + net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 9853f9f3c..aee7e04a6 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -11,7 +11,8 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.exploration import GaussianNoise -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHead, CriticHead def get_args(): @@ -65,13 +66,15 @@ def test_ddpg(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = Actor( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHead( + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic = CriticHead( + net, args.device ).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 013109d96..6d90b614f 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -11,7 +11,8 @@ from tianshou.policy.dist import DiagGaussian from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHeadProb, CriticHead def get_args(): @@ -68,12 +69,13 @@ def test_ppo(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHeadProb( + net, args.action_shape, args.max_action, args.device ).to(args.device) - critic = Critic( - args.layer_num, args.state_shape, device=args.device + critic = CriticHead( + Net(args.layer_num, args.state_shape), device=args.device ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 55cf910bc..c01e59e4a 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -10,7 +10,8 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.policy import SACPolicy, ImitationPolicy -from tianshou.utils.net.continuous import Actor, ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHead, ActorHeadProb, CriticHead def get_args(): @@ -64,17 +65,19 @@ def test_sac_with_il(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHeadProb( + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic1 = CriticHead( + net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + critic2 = CriticHead( + net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( @@ -118,8 +121,8 @@ def stop_fn(x): # here we define an imitation collector with a trivial policy if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal - net = Actor(1, args.state_shape, args.action_shape, - args.max_action, args.device).to(args.device) + net = ActorHead(Net(1, args.state_shape), args.action_shape, + args.max_action, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='continuous') il_test_collector = Collector(il_policy, test_envs) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index f7ed5402e..8377885a2 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,7 +11,8 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.exploration import GaussianNoise -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorHead, CriticHead def get_args(): @@ -67,17 +68,19 @@ def test_td3(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = Actor( - args.layer_num, args.state_shape, args.action_shape, + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = ActorHead( + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + critic1 = CriticHead( + net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device + critic2 = CriticHead( + net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 95ae72b8f..427e3d2a0 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -10,7 +10,8 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.trainer import onpolicy_trainer, offpolicy_trainer -from tianshou.utils.net.discrete import Net, ActorHead, CriticHead +from tianshou.utils.net.discrete import ActorHead, CriticHead +from tianshou.utils.net.common import Net def get_args(): diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6125abd6c..1863f5b6a 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -10,7 +10,7 @@ from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.discrete import Net +from tianshou.utils.net.common import Net def get_args(): diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index d2f160b8d..627a14da9 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -10,7 +10,7 @@ from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.discrete import Recurrent +from tianshou.utils.net.common import Recurrent def get_args(): diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py index 49682679c..98a8a8899 100644 --- a/test/discrete/test_pdqn.py +++ b/test/discrete/test_pdqn.py @@ -6,7 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.net.discrete import Net +from tianshou.utils.net.common import Net from tianshou.env import VectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 2e87c813d..8f8f5dc65 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.net.discrete import Net +from tianshou.utils.net.common import Net from tianshou.env import VectorEnv from tianshou.policy import PGPolicy from tianshou.trainer import onpolicy_trainer diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index b18c9374b..1000e3dfb 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -10,7 +10,8 @@ from tianshou.policy import PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.discrete import Net, ActorHead, CriticHead +from tianshou.utils.net.discrete import ActorHead, CriticHead +from tianshou.utils.net.common import Net def get_args(): diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py new file mode 100644 index 000000000..61984c536 --- /dev/null +++ b/tianshou/utils/net/common.py @@ -0,0 +1,78 @@ +""" +Commonly used MLP-backbone. +""" +import numpy as np +import torch +from torch import nn + +from tianshou.data import to_torch + + +class Net(nn.Module): + def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', + softmax=False, concat=False): + """ + Simple MLP backbone. + :param concat: whether the input shape is concatenated by state_shape + and action_shape. If it is True, ``action_shape`` is not the output + shape, but affects the input shape. + """ + super().__init__() + self.device = device + input_size = np.prod(state_shape) + if concat: + input_size += np.prod(action_shape) + self.model = [ + nn.Linear(input_size, 128), + nn.ReLU(inplace=True)] + for i in range(layer_num): + self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] + if action_shape and not concat: + self.model += [nn.Linear(128, np.prod(action_shape))] + if softmax: + self.model += [nn.Softmax(dim=-1)] + self.model = nn.Sequential(*self.model) + + def forward(self, s, state=None, info={}): + s = to_torch(s, device=self.device, dtype=torch.float) + batch = s.shape[0] + s = s.view(batch, -1) + logits = self.model(s) + return logits, state + + +class Recurrent(nn.Module): + def __init__(self, layer_num, state_shape, action_shape, device='cpu'): + super().__init__() + self.state_shape = state_shape + self.action_shape = action_shape + self.device = device + self.nn = nn.LSTM(input_size=128, hidden_size=128, + num_layers=layer_num, batch_first=True) + self.fc1 = nn.Linear(np.prod(state_shape), 128) + self.fc2 = nn.Linear(128, np.prod(action_shape)) + + def forward(self, s, state=None, info={}): + s = to_torch(s, device=self.device, dtype=torch.float) + # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) + # In short, the tensor's shape in training phase is longer than which + # in evaluation phase. + if len(s.shape) == 2: + bsz, dim = s.shape + length = 1 + else: + bsz, length, dim = s.shape + s = self.fc1(s.view([bsz * length, dim])) + s = s.view(bsz, length, -1) + self.nn.flatten_parameters() + if state is None: + s, (h, c) = self.nn(s) + else: + # we store the stack data in [bsz, len, ...] format + # but pytorch rnn needs [len, bsz, ...] + s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(), + state['c'].transpose(0, 1).contiguous())) + s = self.fc2(s[:, -1]) + # please ensure the first dim is batch size: [bsz, len, ...] + return s, {'h': h.transpose(0, 1).detach(), + 'c': c.transpose(0, 1).detach()} diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index c4033961c..0f2ef87b4 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,3 +1,7 @@ +""" +Commonly used network modules for continuous output. +""" + import torch import numpy as np from torch import nn @@ -5,88 +9,62 @@ from tianshou.data import to_torch -class Actor(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, +class ActorHead(nn.Module): + def __init__(self, preprocess_net, action_shape, max_action, device='cpu'): super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model += [nn.Linear(128, np.prod(action_shape))] - self.model = nn.Sequential(*self.model) + self.preprocess = preprocess_net + self.last = nn.Linear(128, np.prod(action_shape)) self._max = max_action - def forward(self, s, **kwargs): + def forward(self, s, state=None, info={}): + logits, h = self.preprocess(s, state) + logits = self._max * torch.tanh(self.last(logits)) + return logits, h + + +class CriticHead(nn.Module): + def __init__(self, preprocess_net, device='cpu'): + super().__init__() + self.device = device + self.preprocess = preprocess_net + self.last = nn.Linear(128, 1) + + def forward(self, s, a=None, **kwargs): s = to_torch(s, device=self.device, dtype=torch.float) batch = s.shape[0] s = s.view(batch, -1) - logits = self.model(s) - logits = self._max * torch.tanh(logits) - return logits, None + if a is not None: + a = to_torch(a, device=self.device, dtype=torch.float) + a = a.view(batch, -1) + s = torch.cat([s, a], dim=1) + logits, h = self.preprocess(s) + logits = self.last(logits) + return logits -class ActorProb(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, +class ActorHeadProb(nn.Module): + def __init__(self, preprocess_net, action_shape, max_action, device='cpu', unbounded=False): super().__init__() + self.preprocess = preprocess_net self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model = nn.Sequential(*self.model) self.mu = nn.Linear(128, np.prod(action_shape)) self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) - # self.sigma = nn.Linear(128, np.prod(action_shape)) self._max = max_action self._unbounded = unbounded - def forward(self, s, **kwargs): - s = to_torch(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) + def forward(self, s, state=None, **kwargs): + logits, h = self.preprocess(s, state) mu = self.mu(logits) if not self._unbounded: mu = self._max * torch.tanh(mu) shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() - # assert sigma.shape == mu.shape - # mu = self._max * torch.tanh(self.mu(logits)) - # sigma = torch.exp(self.sigma(logits)) return (mu, sigma), None -class Critic(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model += [nn.Linear(128, 1)] - self.model = nn.Sequential(*self.model) - - def forward(self, s, a=None, **kwargs): - s = to_torch(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - if a is not None: - if not isinstance(a, torch.Tensor): - a = torch.tensor(a, device=self.device, dtype=torch.float) - a = a.view(batch, -1) - s = torch.cat([s, a], dim=1) - logits = self.model(s) - return logits - - class RecurrentActorProb(nn.Module): def __init__(self, layer_num, state_shape, action_shape, max_action, device='cpu'): diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 3a9afde66..fff8c2a59 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -7,32 +7,6 @@ from torch import nn import torch.nn.functional as F -from tianshou.data import to_torch - - -class Net(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', - softmax=False): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - if action_shape: - self.model += [nn.Linear(128, np.prod(action_shape))] - if softmax: - self.model += [nn.Softmax(dim=-1)] - self.model = nn.Sequential(*self.model) - - def forward(self, s, state=None, info={}): - s = to_torch(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - return logits, state - class ActorHead(nn.Module): def __init__(self, preprocess_net, action_shape): @@ -58,43 +32,6 @@ def forward(self, s, **kwargs): return logits -class Recurrent(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, device='cpu'): - super().__init__() - self.state_shape = state_shape - self.action_shape = action_shape - self.device = device - self.fc1 = nn.Linear(np.prod(state_shape), 128) - self.nn = nn.LSTM(input_size=128, hidden_size=128, - num_layers=layer_num, batch_first=True) - self.fc2 = nn.Linear(128, np.prod(action_shape)) - - def forward(self, s, state=None, info={}): - s = to_torch(s, device=self.device, dtype=torch.float) - # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) - # In short, the tensor's shape in training phase is longer than which - # in evaluation phase. - if len(s.shape) == 2: - bsz, dim = s.shape - length = 1 - else: - bsz, length, dim = s.shape - s = self.fc1(s.view([bsz * length, dim])) - s = s.view(bsz, length, -1) - self.nn.flatten_parameters() - if state is None: - s, (h, c) = self.nn(s) - else: - # we store the stack data in [bsz, len, ...] format - # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(), - state['c'].transpose(0, 1).contiguous())) - s = self.fc2(s[:, -1]) - # please ensure the first dim is batch size: [bsz, len, ...] - return s, {'h': h.transpose(0, 1).detach(), - 'c': c.transpose(0, 1).detach()} - - class DQN(nn.Module): def __init__(self, h, w, action_shape, device='cpu'): From 277fb4e308b375e3547bfb53bd7a29493185294b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jul 2020 20:24:18 +0800 Subject: [PATCH 03/11] rename class --- examples/ant_v2_ddpg.py | 8 ++++---- examples/ant_v2_sac.py | 8 ++++---- examples/ant_v2_td3.py | 8 ++++---- examples/halfcheetahBullet_v0_sac.py | 8 ++++---- examples/point_maze_td3.py | 8 ++++---- examples/pong_a2c.py | 6 +++--- examples/pong_ppo.py | 6 +++--- examples/sac_mcc.py | 8 ++++---- test/continuous/test_ddpg.py | 6 +++--- test/continuous/test_ppo.py | 6 +++--- test/continuous/test_sac_with_il.py | 12 ++++++------ test/continuous/test_td3.py | 8 ++++---- test/discrete/test_a2c_with_il.py | 8 ++++---- test/discrete/test_ppo.py | 6 +++--- tianshou/utils/net/continuous.py | 6 +++--- tianshou/utils/net/discrete.py | 4 ++-- 16 files changed, 58 insertions(+), 58 deletions(-) diff --git a/examples/ant_v2_ddpg.py b/examples/ant_v2_ddpg.py index 206b00a0c..3a5fd13a7 100644 --- a/examples/ant_v2_ddpg.py +++ b/examples/ant_v2_ddpg.py @@ -11,7 +11,7 @@ from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHead, CriticHead +from tianshou.utils.net.continuous import Actor, Critic def get_args(): @@ -58,11 +58,11 @@ def test_ddpg(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead(net, args.action_shape, args.max_action, - args.device).to(args.device) + actor = Actor(net, args.action_shape, args.max_action, + args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic = CriticHead(net, args.device).to(args.device) + critic = Critic(net, args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, diff --git a/examples/ant_v2_sac.py b/examples/ant_v2_sac.py index a8e6f6536..02f5a3b2e 100644 --- a/examples/ant_v2_sac.py +++ b/examples/ant_v2_sac.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHeadProb, CriticHead +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): @@ -59,17 +59,17 @@ def test_sac(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHeadProb( + actor = ActorProb( net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = CriticHead( + critic1 = Critic( net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = CriticHead( + critic2 = Critic( net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) diff --git a/examples/ant_v2_td3.py b/examples/ant_v2_td3.py index d251566fd..d76d6b63f 100644 --- a/examples/ant_v2_td3.py +++ b/examples/ant_v2_td3.py @@ -11,7 +11,7 @@ from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHead, CriticHead +from tianshou.utils.net.continuous import Actor, Critic def get_args(): @@ -61,17 +61,17 @@ def test_td3(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead( + actor = Actor( net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = CriticHead( + critic1 = Critic( net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = CriticHead( + critic2 = Critic( net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) diff --git a/examples/halfcheetahBullet_v0_sac.py b/examples/halfcheetahBullet_v0_sac.py index 054105d60..56694b075 100644 --- a/examples/halfcheetahBullet_v0_sac.py +++ b/examples/halfcheetahBullet_v0_sac.py @@ -16,7 +16,7 @@ except ImportError: pass from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHeadProb, CriticHead +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): @@ -67,17 +67,17 @@ def test_sac(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHeadProb( + actor = ActorProb( net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = CriticHead( + critic1 = Critic( net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = CriticHead( + critic2 = Critic( net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) diff --git a/examples/point_maze_td3.py b/examples/point_maze_td3.py index c603dccca..978fcd52d 100644 --- a/examples/point_maze_td3.py +++ b/examples/point_maze_td3.py @@ -11,7 +11,7 @@ from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHead, CriticHead +from tianshou.utils.net.continuous import Actor, Critic from mujoco.register import reg @@ -65,17 +65,17 @@ def test_td3(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead( + actor = Actor( net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = CriticHead( + critic1 = Critic( net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = CriticHead( + critic2 = Critic( net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) diff --git a/examples/pong_a2c.py b/examples/pong_a2c.py index c6424aba1..544153ef6 100644 --- a/examples/pong_a2c.py +++ b/examples/pong_a2c.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment -from tianshou.utils.net.discrete import ActorHead, CriticHead +from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.common import Net @@ -66,8 +66,8 @@ def test_a2c(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead(net, args.action_shape).to(args.device) - critic = CriticHead(net).to(args.device) + actor = Actor(net, args.action_shape).to(args.device) + critic = Critic(net).to(args.device) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/examples/pong_ppo.py b/examples/pong_ppo.py index e5d4fcab1..e47534456 100644 --- a/examples/pong_ppo.py +++ b/examples/pong_ppo.py @@ -9,7 +9,7 @@ from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment -from tianshou.utils.net.discrete import ActorHead, CriticHead +from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.common import Net @@ -63,8 +63,8 @@ def test_ppo(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead(net, args.action_shape).to(args.device) - critic = CriticHead(net).to(args.device) + actor = Actor(net, args.action_shape).to(args.device) + critic = Critic(net).to(args.device) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/examples/sac_mcc.py b/examples/sac_mcc.py index 28b305fa3..4b3b36b39 100644 --- a/examples/sac_mcc.py +++ b/examples/sac_mcc.py @@ -12,7 +12,7 @@ from tianshou.env import VectorEnv from tianshou.exploration import OUNoise from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHeadProb, CriticHead +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): @@ -63,17 +63,17 @@ def test_sac(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHeadProb( + actor = ActorProb( net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = CriticHead( + critic1 = Critic( net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = CriticHead( + critic2 = Critic( net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index aee7e04a6..f1d7c4f76 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHead, CriticHead +from tianshou.utils.net.continuous import Actor, Critic def get_args(): @@ -67,13 +67,13 @@ def test_ddpg(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead( + actor = Actor( net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic = CriticHead( + critic = Critic( net, args.device ).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 6d90b614f..d81b826af 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -12,7 +12,7 @@ from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHeadProb, CriticHead +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): @@ -70,11 +70,11 @@ def test_ppo(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHeadProb( + actor = ActorProb( net, args.action_shape, args.max_action, args.device ).to(args.device) - critic = CriticHead( + critic = Critic( Net(args.layer_num, args.state_shape), device=args.device ).to(args.device) # orthogonal initialization diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index c01e59e4a..ccd038d5b 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.policy import SACPolicy, ImitationPolicy from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHead, ActorHeadProb, CriticHead +from tianshou.utils.net.continuous import Actor, ActorProb, Critic def get_args(): @@ -66,17 +66,17 @@ def test_sac_with_il(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHeadProb( + actor = ActorProb( net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = CriticHead( + critic1 = Critic( net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = CriticHead( + critic2 = Critic( net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -121,8 +121,8 @@ def stop_fn(x): # here we define an imitation collector with a trivial policy if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal - net = ActorHead(Net(1, args.state_shape), args.action_shape, - args.max_action, args.device).to(args.device) + net = Actor(Net(1, args.state_shape), args.action_shape, + args.max_action, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='continuous') il_test_collector = Collector(il_policy, test_envs) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 8377885a2..49493b43d 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorHead, CriticHead +from tianshou.utils.net.continuous import Actor, Critic def get_args(): @@ -69,17 +69,17 @@ def test_td3(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead( + actor = Actor( net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = CriticHead( + critic1 = Critic( net, args.device ).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = CriticHead( + critic2 = Critic( net, args.device ).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 427e3d2a0..365fb1234 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.trainer import onpolicy_trainer, offpolicy_trainer -from tianshou.utils.net.discrete import ActorHead, CriticHead +from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.common import Net @@ -64,8 +64,8 @@ def test_a2c_with_il(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead(net, args.action_shape).to(args.device) - critic = CriticHead(net).to(args.device) + actor = Actor(net, args.action_shape).to(args.device) + critic = Critic(net).to(args.device) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical @@ -108,7 +108,7 @@ def stop_fn(x): if args.task == 'CartPole-v0': env.spec.reward_threshold = 190 # lower the goal net = Net(1, args.state_shape, device=args.device) - net = ActorHead(net, args.action_shape).to(args.device) + net = Actor(net, args.action_shape).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector(il_policy, test_envs) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 1000e3dfb..ca0e87930 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -10,7 +10,7 @@ from tianshou.policy import PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.discrete import ActorHead, CriticHead +from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.common import Net @@ -66,8 +66,8 @@ def test_ppo(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, device=args.device) - actor = ActorHead(net, args.action_shape).to(args.device) - critic = CriticHead(net).to(args.device) + actor = Actor(net, args.action_shape).to(args.device) + critic = Critic(net).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 0f2ef87b4..a3ca82e3e 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -9,7 +9,7 @@ from tianshou.data import to_torch -class ActorHead(nn.Module): +class Actor(nn.Module): def __init__(self, preprocess_net, action_shape, max_action, device='cpu'): super().__init__() @@ -23,7 +23,7 @@ def forward(self, s, state=None, info={}): return logits, h -class CriticHead(nn.Module): +class Critic(nn.Module): def __init__(self, preprocess_net, device='cpu'): super().__init__() self.device = device @@ -43,7 +43,7 @@ def forward(self, s, a=None, **kwargs): return logits -class ActorHeadProb(nn.Module): +class ActorProb(nn.Module): def __init__(self, preprocess_net, action_shape, max_action, device='cpu', unbounded=False): super().__init__() diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index fff8c2a59..5f3f1e29c 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -8,7 +8,7 @@ import torch.nn.functional as F -class ActorHead(nn.Module): +class Actor(nn.Module): def __init__(self, preprocess_net, action_shape): super().__init__() self.preprocess = preprocess_net @@ -20,7 +20,7 @@ def forward(self, s, state=None, info={}): return logits, h -class CriticHead(nn.Module): +class Critic(nn.Module): def __init__(self, preprocess_net): super().__init__() self.preprocess = preprocess_net From 185f5dc07bf874d6be63307b6f010f9b8e2d7ad4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jul 2020 20:33:39 +0800 Subject: [PATCH 04/11] change torch.float to torch.float32 --- README.md | 2 +- docs/tutorials/dqn.rst | 2 +- tianshou/policy/imitation/base.py | 2 +- tianshou/utils/net/common.py | 4 ++-- tianshou/utils/net/continuous.py | 10 +++++----- tianshou/utils/net/discrete.py | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 150489afe..63332ba2c 100644 --- a/README.md +++ b/README.md @@ -217,7 +217,7 @@ class Net(nn.Module): ]) def forward(self, s, state=None, info={}): if not isinstance(s, torch.Tensor): - s = torch.tensor(s, dtype=torch.float) + s = torch.tensor(s, dtype=torch.float32) batch = s.shape[0] logits = self.model(s.view(batch, -1)) return logits, state diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 0d430316f..7c3ef10c4 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -64,7 +64,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the ]) def forward(self, obs, state=None, info={}): if not isinstance(obs, torch.Tensor): - obs = torch.tensor(obs, dtype=torch.float) + obs = torch.tensor(obs, dtype=torch.float32) batch = obs.shape[0] logits = self.model(obs.view(batch, -1)) return logits, state diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index aeb0eb3ba..57bdba933 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -46,7 +46,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.optim.zero_grad() if self.mode == 'continuous': a = self(batch).act - a_ = to_torch(batch.act, dtype=torch.float, device=a.device) + a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) loss = F.mse_loss(a, a_) elif self.mode == 'discrete': # classification a = self(batch).logits diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 61984c536..36ca9f81c 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -34,7 +34,7 @@ def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', self.model = nn.Sequential(*self.model) def forward(self, s, state=None, info={}): - s = to_torch(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float32) batch = s.shape[0] s = s.view(batch, -1) logits = self.model(s) @@ -53,7 +53,7 @@ def __init__(self, layer_num, state_shape, action_shape, device='cpu'): self.fc2 = nn.Linear(128, np.prod(action_shape)) def forward(self, s, state=None, info={}): - s = to_torch(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index a3ca82e3e..e6f94264b 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -31,11 +31,11 @@ def __init__(self, preprocess_net, device='cpu'): self.last = nn.Linear(128, 1) def forward(self, s, a=None, **kwargs): - s = to_torch(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float32) batch = s.shape[0] s = s.view(batch, -1) if a is not None: - a = to_torch(a, device=self.device, dtype=torch.float) + a = to_torch(a, device=self.device, dtype=torch.float32) a = a.view(batch, -1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) @@ -76,7 +76,7 @@ def __init__(self, layer_num, state_shape, action_shape, self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) def forward(self, s, **kwargs): - s = to_torch(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -106,7 +106,7 @@ def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): self.fc2 = nn.Linear(128 + np.prod(action_shape), 1) def forward(self, s, a=None): - s = to_torch(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -116,7 +116,7 @@ def forward(self, s, a=None): s = s[:, -1] if a is not None: if not isinstance(a, torch.Tensor): - a = torch.tensor(a, device=self.device, dtype=torch.float) + a = torch.tensor(a, device=self.device, dtype=torch.float32) s = torch.cat([s, a], dim=1) s = self.fc2(s) return s diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 5f3f1e29c..d291c11a4 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -56,7 +56,7 @@ def conv2d_size_out(size, kernel_size=5, stride=2): def forward(self, x, state=None, info={}): if not isinstance(x, torch.Tensor): - x = torch.tensor(x, device=self.device, dtype=torch.float) + x = torch.tensor(x, device=self.device, dtype=torch.float32) x = x.permute(0, 3, 1, 2) x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) From ba7c184b83bbdc9ac960aaab99b9835f78d1c42e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jul 2020 21:07:29 +0800 Subject: [PATCH 05/11] use flatten(1) instead of view(batch, -1) --- README.md | 3 +-- docs/tutorials/dqn.rst | 3 +-- tianshou/utils/net/common.py | 11 ++--------- tianshou/utils/net/continuous.py | 11 +++-------- 4 files changed, 7 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 63332ba2c..32cb5e90d 100644 --- a/README.md +++ b/README.md @@ -218,8 +218,7 @@ class Net(nn.Module): def forward(self, s, state=None, info={}): if not isinstance(s, torch.Tensor): s = torch.tensor(s, dtype=torch.float32) - batch = s.shape[0] - logits = self.model(s.view(batch, -1)) + logits = self.model(s.flatten(1)) return logits, state env = gym.make(task) diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 7c3ef10c4..7bf48a4bf 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -65,8 +65,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the def forward(self, obs, state=None, info={}): if not isinstance(obs, torch.Tensor): obs = torch.tensor(obs, dtype=torch.float32) - batch = obs.shape[0] - logits = self.model(obs.view(batch, -1)) + logits = self.model(obs.flatten(1)) return logits, state state_shape = env.observation_space.shape or env.observation_space.n diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 36ca9f81c..9287b747d 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -35,8 +35,7 @@ def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', def forward(self, s, state=None, info={}): s = to_torch(s, device=self.device, dtype=torch.float32) - batch = s.shape[0] - s = s.view(batch, -1) + s = s.flatten(1) logits = self.model(s) return logits, state @@ -57,13 +56,7 @@ def forward(self, s, state=None, info={}): # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. - if len(s.shape) == 2: - bsz, dim = s.shape - length = 1 - else: - bsz, length, dim = s.shape - s = self.fc1(s.view([bsz * length, dim])) - s = s.view(bsz, length, -1) + s = self.fc1(s) self.nn.flatten_parameters() if state is None: s, (h, c) = self.nn(s) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index e6f94264b..8e03cb574 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -32,11 +32,10 @@ def __init__(self, preprocess_net, device='cpu'): def forward(self, s, a=None, **kwargs): s = to_torch(s, device=self.device, dtype=torch.float32) - batch = s.shape[0] - s = s.view(batch, -1) + s = s.flatten(1) if a is not None: a = to_torch(a, device=self.device, dtype=torch.float32) - a = a.view(batch, -1) + a = a.flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) logits = self.last(logits) @@ -81,11 +80,7 @@ def forward(self, s, **kwargs): # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(s.shape) == 2: - bsz, dim = s.shape - length = 1 - else: - bsz, length, dim = s.shape - s = s.view(bsz, length, -1) + s = s.unsqueeze(-2) logits, _ = self.nn(s) logits = logits[:, -1] mu = self.mu(logits) From 1474a6ce692b43b796886a740802e878ab210634 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jul 2020 21:18:09 +0800 Subject: [PATCH 06/11] remove dummy net in docs --- README.md | 17 ++--------------- docs/tutorials/dqn.rst | 20 +++----------------- 2 files changed, 5 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 32cb5e90d..fbc8bbbcd 100644 --- a/README.md +++ b/README.md @@ -206,25 +206,12 @@ test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)]) Define the network: ```python -class Net(nn.Module): - def __init__(self, state_shape, action_shape): - super().__init__() - self.model = nn.Sequential(*[ - nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True), - nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, np.prod(action_shape)) - ]) - def forward(self, s, state=None, info={}): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, dtype=torch.float32) - logits = self.model(s.flatten(1)) - return logits, state +from tianshou.utils.net.common import Net env = gym.make(task) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n -net = Net(state_shape, action_shape) +net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape) optim = torch.optim.Adam(net.parameters(), lr=lr) ``` diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 7bf48a4bf..6e7199d58 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -50,27 +50,13 @@ Build the Network Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code: :: - import torch, numpy as np from torch import nn - - class Net(nn.Module): - def __init__(self, state_shape, action_shape): - super().__init__() - self.model = nn.Sequential(*[ - nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True), - nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, np.prod(action_shape)) - ]) - def forward(self, obs, state=None, info={}): - if not isinstance(obs, torch.Tensor): - obs = torch.tensor(obs, dtype=torch.float32) - logits = self.model(obs.flatten(1)) - return logits, state + import torch, numpy as np + from tianshou.utils.net.common import Net state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n - net = Net(state_shape, action_shape) + net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) The rules of self-defined networks are: From 6d0f9afec58a5412fc0a86a7c14b70b2164f24be Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jul 2020 21:23:20 +0800 Subject: [PATCH 07/11] bugfix for rnn --- tianshou/utils/net/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 9287b747d..478952c85 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -56,6 +56,8 @@ def forward(self, s, state=None, info={}): # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. + if len(s.shape) == 2: + s = s.unsqueeze(-2) s = self.fc1(s) self.nn.flatten_parameters() if state is None: From cd722c8993c380834e58b74891adba013a4ef6c7 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 9 Jul 2020 22:20:43 +0800 Subject: [PATCH 08/11] fix cuda error --- examples/ant_v2_ddpg.py | 3 ++- examples/ant_v2_sac.py | 11 ++++------- examples/ant_v2_td3.py | 11 ++++------- examples/halfcheetahBullet_v0_sac.py | 11 ++++------- examples/point_maze_td3.py | 11 ++++------- examples/sac_mcc.py | 11 ++++------- test/continuous/test_ddpg.py | 7 +++---- test/continuous/test_ppo.py | 6 +++--- test/continuous/test_sac_with_il.py | 16 +++++++--------- test/continuous/test_td3.py | 11 ++++------- test/discrete/test_dqn.py | 4 ++-- test/discrete/test_drqn.py | 3 +-- test/discrete/test_pdqn.py | 4 ++-- test/discrete/test_pg.py | 3 +-- 14 files changed, 45 insertions(+), 67 deletions(-) diff --git a/examples/ant_v2_ddpg.py b/examples/ant_v2_ddpg.py index 3a5fd13a7..6b9ba0a53 100644 --- a/examples/ant_v2_ddpg.py +++ b/examples/ant_v2_ddpg.py @@ -61,7 +61,8 @@ def test_ddpg(args=get_args()): actor = Actor(net, args.action_shape, args.max_action, args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) critic = Critic(net, args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( diff --git a/examples/ant_v2_sac.py b/examples/ant_v2_sac.py index 02f5a3b2e..cdfc8138f 100644 --- a/examples/ant_v2_sac.py +++ b/examples/ant_v2_sac.py @@ -64,14 +64,11 @@ def test_sac(args=get_args()): args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = Critic( - net, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - net, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/examples/ant_v2_td3.py b/examples/ant_v2_td3.py index d76d6b63f..495891bd4 100644 --- a/examples/ant_v2_td3.py +++ b/examples/ant_v2_td3.py @@ -66,14 +66,11 @@ def test_td3(args=get_args()): args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = Critic( - net, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - net, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/examples/halfcheetahBullet_v0_sac.py b/examples/halfcheetahBullet_v0_sac.py index 56694b075..3da77ccaa 100644 --- a/examples/halfcheetahBullet_v0_sac.py +++ b/examples/halfcheetahBullet_v0_sac.py @@ -72,14 +72,11 @@ def test_sac(args=get_args()): args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = Critic( - net, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - net, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/examples/point_maze_td3.py b/examples/point_maze_td3.py index 978fcd52d..ce045993f 100644 --- a/examples/point_maze_td3.py +++ b/examples/point_maze_td3.py @@ -70,14 +70,11 @@ def test_td3(args=get_args()): args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = Critic( - net, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - net, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/examples/sac_mcc.py b/examples/sac_mcc.py index 4b3b36b39..fcd2ce447 100644 --- a/examples/sac_mcc.py +++ b/examples/sac_mcc.py @@ -68,14 +68,11 @@ def test_sac(args=get_args()): args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = Critic( - net, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - net, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) if args.auto_alpha: diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index f1d7c4f76..6f078bcaa 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -72,10 +72,9 @@ def test_ddpg(args=get_args()): args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic = Critic( - net, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic = Critic(net, args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index d81b826af..daa7b0661 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -74,9 +74,9 @@ def test_ppo(args=get_args()): net, args.action_shape, args.max_action, args.device ).to(args.device) - critic = Critic( - Net(args.layer_num, args.state_shape), device=args.device - ).to(args.device) + critic = Critic(Net( + args.layer_num, args.state_shape, device=args.device + ), device=args.device).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index ccd038d5b..96e034054 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -71,14 +71,11 @@ def test_sac_with_il(args=get_args()): args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = Critic( - net, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - net, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, @@ -121,8 +118,9 @@ def stop_fn(x): # here we define an imitation collector with a trivial policy if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal - net = Actor(Net(1, args.state_shape), args.action_shape, - args.max_action, args.device).to(args.device) + net = Actor(Net(1, args.state_shape, device=args.device), + args.action_shape, args.max_action, args.device + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='continuous') il_test_collector = Collector(il_policy, test_envs) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 49493b43d..096290b6b 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -74,14 +74,11 @@ def test_td3(args=get_args()): args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.layer_num, args.state_shape, args.action_shape, concat=True) - critic1 = Critic( - net, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - net, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 1863f5b6a..96ddb70e2 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -57,8 +57,8 @@ def test_dqn(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) - net = net.to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 627a14da9..f0f34ed4c 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -59,8 +59,7 @@ def test_drqn(args=get_args()): test_envs.seed(args.seed) # model net = Recurrent(args.layer_num, args.state_shape, - args.action_shape, args.device) - net = net.to(args.device) + args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py index 98a8a8899..22fa34764 100644 --- a/test/discrete/test_pdqn.py +++ b/test/discrete/test_pdqn.py @@ -60,8 +60,8 @@ def test_pdqn(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) - net = net.to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 8f8f5dc65..c82076204 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -125,8 +125,7 @@ def test_pg(args=get_args()): # model net = Net( args.layer_num, args.state_shape, args.action_shape, - device=args.device, softmax=True) - net = net.to(args.device) + device=args.device, softmax=True).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist = torch.distributions.Categorical policy = PGPolicy(net, optim, dist, args.gamma, From b9a314eaa2e514779a0b2290e31c5fb4376928aa Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 9 Jul 2020 22:35:28 +0800 Subject: [PATCH 09/11] minor fix of focs --- docs/api/tianshou.utils.rst | 15 +++++++++++++++ docs/tutorials/dqn.rst | 2 ++ tianshou/utils/net/common.py | 21 ++++++++++++--------- tianshou/utils/net/continuous.py | 4 ---- tianshou/utils/net/discrete.py | 4 ---- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/docs/api/tianshou.utils.rst b/docs/api/tianshou.utils.rst index 82c195259..3a293b1c1 100644 --- a/docs/api/tianshou.utils.rst +++ b/docs/api/tianshou.utils.rst @@ -5,3 +5,18 @@ tianshou.utils :members: :undoc-members: :show-inheritance: + +.. automodule:: tianshou.utils.net.common + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: tianshou.utils.net.discrete + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: tianshou.utils.net.continuous + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 6e7199d58..0d48e075d 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -59,6 +59,8 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) +where the `Net` is a simple `torch.nn.Module` obeys the following rule. + The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 478952c85..8dd376b92 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,6 +1,3 @@ -""" -Commonly used MLP-backbone. -""" import numpy as np import torch from torch import nn @@ -9,14 +6,16 @@ class Net(nn.Module): + """Simple MLP backbone. For advanced usage (how to customize the network), + please refer to :ref:`build_the_network`. + + :param concat: whether the input shape is concatenated by state_shape + and action_shape. If it is True, ``action_shape`` is not the output + shape, but affects the input shape. + """ + def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', softmax=False, concat=False): - """ - Simple MLP backbone. - :param concat: whether the input shape is concatenated by state_shape - and action_shape. If it is True, ``action_shape`` is not the output - shape, but affects the input shape. - """ super().__init__() self.device = device input_size = np.prod(state_shape) @@ -41,6 +40,10 @@ def forward(self, s, state=None, info={}): class Recurrent(nn.Module): + """Simple Recurrent network based on LSTM. For advanced usage (how to + customize the network), please refer to :ref:`build_the_network`. + """ + def __init__(self, layer_num, state_shape, action_shape, device='cpu'): super().__init__() self.state_shape = state_shape diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 8e03cb574..a36a704a3 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,7 +1,3 @@ -""" -Commonly used network modules for continuous output. -""" - import torch import numpy as np from torch import nn diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index d291c11a4..ebf2d1c34 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -1,7 +1,3 @@ -""" -Commonly used network modules for discrete output. -""" - import torch import numpy as np from torch import nn From 8e8ec66585b0ba509e440b2c3a941cd6260a0ff6 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 9 Jul 2020 22:38:20 +0800 Subject: [PATCH 10/11] minor fix of docs --- tianshou/utils/net/continuous.py | 20 ++++++++++++++++++++ tianshou/utils/net/discrete.py | 11 +++++++++++ 2 files changed, 31 insertions(+) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index a36a704a3..aac6fc4f7 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -6,6 +6,10 @@ class Actor(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, preprocess_net, action_shape, max_action, device='cpu'): super().__init__() @@ -20,6 +24,10 @@ def forward(self, s, state=None, info={}): class Critic(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, preprocess_net, device='cpu'): super().__init__() self.device = device @@ -39,6 +47,10 @@ def forward(self, s, a=None, **kwargs): class ActorProb(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, preprocess_net, action_shape, max_action, device='cpu', unbounded=False): super().__init__() @@ -61,6 +73,10 @@ def forward(self, s, state=None, **kwargs): class RecurrentActorProb(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, layer_num, state_shape, action_shape, max_action, device='cpu'): super().__init__() @@ -87,6 +103,10 @@ def forward(self, s, **kwargs): class RecurrentCritic(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): super().__init__() self.state_shape = state_shape diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index ebf2d1c34..4cad50d28 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -5,6 +5,10 @@ class Actor(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, preprocess_net, action_shape): super().__init__() self.preprocess = preprocess_net @@ -17,6 +21,10 @@ def forward(self, s, state=None, info={}): class Critic(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, preprocess_net): super().__init__() self.preprocess = preprocess_net @@ -29,6 +37,9 @@ def forward(self, s, **kwargs): class DQN(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ def __init__(self, h, w, action_shape, device='cpu'): super(DQN, self).__init__() From b507bb5d04d511513a42fb04dabfcb51ecbd8f66 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 9 Jul 2020 22:50:13 +0800 Subject: [PATCH 11/11] do not change the example code in dqn tutorial, since it is for demonstration --- docs/tutorials/dqn.rst | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 0d48e075d..d981a1cb3 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -50,18 +50,31 @@ Build the Network Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code: :: - from torch import nn import torch, numpy as np - from tianshou.utils.net.common import Net + from torch import nn + + class Net(nn.Module): + def __init__(self, state_shape, action_shape): + super().__init__() + self.model = nn.Sequential(*[ + nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True), + nn.Linear(128, 128), nn.ReLU(inplace=True), + nn.Linear(128, 128), nn.ReLU(inplace=True), + nn.Linear(128, np.prod(action_shape)) + ]) + def forward(self, obs, state=None, info={}): + if not isinstance(obs, torch.Tensor): + obs = torch.tensor(obs, dtype=torch.float) + batch = obs.shape[0] + logits = self.model(obs.view(batch, -1)) + return logits, state state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n - net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape) + net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) -where the `Net` is a simple `torch.nn.Module` obeys the following rule. - -The rules of self-defined networks are: +You can also have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. 2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).