这是indexloc提供的服务,不要输入任何密码
Skip to content

Fix Atari PPO example #780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.1 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
| BreakoutNoFrameskip-v4 | 438.5 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
| EnduroNoFrameskip-v4 | 1304.8 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 13640 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 1930 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 904 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 2.5e-5` |
| SpaceInvadersNoFrameskip-v4 | 843 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |
| PongNoFrameskip-v4 | 20.2 | ![](results/ppo/Pong_rew.png) | `python3 atari_ppo.py --task "PongNoFrameskip-v4"` |
| BreakoutNoFrameskip-v4 | 441.8 | ![](results/ppo/Breakout_rew.png) | `python3 atari_ppo.py --task "BreakoutNoFrameskip-v4"` |
| EnduroNoFrameskip-v4 | 1245.4 | ![](results/ppo/Enduro_rew.png) | `python3 atari_ppo.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 17395 | ![](results/ppo/Qbert_rew.png) | `python3 atari_ppo.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 2098 | ![](results/ppo/MsPacman_rew.png) | `python3 atari_ppo.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 882 | ![](results/ppo/Seaquest_rew.png) | `python3 atari_ppo.py --task "SeaquestNoFrameskip-v4" --lr 1e-4` |
| SpaceInvadersNoFrameskip-v4 | 1340.5 | ![](results/ppo/SpaceInvaders_rew.png) | `python3 atari_ppo.py --task "SpaceInvadersNoFrameskip-v4"` |

# SAC (single run)

Expand Down
43 changes: 35 additions & 8 deletions examples/atari/atari_network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
Expand All @@ -7,6 +7,29 @@
from tianshou.utils.net.discrete import NoisyLinear


def layer_init(
layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0
) -> nn.Module:
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


def scale_obs(module: Type[nn.Module], denom: float = 255.0) -> Type[nn.Module]:

class scaled_module(module):

def forward(
self,
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {}
) -> Tuple[torch.Tensor, Any]:
return super().forward(obs / denom, state, info)

return scaled_module


class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.

Expand All @@ -23,26 +46,30 @@ def __init__(
device: Union[str, int, torch.device] = "cpu",
features_only: bool = False,
output_dim: Optional[int] = None,
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
) -> None:
super().__init__()
self.device = device
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
nn.Flatten()
layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)),
nn.ReLU(inplace=True),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(inplace=True),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.ReLU(inplace=True), nn.Flatten()
)
with torch.no_grad():
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only:
self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
nn.Linear(512, np.prod(action_shape))
self.net, layer_init(nn.Linear(self.output_dim, 512)),
nn.ReLU(inplace=True),
layer_init(nn.Linear(512, np.prod(action_shape)))
)
self.output_dim = np.prod(action_shape)
elif output_dim is not None:
self.net = nn.Sequential(
self.net, nn.Linear(self.output_dim, output_dim),
self.net, layer_init(nn.Linear(self.output_dim, output_dim)),
nn.ReLU(inplace=True)
)
self.output_dim = output_dim
Expand Down
24 changes: 14 additions & 10 deletions examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import torch
from atari_network import DQN
from atari_network import DQN, layer_init, scale_obs
from atari_wrapper import make_atari_env
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -22,9 +22,9 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=4213)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--lr", type=float, default=2.5e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
Expand All @@ -35,14 +35,14 @@ def get_args():
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--rew-norm", type=int, default=False)
parser.add_argument("--vf-coef", type=float, default=0.5)
parser.add_argument("--vf-coef", type=float, default=0.25)
parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--eps-clip", type=float, default=0.2)
parser.add_argument("--eps-clip", type=float, default=0.1)
parser.add_argument("--dual-clip", type=float, default=None)
parser.add_argument("--value-clip", type=int, default=0)
parser.add_argument("--value-clip", type=int, default=1)
parser.add_argument("--norm-adv", type=int, default=1)
parser.add_argument("--recompute-adv", type=int, default=0)
parser.add_argument("--logdir", type=str, default="log")
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_ppo(args=get_args()):
args.seed,
args.training_num,
args.test_num,
scale=args.scale_obs,
scale=0,
frame_stack=args.frames_stack,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
Expand All @@ -106,16 +106,20 @@ def test_ppo(args=get_args()):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# define model
net = DQN(
net_cls = scale_obs(DQN) if args.scale_obs else DQN
net = net_cls(
*args.state_shape,
args.action_shape,
device=args.device,
features_only=True,
output_dim=args.hidden_size
output_dim=args.hidden_size,
layer_init=layer_init,
)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)
optim = torch.optim.Adam(
ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5
)

lr_scheduler = None
if args.lr_decay:
Expand Down
Binary file modified examples/atari/results/ppo/Breakout_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/atari/results/ppo/Enduro_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/atari/results/ppo/MsPacman_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/atari/results/ppo/Pong_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/atari/results/ppo/Qbert_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/atari/results/ppo/Seaquest_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/atari/results/ppo/SpaceInvaders_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 3 additions & 8 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ def __init__(
"Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip
self._value_clip = value_clip
if not self._rew_norm:
assert not self._value_clip, \
"value clip is available only when `reward_normalization` is True"
self._norm_adv = advantage_normalization
self._recompute_adv = recompute_advantage
self._actor_critic: ActorCritic
Expand All @@ -94,11 +91,8 @@ def process_fn(
self._buffer, self._indices = buffer, indices
batch = self._compute_returns(batch, buffer, indices)
batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob = []
with torch.no_grad():
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.logp_old = self(batch).dist.log_prob(batch.act)
return batch

def learn( # type: ignore
Expand All @@ -113,7 +107,8 @@ def learn( # type: ignore
dist = self(minibatch).dist
if self._norm_adv:
mean, std = minibatch.adv.mean(), minibatch.adv.std()
minibatch.adv = (minibatch.adv - mean) / std # per-batch norm
minibatch.adv = (minibatch.adv -
mean) / (std + self._eps) # per-batch norm
ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
Expand Down
3 changes: 2 additions & 1 deletion tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def test_step(self) -> Tuple[Dict[str, Any], bool]:
print(
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
f"{self.best_reward_std:.6f} in #{self.best_epoch}",
flush=True
)
if not self.is_run:
test_stat = {
Expand Down