这是indexloc提供的服务,不要输入任何密码
Skip to content

Fix shape inconsistency in A2CPolicy and PPOPolicy #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e976d74
Improve Batch (#126)
youkaichao Jul 11, 2020
d1a2037
Improve Batch (#128)
youkaichao Jul 11, 2020
a55ad33
Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)
youkaichao Jul 12, 2020
885fbc1
Improve collector (#125)
youkaichao Jul 12, 2020
cee8088
Vector env enable select worker (#132)
duburcqa Jul 13, 2020
f8ad6df
Standardized behavior of Batch.cat and misc code refactor (#137)
youkaichao Jul 16, 2020
fa542f8
write tutorials to specify the standard of Batch (#142)
youkaichao Jul 19, 2020
c198c60
Vector env enable select worker (#132)
duburcqa Jul 13, 2020
db2a4c9
Standardized behavior of Batch.cat and misc code refactor (#137)
youkaichao Jul 16, 2020
c7ecb4a
write tutorials to specify the standard of Batch (#142)
youkaichao Jul 19, 2020
977b627
Merge branch 'dev' into dev
youkaichao Jul 20, 2020
988a13d
Vector env enable select worker (#132)
duburcqa Jul 13, 2020
5e2af35
Standardized behavior of Batch.cat and misc code refactor (#137)
youkaichao Jul 16, 2020
e03c49b
write tutorials to specify the standard of Batch (#142)
youkaichao Jul 19, 2020
9d31801
Merge branch 'dev' of github.com:trinkle23897/tianshou into dev
Trinkle23897 Jul 21, 2020
20db334
fix a2c
Trinkle23897 Jul 21, 2020
58057f5
Merge branch 'dev' into fix-policy-shape
Trinkle23897 Jul 21, 2020
a802766
minor update
Trinkle23897 Jul 21, 2020
fb2d490
Merge branch 'dev' into fix-policy-shape
Trinkle23897 Jul 21, 2020
f30dc25
minor update
Trinkle23897 Jul 21, 2020
160db40
Merge branch 'dev' into fix-policy-shape
Trinkle23897 Jul 21, 2020
877b324
fix a2c bug
Trinkle23897 Jul 21, 2020
42287b7
fix ppo
Trinkle23897 Jul 21, 2020
e88e100
all squeeze
Trinkle23897 Jul 21, 2020
4c52ef3
squeeze with dim
Trinkle23897 Jul 21, 2020
ac83c4f
remove squeeze()
Trinkle23897 Jul 21, 2020
a736526
add a warning
Trinkle23897 Jul 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,24 @@ def learn(self, batch: Batch, **kwargs
"""Update policy with a given batch of data.

:return: A dict which includes loss and its corresponding label.

.. warning::

If you use ``torch.distributions.Normal`` and
``torch.distributions.Categorical`` to calculate the log_prob,
please be careful about the shape: Categorical distribution gives
"[batch_size]" shape while Normal distribution gives "[batch_size,
1]" shape. The auto-broadcasting of numerical operation with torch
tensors will amplify this error.
"""
pass

@staticmethod
def compute_episodic_return(
batch: Batch,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
batch: Batch,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
) -> Batch:
"""Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
Expand All @@ -128,7 +137,8 @@ def compute_episodic_return(
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.

:return: a Batch. The result will be stored in batch.returns.
:return: a Batch. The result will be stored in batch.returns as a numpy
array.
"""
rew = batch.rew
if v_s_ is None:
Expand Down Expand Up @@ -157,7 +167,7 @@ def compute_nstep_return(
gamma: float = 0.99,
n_step: int = 1,
rew_norm: bool = False,
) -> np.ndarray:
) -> Batch:
r"""Compute n-step return for Q-learning targets:

.. math::
Expand Down Expand Up @@ -204,7 +214,7 @@ def compute_nstep_return(
returns[done[now] > 0] = 0
returns = (rew[now] - mean) / std + gamma * returns
terminal = (indice + n_step - 1) % buf_len
target_q = target_q_fn(buffer, terminal).squeeze()
target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ]
target_q[gammas != n_step] = 0
returns = to_torch_as(returns, target_q)
gammas = to_torch_as(gamma ** gammas, target_q)
Expand Down
7 changes: 4 additions & 3 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,12 @@ def learn(self, batch: Batch, batch_size: int, repeat: int,
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
v = self.critic(b.obs)
v = self.critic(b.obs).squeeze(-1)
a = to_torch_as(b.act, v)
r = to_torch_as(b.returns, v)
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v)
a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()
).mean()
vf_loss = F.mse_loss(r, v)
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()
Expand Down
5 changes: 2 additions & 3 deletions tianshou/policy/modelfree/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,8 @@ def forward(self, batch: Batch,
return Batch(act=actions, state=h)

def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
current_q = self.critic(batch.obs, batch.act)
target_q = to_torch_as(batch.returns, current_q)
target_q = target_q[:, None]
current_q = self.critic(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
critic_loss.backward()
Expand Down
12 changes: 6 additions & 6 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,10 @@ def learn(self, batch: Batch, batch_size: int, repeat: int,
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0) # old value
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(
batch.returns, v[0]).reshape(batch.v.shape)
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
if not np.isclose(std.item(), 0):
Expand All @@ -147,8 +146,9 @@ def learn(self, batch: Batch, batch_size: int, repeat: int,
for _ in range(repeat):
for b in batch.split(batch_size):
dist = self(b).dist
value = self.critic(b.obs)
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
value = self.critic(b.obs).squeeze(-1)
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
).exp().float()
surr1 = ratio * b.adv
surr2 = ratio.clamp(
1. - self._eps_clip, 1. + self._eps_clip) * b.adv
Expand Down
14 changes: 7 additions & 7 deletions tianshou/policy/modelfree/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,25 +139,25 @@ def _target_q(self, buffer: ReplayBuffer,

def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
target_q = to_torch_as(batch.returns, current_q1)[:, None]
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act)
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
# actor
obs_result = self(batch, explorating=False)
a = obs_result.act
current_q1a = self.critic1(batch.obs, a)
current_q2a = self.critic2(batch.obs, a)
actor_loss = (self._alpha * obs_result.log_prob - torch.min(
current_q1a, current_q2a)).mean()
current_q1a = self.critic1(batch.obs, a).squeeze(-1)
current_q2a = self.critic2(batch.obs, a).squeeze(-1)
actor_loss = (self._alpha * obs_result.log_prob.reshape(
target_q.shape) - torch.min(current_q1a, current_q2a)).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
Expand Down
6 changes: 3 additions & 3 deletions tianshou/policy/modelfree/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,14 @@ def _target_q(self, buffer: ReplayBuffer,

def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
target_q = batch.returns[:, None]
current_q1 = self.critic1(batch.obs, batch.act).squeeze(-1)
target_q = batch.returns
critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act)
current_q2 = self.critic2(batch.obs, batch.act).squeeze(-1)
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
Expand Down