From c37bf2495bb4dc447a5ef61653d07bb0e4bc3070 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 9 Feb 2021 16:15:13 +0800 Subject: [PATCH 1/4] fix cuda bug --- examples/atari/runnable/pong_a2c.py | 4 ++-- examples/atari/runnable/pong_ppo.py | 4 ++-- test/discrete/test_a2c_with_il.py | 4 ++-- test/discrete/test_ppo.py | 4 ++-- tianshou/utils/net/continuous.py | 11 +++++++---- tianshou/utils/net/discrete.py | 10 ++++++++-- 6 files changed, 23 insertions(+), 14 deletions(-) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 100ae24a6..ffed1694d 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -65,8 +65,8 @@ def test_a2c(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 55219da68..35ed0e749 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -65,8 +65,8 @@ def test_ppo(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 90f6681fd..c222bf9a3 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -68,8 +68,8 @@ def test_a2c_with_il(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(set( actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 231ad5032..e2e671c99 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -68,8 +68,8 @@ def test_ppo(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape).to(args.device) - critic = Critic(net).to(args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 333c2ab8c..cf1647bf6 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -49,7 +49,8 @@ def __init__( self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, hidden_sizes) + self.last = MLP(input_dim, self.output_dim, + hidden_sizes, device = self.device) self._max = max_action def forward( @@ -98,7 +99,7 @@ def __init__( self.output_dim = 1 input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, 1, hidden_sizes) + self.last = MLP(input_dim, 1, hidden_sizes, device = self.device) def forward( self, @@ -164,10 +165,12 @@ def __init__( self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, hidden_sizes) + self.mu = MLP(input_dim, self.output_dim, + hidden_sizes, device = self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: - self.sigma = MLP(input_dim, self.output_dim, hidden_sizes) + self.sigma = MLP(input_dim, self.output_dim, + hidden_sizes, device = self.device) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self._max = max_action diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 05c02361c..fc7c9b002 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -40,13 +40,16 @@ def __init__( hidden_sizes: Sequence[int] = (), softmax_output: bool = True, preprocess_net_output_dim: Optional[int] = None, + device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__() + self.device = device self.preprocess = preprocess_net self.output_dim = np.prod(action_shape) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, hidden_sizes) + self.last = MLP(input_dim, self.output_dim, + hidden_sizes, device=self.device) self.softmax_output = softmax_output def forward( @@ -91,13 +94,16 @@ def __init__( hidden_sizes: Sequence[int] = (), last_size: int = 1, preprocess_net_output_dim: Optional[int] = None, + device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__() + self.device = device self.preprocess = preprocess_net self.output_dim = last_size input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, hidden_sizes) + self.last = MLP(input_dim, last_size, + hidden_sizes, device=self.device) def forward( self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any From 9656a1efaf2b26094a3075f0310869bcb3e1e0bd Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 9 Feb 2021 16:19:38 +0800 Subject: [PATCH 2/4] another fix --- test/discrete/test_sac.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 67da3fb57..3d3df6f2c 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -62,15 +62,18 @@ def test_discrete_sac(args=get_args()): # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, softmax_output=False).to(args.device) + actor = Actor(net, args.action_shape, + softmax_output=False, device=args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device) + critic1 = Critic(net_c1, last_size=args.action_shape, + device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device) + critic2 = Critic(net_c2, last_size=args.action_shape, + device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) # better not to use auto alpha in CartPole From 15bd6e7b53fa941f59652675fed11433e3929b0a Mon Sep 17 00:00:00 2001 From: n+e Date: Tue, 9 Feb 2021 16:42:10 +0800 Subject: [PATCH 3/4] Apply suggestions from code review --- tianshou/utils/net/continuous.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index cf1647bf6..a8f667532 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -50,7 +50,7 @@ def __init__( input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) self.last = MLP(input_dim, self.output_dim, - hidden_sizes, device = self.device) + hidden_sizes, device=self.device) self._max = max_action def forward( @@ -99,7 +99,7 @@ def __init__( self.output_dim = 1 input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, 1, hidden_sizes, device = self.device) + self.last = MLP(input_dim, 1, hidden_sizes, device=self.device) def forward( self, @@ -166,11 +166,11 @@ def __init__( input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) self.mu = MLP(input_dim, self.output_dim, - hidden_sizes, device = self.device) + hidden_sizes, device=self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP(input_dim, self.output_dim, - hidden_sizes, device = self.device) + hidden_sizes, device=self.device) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self._max = max_action From 38496cf4664e94f6798d3d8d5e220e47d964e826 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Tue, 9 Feb 2021 16:51:35 +0800 Subject: [PATCH 4/4] last fix --- test/discrete/test_a2c_with_il.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index c222bf9a3..08759f92e 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -113,7 +113,7 @@ def stop_fn(mean_rewards): env.spec.reward_threshold = 190 # lower the goal net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - net = Actor(net, args.action_shape).to(args.device) + net = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector(