这是indexloc提供的服务,不要输入任何密码
Skip to content

final fix for actor_critic shared head parameters #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic


Expand Down Expand Up @@ -84,14 +84,13 @@ def test_ppo(args=get_args()):
Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
device=args.device
).to(args.device)
actor_critic = ActorCritic(actor, critic)
# orthogonal initialization
for m in set(actor.modules()).union(critic.modules()):
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr
)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
Expand Down
6 changes: 2 additions & 4 deletions test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tianshou.policy import A2CPolicy, ImitationPolicy
from tianshou.trainer import offpolicy_trainer, onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic


Expand Down Expand Up @@ -74,9 +74,7 @@ def test_a2c_with_il(args=get_args()):
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr
)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor,
Expand Down
9 changes: 4 additions & 5 deletions test/discrete/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic


Expand Down Expand Up @@ -73,14 +73,13 @@ def test_ppo(args=get_args()):
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)
# orthogonal initialization
for m in set(actor.modules()).union(critic.modules()):
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr
)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor,
Expand Down
7 changes: 5 additions & 2 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(
preprocess_fn: Optional[Callable[..., Batch]] = None,
exploration_noise: bool = False,
) -> None:
assert env.is_async
# assert env.is_async
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)

def reset_env(self) -> None:
Expand Down Expand Up @@ -452,7 +452,10 @@ def collect(
obs_next, rew, done, info = result

# change self.data here because ready_env_ids has changed
ready_env_ids = np.array([i["env_id"] for i in info])
try:
ready_env_ids = info["env_id"]
except Exception:
ready_env_ids = np.array([i["env_id"] for i in info])
self.data = whole_data[ready_env_ids]

self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
Expand Down
5 changes: 3 additions & 2 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.policy import PGPolicy
from tianshou.utils.net.common import ActorCritic


class A2CPolicy(PGPolicy):
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
self._weight_ent = ent_coef
self._grad_norm = max_grad_norm
self._batch = max_batchsize
self._actor_critic = ActorCritic(self.actor, self.critic)

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
Expand Down Expand Up @@ -136,8 +138,7 @@ def learn( # type: ignore
loss.backward()
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
set(self.actor.parameters()).union(self.critic.parameters()),
max_norm=self._grad_norm
self._actor_critic.parameters(), max_norm=self._grad_norm
)
self.optim.step()
actor_losses.append(actor_loss.item())
Expand Down
3 changes: 1 addition & 2 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def learn( # type: ignore
loss.backward()
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
set(self.actor.parameters()).union(self.critic.parameters()),
max_norm=self._grad_norm
self._actor_critic.parameters(), max_norm=self._grad_norm
)
self.optim.step()
clip_losses.append(clip_loss.item())
Expand Down
16 changes: 16 additions & 0 deletions tianshou/utils/net/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,19 @@ def forward(
s = self.fc2(s[:, -1])
# please ensure the first dim is batch size: [bsz, len, ...]
return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}


class ActorCritic(nn.Module):
"""An actor-critic network for parsing parameters.

Using ``actor_critic.parameters()`` instead of set.union or list+list to avoid
issue #449.

:param nn.Module actor: the actor network.
:param nn.Module critic: the critic network.
"""

def __init__(self, actor: nn.Module, critic: nn.Module) -> None:
super().__init__()
self.actor = actor
self.critic = critic