From b17c6dc8b8a86c6757adcb3d61c4a2866752a03b Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Mon, 19 Feb 2024 22:05:35 +0100 Subject: [PATCH 001/115] Remove mutable logger_factory used by examples * Creating a new instance of LoggerFactoryDefault within each module --- examples/atari/atari_c51.py | 3 ++- examples/atari/atari_dqn.py | 3 ++- examples/atari/atari_fqf.py | 3 ++- examples/atari/atari_iqn.py | 3 ++- examples/atari/atari_ppo.py | 3 ++- examples/atari/atari_qrdqn.py | 3 ++- examples/atari/atari_rainbow.py | 3 ++- examples/atari/atari_sac.py | 3 ++- examples/common.py | 3 --- examples/mujoco/fetch_her_ddpg.py | 28 +++++++++++++--------------- examples/mujoco/mujoco_a2c.py | 3 ++- examples/mujoco/mujoco_ddpg.py | 3 ++- examples/mujoco/mujoco_npg.py | 3 ++- examples/mujoco/mujoco_ppo.py | 3 ++- examples/mujoco/mujoco_redq.py | 3 ++- examples/mujoco/mujoco_reinforce.py | 3 ++- examples/mujoco/mujoco_sac.py | 3 ++- examples/mujoco/mujoco_td3.py | 3 ++- examples/mujoco/mujoco_trpo.py | 3 ++- examples/offline/atari_bcq.py | 3 ++- examples/offline/atari_cql.py | 3 ++- examples/offline/atari_crr.py | 3 ++- examples/offline/atari_il.py | 3 ++- examples/vizdoom/vizdoom_c51.py | 3 ++- examples/vizdoom/vizdoom_ppo.py | 3 ++- 25 files changed, 59 insertions(+), 41 deletions(-) delete mode 100644 examples/common.py diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 1946cc790..cde4d18b1 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -9,8 +9,8 @@ from atari_network import C51 from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -122,6 +122,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index c669fa714..7517d6f9b 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -9,8 +9,8 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DQNPolicy from tianshou.policy.base import BasePolicy from tianshou.policy.modelbased.icm import ICMPolicy @@ -157,6 +157,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 31adf9efd..c7a744bbc 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -9,8 +9,8 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import FQFPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -135,6 +135,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index d9832b9ea..f468f5457 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -9,8 +9,8 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import IQNPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -132,6 +132,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 86f54d4d7..6fa4eca3c 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -11,8 +11,8 @@ from torch.distributions import Categorical, Distribution from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -200,6 +200,7 @@ def dist(logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index cef1c4247..a4a29d84c 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -9,8 +9,8 @@ from atari_network import QRDQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import QRDQNPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -118,6 +118,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index dbea00688..fcf6a19ca 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -9,8 +9,8 @@ from atari_network import Rainbow from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy, RainbowPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -152,6 +152,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 7dc60c0e8..72d356ec1 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -9,8 +9,8 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteSACPolicy, ICMPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -182,6 +182,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/common.py b/examples/common.py deleted file mode 100644 index a86115c2a..000000000 --- a/examples/common.py +++ /dev/null @@ -1,3 +0,0 @@ -from tianshou.highlevel.logger import LoggerFactoryDefault - -logger_factory = LoggerFactoryDefault() diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 0804b1a34..224e1fa98 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -9,7 +9,6 @@ import gymnasium as gym import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from tianshou.data import ( @@ -19,12 +18,12 @@ ReplayBuffer, VectorReplayBuffer, ) +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic @@ -96,20 +95,19 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) env, train_envs, test_envs = make_fetch_env(args.task, args.training_num, args.test_num) args.state_shape = { diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 95b645dc3..734767121 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import A2CPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -181,6 +181,7 @@ def dist(*logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 04fc0109b..ae1046cad 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -9,9 +9,9 @@ import torch from mujoco_env import make_mujoco_env -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPGPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -130,6 +130,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 454565a46..5601ea26b 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import NPGPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -178,6 +178,7 @@ def dist(*logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index c0d868cf2..eb2fe03b4 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPOPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -186,6 +186,7 @@ def dist(*logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 66c9f7db6..7a82a04b7 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -9,8 +9,8 @@ import torch from mujoco_env import make_mujoco_env -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import REDQPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -159,6 +159,7 @@ def linear(x, y): log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index f4a86934a..5c3c227e5 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PGPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -158,6 +158,7 @@ def dist(*logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index a09118979..ed37a6025 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -9,8 +9,8 @@ import torch from mujoco_env import make_mujoco_env -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import SACPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -153,6 +153,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 057905c64..82b5e3bcc 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -9,9 +9,9 @@ import torch from mujoco_env import make_mujoco_env -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TD3Policy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -151,6 +151,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index b001fd04c..dc9e4cb69 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TRPOPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -183,6 +183,7 @@ def dist(*logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 766662a07..8a2f41517 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -12,9 +12,9 @@ from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env -from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteBCQPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer @@ -157,6 +157,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 5f1afcdcd..5d20342bf 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -12,9 +12,9 @@ from examples.atari.atari_network import QRDQN from examples.atari.atari_wrapper import make_atari_env -from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCQLPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer @@ -133,6 +133,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 97622b6d5..8fe2642fb 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -12,9 +12,9 @@ from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env -from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCRRPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer @@ -156,6 +156,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 615d38ec0..d897d7fe7 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -12,9 +12,9 @@ from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env -from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer @@ -117,6 +117,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 9ba82a10f..997996b32 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -9,8 +9,8 @@ from env import make_vizdoom_env from network import C51 -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -130,6 +130,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 5c2a9e1f7..c58281d89 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -11,8 +11,8 @@ from torch.distributions import Categorical, Distribution from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -210,6 +210,7 @@ def dist(logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project From 933b27a02714b89fd2325606cc844b4f51f87c61 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:09:38 +0100 Subject: [PATCH 002/115] Fix mypy issues --- test/base/test_returns.py | 40 ++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 2dbf47c29..0196a415b 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -4,10 +4,11 @@ import torch from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data.types import BatchWithReturnsProtocol from tianshou.policy import BasePolicy -def compute_episodic_return_base(batch, gamma): +def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: returns = np.zeros_like(batch.rew) last = 0 for i in reversed(range(len(batch.rew))): @@ -19,7 +20,7 @@ def compute_episodic_return_base(batch, gamma): return batch -def test_episodic_returns(size=2560) -> None: +def test_episodic_returns(size: int = 2560) -> None: fn = BasePolicy.compute_episodic_return buf = ReplayBuffer(20) batch = Batch( @@ -34,7 +35,7 @@ def test_episodic_returns(size=2560) -> None: }, ), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) @@ -46,7 +47,7 @@ def test_episodic_returns(size=2560) -> None: truncated=np.array([0, 0, 0, 0, 0, 0, 0.0]), rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) @@ -58,7 +59,7 @@ def test_episodic_returns(size=2560) -> None: truncated=np.array([0, 0, 0, 0, 0, 0, 0]), rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) @@ -70,7 +71,7 @@ def test_episodic_returns(size=2560) -> None: truncated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) @@ -118,7 +119,7 @@ def test_episodic_returns(size=2560) -> None: }, ), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) @@ -148,15 +149,15 @@ def test_episodic_returns(size=2560) -> None: truncated=np.zeros(size), rew=np.random.random(size), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) indices = buf.sample_indices(0) - def vanilla(): + def vanilla() -> Batch: return compute_episodic_return_base(batch, gamma=0.1) - def optimized(): + def optimized() -> tuple[np.ndarray, np.ndarray]: return fn(batch, buf, indices, gamma=0.1, gae_lambda=1.0) cnt = 3000 @@ -164,17 +165,22 @@ def optimized(): print("GAE optim ", timeit(optimized, setup=optimized, number=cnt)) -def target_q_fn(buffer, indices): +def target_q_fn(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: # return the next reward indices = buffer.next(indices) return torch.tensor(-buffer.rew[indices], dtype=torch.float32) -def target_q_fn_multidim(buffer, indices): +def target_q_fn_multidim(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return target_q_fn(buffer, indices).unsqueeze(1).repeat(1, 51) -def compute_nstep_return_base(nstep, gamma, buffer, indices): +def compute_nstep_return_base( + nstep: int, + gamma: float, + buffer: ReplayBuffer, + indices: np.ndarray, +) -> np.ndarray: returns = np.zeros_like(indices, dtype=float) buf_len = len(buffer) for i in range(len(indices)): @@ -195,7 +201,7 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices): return returns -def test_nstep_returns(size=10000) -> None: +def test_nstep_returns(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( @@ -273,7 +279,7 @@ def test_nstep_returns(size=10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) -def test_nstep_returns_with_timelimit(size=10000) -> None: +def test_nstep_returns_with_timelimit(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( @@ -366,10 +372,10 @@ def test_nstep_returns_with_timelimit(size=10000) -> None: ) batch, indices = buf.sample(256) - def vanilla(): + def vanilla() -> np.ndarray: return compute_nstep_return_base(3, 0.1, buf, indices) - def optimized(): + def optimized() -> BatchWithReturnsProtocol: return BasePolicy.compute_nstep_return( batch, buf, From 7947f534a819962e1f3488af26aaf457103566a5 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 14:08:01 +0100 Subject: [PATCH 003/115] Add type annotations to funcs signatures * Remove func `compute_reward_fn` as it was defined already before --- test/base/test_buffer.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 99154bbdf..717d16415 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -27,7 +27,7 @@ from test.base.env import MyGoalEnv, MyTestEnv -def test_replaybuffer(size=10, bufsize=20) -> None: +def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) @@ -139,7 +139,7 @@ def test_replaybuffer(size=10, bufsize=20) -> None: assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) -def test_ignore_obs_next(size=10) -> None: +def test_ignore_obs_next(size: int = 10) -> None: # Issue 82 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): @@ -208,7 +208,7 @@ def test_ignore_obs_next(size=10) -> None: assert data.obs_next -def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: +def test_stack(size: int = 5, bufsize: int = 9, stack_num: int = 4, cached_num: int = 3) -> None: env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) @@ -279,7 +279,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: buf[bufsize * 2] -def test_priortized_replaybuffer(size=32, bufsize=15) -> None: +def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: env = MyTestEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) @@ -329,11 +329,11 @@ def test_priortized_replaybuffer(size=32, bufsize=15) -> None: assert weight[mask][0] <= 1 -def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4) -> None: +def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4) -> None: env_size = size env = MyGoalEnv(env_size, array_state=True) - def compute_reward_fn(ag, g): + def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: return env.compute_reward_fn(ag, g, {}) buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) @@ -431,9 +431,6 @@ def compute_reward_fn(ag, g): bufsize = 15 env = MyGoalEnv(env_size, array_state=False) - def compute_reward_fn(ag, g): - return env.compute_reward_fn(ag, g, {}) - buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf._index = 5 # shifted start index buf.future_p = 1 @@ -605,10 +602,10 @@ def test_segtree() -> None: tree = SegmentTree(size) tree[np.arange(size)] = naive - def sample_npbuf(): + def sample_npbuf() -> np.ndarray: return np.random.choice(size, bsz, p=naive / naive.sum()) - def sample_tree(): + def sample_tree() -> int | np.ndarray: scalar = np.random.rand(bsz) * tree.reduce() return tree.get_prefix_sum_idx(scalar) From 116349423acd95569ec15b8cc276c6ff255fb590 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 14:37:38 +0100 Subject: [PATCH 004/115] Rename variable to resolve type conflict * Same variable name used multiple times for objects of different types --- test/base/test_buffer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 717d16415..4bfd0cea9 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -319,8 +319,8 @@ def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2) ** buf._alpha) # check multi buffer's data assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) - batch, indices = buf2.sample(10) - buf2.update_weight(indices, batch.weight * 0) + batch_sample, indices = buf2.sample(10) + buf2.update_weight(indices, batch_sample.weight * 0) weight = buf2[np.arange(buf2.maxsize)].weight mask = np.isin(np.arange(buf2.maxsize), indices) assert np.all(weight[mask] == weight[mask][0]) @@ -368,7 +368,7 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: assert len(buf) == min(bufsize, i + 1) assert len(buf2) == min(bufsize, 3 * (i + 1)) - batch, indices = buf.sample(sample_sz) + batch_sample, indices = buf.sample(sample_sz) # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() @@ -398,7 +398,7 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: tmp_indices = buf.next(tmp_indices) # Test vector buffer - batch, indices = buf2.sample(sample_sz) + batch_sample, indices = buf2.sample(sample_sz) # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() @@ -451,7 +451,7 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: ) buf.add(batch) obs = obs_next - batch, indices = buf.sample(0) + batch_sample, indices = buf.sample(0) assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal) assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal) assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep) @@ -1179,7 +1179,7 @@ def test_multibuf_stack() -> None: buf.stack_num = 2 indices = buf5.sample_indices(0) assert np.allclose(sorted(indices), [0, 1, 2, 5, 6, 7, 10, 15, 20]) - batch, _ = buf5.sample(0) + batch_sample, _ = buf5.sample(0) # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), From aa0f1310fe5af0ec60f74d873287b3841b467148 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 15:41:48 +0100 Subject: [PATCH 005/115] Provide right input type to buffer methods --- test/base/test_buffer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 4bfd0cea9..7d4e1d8fd 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -857,7 +857,7 @@ def test_replaybuffermanager() -> None: assert np.all(ptr == [10]) assert np.all(ep_idx == [13]) assert np.allclose(buf.unfinished_index(), [4]) - indices = sorted(buf.sample_indices(0)) + indices = np.array(sorted(buf.sample_indices(0))) assert np.allclose(indices, np.arange(len(buf))) assert np.allclose( buf.prev(indices), @@ -910,8 +910,8 @@ def test_replaybuffermanager() -> None: ], ) # corner case: list, int and -1 - assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] - assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] + assert buf.prev(-1) == buf.prev(np.array([buf.maxsize - 1]))[0] + assert buf.next(-1) == buf.next(np.array([buf.maxsize - 1]))[0] batch = buf._meta batch.info = np.ones(buf.maxsize) buf.set_batch(batch) @@ -1128,7 +1128,7 @@ def test_multibuf_stack() -> None: ], ), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) - indices = sorted(buf4.sample_indices(0)) + indices = np.array(sorted(buf4.sample_indices(0))) assert np.allclose(indices, [*list(range(bufsize)), 9, 10, 14, 15, 19, 20]) assert np.allclose( buf4[indices].obs[..., 0], From 7c91575a69721270b4fe82c67439d70a8fd28ec3 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 16:12:41 +0100 Subject: [PATCH 006/115] Index sample from buffer according to cheatsheet recommendations --- test/base/test_buffer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 7d4e1d8fd..649d5f5a9 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -452,10 +452,10 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: buf.add(batch) obs = obs_next batch_sample, indices = buf.sample(0) - assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal) - assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal) - assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep) - assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep) + assert np.all(buf.obs.desired_goal[:5] == buf.obs.desired_goal[0]) + assert np.all(buf.obs.desired_goal[5:10] == buf.obs.desired_goal[5]) + assert np.all(buf.obs.desired_goal[10:] == buf.obs.desired_goal[0]) # (same ep) + assert np.all(buf.obs.desired_goal[0] != buf.obs.desired_goal[5]) # (diff ep) # Another test case for cycled indices env_size = 99 From 76fe01e4aeaadc2b5f467431fcf88a7c5d3b4327 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 16:53:57 +0100 Subject: [PATCH 007/115] Type index to make mypy happy * index can be int or np.ndarray. since it's defined twice with the same name, mypy gets confused --- test/base/test_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 649d5f5a9..fbbcde884 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -521,10 +521,10 @@ def test_segtree() -> None: assert np.all([tree[i] == 0.0 for i in range(actual_len)]) with pytest.raises(IndexError): tree[actual_len] - naive = np.zeros([actual_len]) + naive = np.zeros(actual_len) for _ in range(1000): # random choose a place to perform single update - index = np.random.randint(actual_len) + index: int | np.ndarray = np.random.randint(actual_len) value = np.random.rand() naive[index] = value tree[index] = value From 33df04c9c70ad4187ed6216fa511384a1a913db9 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 17:52:43 +0100 Subject: [PATCH 008/115] Make mypy happy and check for mask attribute before asserting tests --- test/base/test_buffer.py | 42 +++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index fbbcde884..5a042ed23 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -164,29 +164,31 @@ def test_ignore_obs_next(size: int = 10) -> None: assert isinstance(data, Batch) assert isinstance(data2, Batch) assert np.allclose(indices, orig) - assert np.allclose(data.obs_next.mask, data2.obs_next.mask) - assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) + if hasattr(data.obs_next, "mask") and hasattr(data2.obs_next, "mask"): + assert np.allclose(data.obs_next.mask, data2.obs_next.mask) + assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) buf.stack_num = 4 data = buf[indices] data2 = buf[indices] - assert np.allclose(data.obs_next.mask, data2.obs_next.mask) - assert np.allclose( - data.obs_next.mask, - np.array( - [ - [0, 0, 0, 0], - [1, 1, 1, 2], - [1, 1, 2, 3], - [1, 1, 2, 3], - [4, 4, 4, 5], - [4, 4, 5, 6], - [4, 4, 5, 6], - [7, 7, 7, 8], - [7, 7, 8, 9], - [7, 7, 8, 9], - ], - ), - ) + if hasattr(data.obs_next, "mask") and hasattr(data2.obs_next, "mask"): + assert np.allclose(data.obs_next.mask, data2.obs_next.mask) + assert np.allclose( + data.obs_next.mask, + np.array( + [ + [0, 0, 0, 0], + [1, 1, 1, 2], + [1, 1, 2, 3], + [1, 1, 2, 3], + [4, 4, 4, 5], + [4, 4, 5, 6], + [4, 4, 5, 6], + [7, 7, 7, 8], + [7, 7, 8, 9], + [7, 7, 8, 9], + ], + ), + ) assert np.allclose(data.info["if"], data2.info["if"]) assert np.allclose( data.info["if"], From ba5d74e0bdae904c565a08bd058aeebb8e81ffaa Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 18:06:29 +0100 Subject: [PATCH 009/115] Make mypy happy and specify union of types --- test/base/test_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 5a042ed23..d6c6224a8 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -527,7 +527,7 @@ def test_segtree() -> None: for _ in range(1000): # random choose a place to perform single update index: int | np.ndarray = np.random.randint(actual_len) - value = np.random.rand() + value: float | np.ndarray = np.random.rand() naive[index] = value tree[index] = value for i in range(actual_len): From 8d03ea4d7e736017de30b8edd01c6da1a97556c5 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 18:38:25 +0100 Subject: [PATCH 010/115] Make mypy happy and type ndarray --- test/base/test_buffer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index d6c6224a8..ada6d2a8f 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -5,6 +5,7 @@ import h5py import numpy as np +import numpy.typing as npt import pytest import torch @@ -1284,7 +1285,7 @@ def test_multibuf_hdf5() -> None: def test_from_data() -> None: - obs_data = np.ndarray((10, 3, 3), dtype="uint8") + obs_data: npt.NDArray[np.uint8] = np.ndarray((10, 3, 3), dtype="uint8") for i in range(10): obs_data[i] = i * np.ones((3, 3), dtype="uint8") obs_next_data = np.zeros_like(obs_data) From 13ae7a891443b5388e5dcccffafaf45ecf1d0452 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 18:39:07 +0100 Subject: [PATCH 011/115] Use recommended outer buffer indexing --- test/base/test_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index ada6d2a8f..0dfbbf59e 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -508,8 +508,8 @@ def test_update() -> None: assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) - assert (buf2[0].obs == buf1[1].obs).all() - assert (buf2[-1].obs == buf1[0].obs).all() + assert (buf2.obs[0] == buf1.obs[1]).all() + assert (buf2.obs[-1] == buf1.obs[0]).all() b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) with pytest.raises(NotImplementedError): b.update(b) From f282d814afb3b1a76d4cb97535051de752e9527f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 23 Feb 2024 21:19:30 +0100 Subject: [PATCH 012/115] Access buffer attrs with __getattr__ --- test/base/test_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 0dfbbf59e..47ecd22d8 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -699,8 +699,8 @@ def test_hdf5() -> None: assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: - assert np.all(buffers[k][:].info.number.n == _buffers[k][:].info.number.n) - assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra) + assert np.all(buffers[k][:]["info"].number.n == _buffers[k][:]["info"].number.n) + assert np.all(buffers[k][:]["info"]["extra"] == _buffers[k][:]["info"]["extra"]) # raise exception when value cannot be pickled data = {"not_supported": lambda x: x * x} From 280596721df9923dbf8748e58468ad1263e9f289 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 29 Feb 2024 17:51:55 +0100 Subject: [PATCH 013/115] Ignore mypy issue as it explicitly tests for invalid type --- test/base/test_batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index aaacffdd4..e22b0b235 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -350,9 +350,9 @@ def test_batch_cat_and_stack() -> None: # test with illegal input format with pytest.raises(ValueError): - Batch.cat([[Batch(a=1)], [Batch(a=1)]]) + Batch.cat([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # cat() tested with invalid inp with pytest.raises(ValueError): - Batch.stack([[Batch(a=1)], [Batch(a=1)]]) + Batch.stack([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # stack() tested with invalid inp # exceptions assert Batch.cat([]).is_empty() From eaa45034a2b67999dcc29e1c1dcf87cb5f8d1978 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 29 Feb 2024 22:05:13 +0100 Subject: [PATCH 014/115] Make mypy happy & use explicit var typing and ignore --- test/base/test_batch.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index e22b0b235..3c5f1c4cc 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -135,7 +135,7 @@ def test_batch() -> None: assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2.a.d.f = {} - batch2_sum = (batch2 + 1.0) * 2 + batch2_sum = (batch2 + 1.0) * 2 # type: ignore # __add__ supports Number as input type assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 @@ -355,8 +355,10 @@ def test_batch_cat_and_stack() -> None: Batch.stack([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # stack() tested with invalid inp # exceptions - assert Batch.cat([]).is_empty() - assert Batch.stack([]).is_empty() + batch_cat: Batch = Batch.cat([]) + assert batch_cat.is_empty() + batch_stack: Batch = Batch.stack([]) + assert batch_stack.is_empty() b1 = Batch(e=[4, 5], d=6) b2 = Batch(e=[4, 6]) with pytest.raises(ValueError): @@ -548,7 +550,7 @@ def test_batch_empty() -> None: def test_batch_standard_compatibility() -> None: batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])) batch_mean = np.mean(batch) - assert isinstance(batch_mean, Batch) + assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst` assert sorted(batch_mean.keys()) == ["a", "b", "c"] with pytest.raises(TypeError): len(batch_mean) From cb765b4b9eaa894b09c9231d54f7d9d48bfe53c9 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 29 Feb 2024 22:45:56 +0100 Subject: [PATCH 015/115] Ignore type on explicit error --- test/base/test_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 3c5f1c4cc..ef5812203 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -141,7 +141,7 @@ def test_batch() -> None: assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): - batch2 += [1] + batch2 += [1] # type: ignore # error is raised explicitly batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))}) batch3.a.d[0] = {"e": 4.0} assert batch3.a.d.e[0] == 4.0 From 1aacd102ce21388862bd8a0bd8506111c66a822e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 29 Feb 2024 22:46:35 +0100 Subject: [PATCH 016/115] Remove redundant assert --- test/base/test_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index ef5812203..86958184f 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -251,7 +251,7 @@ def test_batch_cat_and_stack() -> None: assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert ans.a.t.is_empty() - assert b1.stack_([b2]) is None + b1.stack_([b2]) assert isinstance(b1.a.d.e, np.ndarray) assert b1.a.d.e.ndim == 2 From ef9581bc291fb89b695afd95b1462819e6910f55 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 3 Mar 2024 20:54:44 +0100 Subject: [PATCH 017/115] Add typing to func args --- examples/vizdoom/env.py | 19 +++++++++++++------ examples/vizdoom/replay.py | 2 +- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index f5c974fa0..55348d902 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -1,4 +1,5 @@ import os +from collections.abc import Sequence import cv2 import gymnasium as gym @@ -13,7 +14,7 @@ envpool = None -def normal_button_comb(): +def normal_button_comb() -> list: actions = [] m_forward = [[0.0], [1.0]] t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] @@ -23,7 +24,7 @@ def normal_button_comb(): return actions -def battle_button_comb(): +def battle_button_comb() -> list: actions = [] m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] @@ -41,7 +42,13 @@ def battle_button_comb(): class Env(gym.Env): - def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False) -> None: + def __init__( + self, + cfg_path: str, + frameskip: int = 4, + res: Sequence[int] = (4, 40, 60), + save_lmp: bool = False, + ) -> None: super().__init__() self.save_lmp = save_lmp self.health_setting = "battle" in cfg_path @@ -62,7 +69,7 @@ def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False) -> No self.spec = gym.envs.registration.EnvSpec("vizdoom-v0") self.count = 0 - def get_obs(self): + def get_obs(self) -> None: state = self.game.get_state() if state is None: return @@ -107,10 +114,10 @@ def step(self, action): info["TimeLimit.truncated"] = True return self.obs_buffer, reward, done, info - def render(self): + def render(self) -> None: pass - def close(self): + def close(self) -> None: self.game.close() diff --git a/examples/vizdoom/replay.py b/examples/vizdoom/replay.py index 4437a08ba..ac6e183e4 100755 --- a/examples/vizdoom/replay.py +++ b/examples/vizdoom/replay.py @@ -6,7 +6,7 @@ import vizdoom as vzd -def main(cfg_path="maps/D3_battle.cfg", lmp_path="test.lmp") -> None: +def main(cfg_path: str = "maps/D3_battle.cfg", lmp_path: str = "test.lmp") -> None: game = vzd.DoomGame() game.load_config(cfg_path) game.set_screen_format(vzd.ScreenFormat.CRCGCB) From 0356d8d2e032a175c474db71f414864bc42d02f0 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Mon, 4 Mar 2024 10:43:54 +0100 Subject: [PATCH 018/115] Add type annotations to funcs --- examples/mujoco/gen_json.py | 3 ++- examples/mujoco/tools.py | 20 +++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/mujoco/gen_json.py b/examples/mujoco/gen_json.py index b41b06b15..54ee3bb90 100755 --- a/examples/mujoco/gen_json.py +++ b/examples/mujoco/gen_json.py @@ -4,9 +4,10 @@ import json import os import sys +from os import PathLike -def merge(rootdir): +def merge(rootdir: str | PathLike[str]) -> None: """format: $rootdir/$algo/*.csv.""" result = [] for path, _, filenames in os.walk(rootdir): diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index be289e33a..e0db8162b 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -5,13 +5,16 @@ import os import re from collections import defaultdict +from os import PathLike +from re import Pattern +from typing import Any import numpy as np import tqdm from tensorboard.backend.event_processing import event_accumulator -def find_all_files(root_dir, pattern): +def find_all_files(root_dir: str | PathLike[str], pattern: str | Pattern[str]) -> list: """Find all files under root_dir according to relative pattern.""" file_list = [] for dirname, _, files in os.walk(root_dir): @@ -22,7 +25,7 @@ def find_all_files(root_dir, pattern): return file_list -def group_files(file_list, pattern): +def group_files(file_list: list[str], pattern: str | Pattern[str]) -> dict[str, list]: res = defaultdict(list) for f in file_list: match = re.search(pattern, f) @@ -31,7 +34,7 @@ def group_files(file_list, pattern): return res -def csv2numpy(csv_file): +def csv2numpy(csv_file: str) -> dict[Any, np.ndarray]: csv_dict = defaultdict(list) with open(csv_file) as f: for row in csv.DictReader(f): @@ -40,7 +43,10 @@ def csv2numpy(csv_file): return {k: np.array(v) for k, v in csv_dict.items()} -def convert_tfevents_to_csv(root_dir, refresh=False): +def convert_tfevents_to_csv( + root_dir: str | PathLike[str], + refresh: bool = False, +) -> dict[str, list]: """Recursively convert test/reward from all tfevent file under root_dir to csv. This function assumes that there is at most one tfevents file in each directory @@ -81,7 +87,11 @@ def convert_tfevents_to_csv(root_dir, refresh=False): return result -def merge_csv(csv_files, root_dir, remove_zero=False): +def merge_csv( + csv_files: dict[str, list], + root_dir: str | PathLike[str], + remove_zero: bool = False, +) -> None: """Merge result in csv_files into a single csv file.""" assert len(csv_files) > 0 if remove_zero: From c48d50f4ce25302cd9a06e9df89dcb98d4cb2094 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 5 Mar 2024 23:20:35 +0100 Subject: [PATCH 019/115] Fix mypy issues --- examples/mujoco/plotter.py | 58 ++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py index 60840decf..ee86f712f 100755 --- a/examples/mujoco/plotter.py +++ b/examples/mujoco/plotter.py @@ -3,6 +3,7 @@ import argparse import os import re +from typing import Any import matplotlib.pyplot as plt import matplotlib.ticker as mticker @@ -10,7 +11,12 @@ from tools import csv2numpy, find_all_files, group_files -def smooth(y, radius, mode="two_sided", valid_only=False): +def smooth( + y: np.ndarray, + radius: int, + mode: str = "two_sided", + valid_only: bool = False, +) -> np.ndarray: """Smooth signal y, where radius is determines the size of the window. mode='twosided': @@ -88,23 +94,25 @@ def smooth(y, radius, mode="two_sided", valid_only=False): def plot_ax( - ax, - file_lists, - legend_pattern=".*", - xlabel=None, - ylabel=None, - title=None, - xlim=None, - xkey="env_step", - ykey="reward", - smooth_radius=0, - shaded_std=True, - legend_outside=False, -): - def legend_fn(x): + ax: plt.Axes, + file_lists: list[str], + legend_pattern: str = ".*", + xlabel: str | None = None, + ylabel: str | None = None, + title: str = "", + xlim: float | None = None, + xkey: str = "env_step", + ykey: str = "reward", + smooth_radius: int = 0, + shaded_std: bool = True, + legend_outside: bool = False, +) -> None: + def legend_fn(x: str) -> str: # return os.path.split(os.path.join( # args.root_dir, x))[0].replace('/', '_') + " (10)" - return re.search(legend_pattern, x).group(0) + match = re.search(legend_pattern, x) + assert match is not None # for mypy + return match.group(0) legneds = map(legend_fn, file_lists) # sort filelist according to legends @@ -139,15 +147,15 @@ def legend_fn(x): def plot_figure( - file_lists, - group_pattern=None, - fig_length=6, - fig_width=6, - sharex=False, - sharey=False, - title=None, - **kwargs, -): + file_lists: list[str], + group_pattern: str | None = None, + fig_length: int = 6, + fig_width: int = 6, + sharex: bool = False, + sharey: bool = False, + title: str = "", + **kwargs: Any, +) -> None: if not group_pattern: fig, ax = plt.subplots(figsize=(fig_length, fig_width)) plot_ax(ax, file_lists, title=title, **kwargs) From ebacf998a7f03c56dc2fb91d563cb556b120df59 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 13:38:09 +0100 Subject: [PATCH 020/115] Use DataclassPPrintMixin to print collector stats --- examples/atari/atari_c51.py | 3 +-- examples/atari/atari_dqn.py | 3 +-- examples/atari/atari_fqf.py | 3 +-- examples/atari/atari_iqn.py | 3 +-- examples/atari/atari_ppo.py | 3 +-- examples/atari/atari_qrdqn.py | 3 +-- examples/atari/atari_rainbow.py | 3 +-- examples/atari/atari_sac.py | 3 +-- examples/offline/atari_il.py | 4 +--- examples/vizdoom/vizdoom_c51.py | 5 +---- examples/vizdoom/vizdoom_ppo.py | 5 +---- test/discrete/test_bdq.py | 2 +- test/modelbased/test_psrl.py | 2 +- 13 files changed, 13 insertions(+), 29 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index cde4d18b1..c6fe6dd04 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -183,8 +183,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 7517d6f9b..1577315cf 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -224,8 +224,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index c7a744bbc..f616a6838 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -196,8 +196,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index f468f5457..911069400 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -193,8 +193,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 6fa4eca3c..1b635a48e 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -253,8 +253,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index a4a29d84c..f33c5da5b 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -179,8 +179,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index fcf6a19ca..86e7fe0e1 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -223,8 +223,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 72d356ec1..f4dbe57c1 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -235,8 +235,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index d897d7fe7..a1bb62f70 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -145,9 +145,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) - rew = result.returns_stat.mean - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 997996b32..62daaf64f 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -189,10 +189,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - lens = result.lens_stat.mean * args.skip_num - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") - print(f"Mean length (over {result.n_collected_episodes} episodes): {lens}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index c58281d89..08538036a 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -255,10 +255,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - lens = result.lens_stat.mean * args.skip_num - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") - print(f"Mean length (over {result.n_collected_episodes} episodes): {lens}") + result.pprint_asdict() if args.watch: watch() diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index a45b02114..b1750d49d 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -145,7 +145,7 @@ def stop_fn(mean_rewards: float) -> bool: test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) - print(collector_stats) + collector_stats.pprint_asdict() if __name__ == "__main__": diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 55719f47e..05691645f 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -122,7 +122,7 @@ def stop_fn(mean_rewards: float) -> bool: test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.rew_mean}, length: {result.len_mean}") + result.pprint_asdict() elif env.spec.reward_threshold: assert result.best_reward >= env.spec.reward_threshold From b6accd4228f61b8a583b136d92ed7216be0dd4c2 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 13:54:55 +0100 Subject: [PATCH 021/115] Add type annotations to func --- examples/mujoco/analysis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/mujoco/analysis.py b/examples/mujoco/analysis.py index 9c2983e04..b881cdd34 100755 --- a/examples/mujoco/analysis.py +++ b/examples/mujoco/analysis.py @@ -3,13 +3,14 @@ import argparse import re from collections import defaultdict +from os import PathLike import numpy as np from tabulate import tabulate from tools import csv2numpy, find_all_files, group_files -def numerical_analysis(root_dir, xlim, norm=False): +def numerical_analysis(root_dir: str | PathLike, xlim: float, norm: bool = False) -> None: file_pattern = re.compile(r".*/test_reward_\d+seeds.csv$") norm_group_pattern = re.compile(r"(/|^)\w+?\-v(\d|$)") output_group_pattern = re.compile(r".*?(?=(/|^)\w+?\-v\d)") From 52bb1e35f3154e714e347e9fd85c0585d9ebaad4 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 14:21:45 +0100 Subject: [PATCH 022/115] Add type annotations to funcs * Make mypy happy and annotate some vars --- examples/offline/convert_rl_unplugged_atari.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py index a28a35e5f..1afd721a5 100755 --- a/examples/offline/convert_rl_unplugged_atari.py +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -28,10 +28,11 @@ clipping. """ import os -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace import h5py import numpy as np +import numpy.typing as npt import requests import tensorflow as tf from tqdm import tqdm @@ -172,7 +173,7 @@ def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch: # Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 -def download(url: str, fname: str, chunk_size=1024): +def download(url: str, fname: str, chunk_size: int | None = 1024) -> None: resp = requests.get(url, stream=True) total = int(resp.headers.get("content-length", 0)) if os.path.exists(fname): @@ -192,11 +193,11 @@ def download(url: str, fname: str, chunk_size=1024): def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None: download(url, fname) - obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") - act = np.ndarray((maxsize,), dtype="int64") - rew = np.ndarray((maxsize,), dtype="float32") - done = np.ndarray((maxsize,), dtype="bool") - obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") + obs: npt.NDArray[np.uint8] = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") + act: npt.NDArray[np.int64] = np.ndarray((maxsize,), dtype="int64") + rew: npt.NDArray[np.float32] = np.ndarray((maxsize,), dtype="float32") + done: npt.NDArray[np.bool_] = np.ndarray((maxsize,), dtype="bool") + obs_next: npt.NDArray[np.uint8] = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") i = 0 file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP") for example in file_ds: @@ -238,7 +239,7 @@ def process_dataset( process_shard(url, filepath, ofname) -def main(args) -> None: +def main(args: Namespace) -> None: if args.task not in ALL_GAMES: raise KeyError(f"`{args.task}` is not in the list of games.") fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards) From 14b769f6d285b150f261c9e81e5a7aa8810feef0 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:05:49 +0100 Subject: [PATCH 023/115] Fix policy annotation --- examples/vizdoom/vizdoom_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 08538036a..f698497b8 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -140,7 +140,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: def dist(logits: torch.Tensor) -> Distribution: return Categorical(logits=logits) - policy: PPOPolicy = PPOPolicy( + policy: BasePolicy = PPOPolicy( actor=actor, critic=critic, optim=optim, From 59f9bb10795c33fb3f5df404c0f9477a819da1df Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:11:04 +0100 Subject: [PATCH 024/115] Add type annotations func --- examples/mujoco/mujoco_redq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 7a82a04b7..8eb9c3658 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -96,7 +96,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - def linear(x, y): + def linear(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(args.ensemble_size, x, y) net_c = Net( From ba6df442ebe451540369a17ab954aaaedb3297d4 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:59:24 +0100 Subject: [PATCH 025/115] Type policy var to resolve mypy confusion --- examples/atari/atari_dqn.py | 3 ++- examples/atari/atari_ppo.py | 3 ++- examples/atari/atari_sac.py | 3 ++- examples/box2d/bipedal_bdq.py | 2 +- examples/mujoco/mujoco_ddpg.py | 1 + examples/vizdoom/vizdoom_ppo.py | 3 ++- 6 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 1577315cf..765463cd3 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -104,7 +104,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: DQNPolicy = DQNPolicy( + policy: DQNPolicy | ICMPolicy + policy = DQNPolicy( model=net, optim=optim, action_space=env.action_space, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 1b635a48e..cd94eca18 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -135,7 +135,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: def dist(logits: torch.Tensor) -> Distribution: return Categorical(logits=logits) - policy: PPOPolicy = PPOPolicy( + policy: PPOPolicy | ICMPolicy + policy = PPOPolicy( actor=actor, critic=critic, optim=optim, diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index f4dbe57c1..78d05e7be 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -124,7 +124,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: DiscreteSACPolicy = DiscreteSACPolicy( + policy: DiscreteSACPolicy | ICMPolicy + policy = DiscreteSACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 16dfaf097..f782785a5 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -98,7 +98,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = BranchingDQNPolicy( + policy: BranchingDQNPolicy = BranchingDQNPolicy( model=net, optim=optim, discount_factor=args.gamma, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index ae1046cad..ec065e728 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -115,6 +115,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index f698497b8..010bb28ec 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -140,7 +140,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: def dist(logits: torch.Tensor) -> Distribution: return Categorical(logits=logits) - policy: BasePolicy = PPOPolicy( + policy: PPOPolicy | ICMPolicy + policy = PPOPolicy( actor=actor, critic=critic, optim=optim, From ebd140f9eff2be2335d4319d3d96bb6cc92dff21 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 17:53:31 +0100 Subject: [PATCH 026/115] Fix mypy issues * Use ActionSpaceInfo to type action space attrs * Print collector stats using DataclassPPrintMixin * Add missing type annotations to funcs --- examples/mujoco/fetch_her_ddpg.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 224e1fa98..d79d69bda 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -26,6 +26,8 @@ from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic +from tianshou.env.venvs import BaseVectorEnv +from tianshou.utils.space_info import ActionSpaceInfo def get_args() -> argparse.Namespace: @@ -76,7 +78,11 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def make_fetch_env(task, training_num, test_num): +def make_fetch_env( + task: str, + training_num: int, + test_num: int, +) -> tuple[gym.Env, BaseVectorEnv, BaseVectorEnv]: env = TruncatedAsTerminated(gym.make(task)) train_envs = ShmemVectorEnv( [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)], @@ -110,17 +116,21 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: ) env, train_envs, test_envs = make_fetch_env(args.task, args.training_num, args.test_num) + # The method HER works with goal-based environments + assert isinstance(env.observation_space, gym.spaces.Dict) args.state_shape = { "observation": env.observation_space["observation"].shape, "achieved_goal": env.observation_space["achieved_goal"].shape, "desired_goal": env.observation_space["desired_goal"].shape, } - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + action_info = ActionSpaceInfo.from_space(env.action_space) + args.action_shape = action_info.action_shape + args.max_action = action_info.max_action + args.exploration_noise = args.exploration_noise * args.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + print("Action range:", action_info.min_action, action_info.max_action) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -168,7 +178,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: print("Loaded agent from: ", args.resume_path) # collector - def compute_reward_fn(ag: np.ndarray, g: np.ndarray): + def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: return env.compute_reward(ag, g, {}) buffer: VectorReplayBuffer | ReplayBuffer | HERReplayBuffer | HERVectorReplayBuffer @@ -223,7 +233,7 @@ def save_best_fn(policy: BasePolicy) -> None: test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) - print(collector_stats) + collector_stats.pprint_asdict() if __name__ == "__main__": From bce45a70be6d1b2165a90f8dbe53a64ba670a511 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 22:06:33 +0100 Subject: [PATCH 027/115] Check env.spec for none before accessing attrs --- examples/atari/atari_wrapper.py | 2 +- examples/vizdoom/env.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index a2fdcca1f..8828d4b21 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -416,7 +416,7 @@ def __init__(self, task: str) -> None: def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: env = context.envs.env - if env.spec.reward_threshold: + if env.spec and env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold if "Pong" in self.task: return mean_rewards >= 20 diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 55348d902..04f4f0563 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -176,7 +176,8 @@ def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_n print(env.available_actions) action_num = env.action_space.n obs = env.reset() - print(env.spec.reward_threshold) + if env.spec: + print(env.spec.reward_threshold) print(obs.shape, action_num) for _ in range(4000): obs, rew, terminated, truncated, info = env.step(0) From 1dcd319e3fbc13a7493f2581f634910299d7eac3 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 22:07:48 +0100 Subject: [PATCH 028/115] Add type annotation using space info --- examples/discrete/discrete_dqn.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 2e9697adb..c588814ae 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -1,8 +1,11 @@ +from typing import cast + import gymnasium as gym import torch from torch.utils.tensorboard import SummaryWriter import tianshou as ts +from tianshou.utils.space_info import SpaceInfo def main() -> None: @@ -26,8 +29,10 @@ def main() -> None: # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network env = gym.make(task, render_mode="human") - state_shape = env.observation_space.shape or env.observation_space.n - action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + state_shape = space_info.observation_info.obs_shape + action_shape = space_info.action_info.action_shape net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = torch.optim.Adam(net.parameters(), lr=lr) From 48af507b4f8d358b4eca2685886094b201f36114 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 22:53:33 +0100 Subject: [PATCH 029/115] Add type annotations to NoopResetEnv's methods and helper funcs --- examples/atari/atari_wrapper.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 8828d4b21..21fd2cbf8 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -3,6 +3,7 @@ import logging import warnings from collections import deque +from typing import Any import cv2 import gymnasium as gym @@ -26,7 +27,7 @@ log = logging.getLogger(__name__) -def _parse_reset_result(reset_result): +def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]: contains_info = ( isinstance(reset_result, tuple) and len(reset_result) == 2 @@ -46,13 +47,14 @@ class NoopResetEnv(gym.Wrapper): :param int noop_max: the maximum value of no-ops to run. """ - def __init__(self, env, noop_max=30) -> None: + def __init__(self, env: gym.Env, noop_max: int = 30) -> None: super().__init__(env) self.noop_max = noop_max self.noop_action = 0 + assert hasattr(env.unwrapped, "get_action_meanings") assert env.unwrapped.get_action_meanings()[0] == "NOOP" - def reset(self, **kwargs): + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: _, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) if hasattr(self.unwrapped.np_random, "integers"): noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) @@ -69,7 +71,7 @@ def reset(self, **kwargs): obs, info, _ = _parse_reset_result(self.env.reset()) if return_info: return obs, info - return obs + return obs, {} class MaxAndSkipEnv(gym.Wrapper): From 2ec7c67694ed657d3f712fd95acc3b88c597de4f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 6 Mar 2024 22:54:25 +0100 Subject: [PATCH 030/115] Use only integers with Generator * Generator does not have randint --- examples/atari/atari_wrapper.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 21fd2cbf8..5d6178da5 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -56,10 +56,7 @@ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: _, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) - if hasattr(self.unwrapped.np_random, "integers"): - noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) - else: - noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) for _ in range(noops): step_result = self.env.step(self.noop_action) if len(step_result) == 4: From e1a85faca419bb74060a19aaf8088b3b92ecb41e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:11:56 +0100 Subject: [PATCH 031/115] Use DataclassPPrintMixin to print collect stats --- examples/offline/atari_bcq.py | 4 +--- examples/offline/atari_cql.py | 4 +--- examples/offline/atari_crr.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 8a2f41517..1f29aa5d1 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -186,9 +186,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) - rew = result.returns_stat.mean - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 5d20342bf..07acbe4c2 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -162,9 +162,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) - rew = result.returns_stat.mean - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 8fe2642fb..8ec57a4bd 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -184,9 +184,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) - rew = result.returns_stat.mean - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() From 8d9e1680f8deac17506a435cd36c9b155ffcd15e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:13:23 +0100 Subject: [PATCH 032/115] Respect mypy typing for vars/args --- examples/atari/atari_network.py | 2 +- examples/atari/atari_ppo_hl.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 2b8288a7c..d52fbc7df 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -221,7 +221,7 @@ def __init__( num_quantiles: int = 200, device: str | int | torch.device = "cpu", ) -> None: - self.action_num = np.prod(action_shape) + self.action_num = int(np.prod(action_shape)) super().__init__(c, h, w, [self.action_num * num_quantiles], device) self.num_quantiles = num_quantiles diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index b492b9c84..03272aa5b 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -98,10 +98,11 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: + hidden_sizes = [hidden_sizes] if isinstance(hidden_sizes, int) else hidden_sizes builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), - hidden_sizes=[hidden_sizes], + hidden_sizes=hidden_sizes, lr=lr, lr_scale=icm_lr_scale, reward_scale=icm_reward_scale, From 6f31ac12e935820e0d878d4673ef30aaef41daff Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:14:14 +0100 Subject: [PATCH 033/115] Fix many mypy issues * type annotations funcs * type some vars that mypy asks to * Make mypy happy and check for existing attr of ALE-based envs * Respect super class methods types --- examples/atari/atari_wrapper.py | 107 +++++++++++++++++++------------- 1 file changed, 63 insertions(+), 44 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 5d6178da5..2a9c0cd27 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -3,13 +3,15 @@ import logging import warnings from collections import deque -from typing import Any +from typing import Any, SupportsFloat import cv2 import gymnasium as gym import numpy as np +import numpy.typing as npt from gymnasium import Env +from tianshou.env import BaseVectorEnv from tianshou.highlevel.env import ( EnvFactoryRegistered, EnvMode, @@ -78,16 +80,17 @@ class MaxAndSkipEnv(gym.Wrapper): :param int skip: number of `skip`-th frame. """ - def __init__(self, env, skip=4) -> None: + def __init__(self, env: gym.Env, skip: int = 4) -> None: super().__init__(env) self._skip = skip - def step(self, action): + def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: """Step the environment with the given action. Repeat action, sum reward, and max over last observations. """ - obs_list, total_reward = [], 0.0 + obs_list = [] + total_reward = 0.0 new_step_api = False for _ in range(self._skip): step_result = self.env.step(action) @@ -98,7 +101,7 @@ def step(self, action): done = term or trunc new_step_api = True obs_list.append(obs) - total_reward += reward + total_reward += float(reward) if done: break max_frame = np.max(obs_list[-2:], axis=0) @@ -116,13 +119,13 @@ class EpisodicLifeEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) self.lives = 0 self.was_real_done = True self._return_info = False - def step(self, action): + def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: step_result = self.env.step(action) if len(step_result) == 4: obs, reward, done, info = step_result @@ -135,6 +138,7 @@ def step(self, action): self.was_real_done = done # check current lives, make loss of life terminal, then update lives to # handle bonus lives + assert hasattr(self.env.unwrapped, "ale") lives = self.env.unwrapped.ale.lives() if 0 < lives < self.lives: # for Qbert sometimes we stay in lives == 0 condition for a few @@ -147,7 +151,7 @@ def step(self, action): return obs, reward, term, trunc, info return obs, reward, done, info - def reset(self, **kwargs): + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: """Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, and @@ -159,10 +163,11 @@ def reset(self, **kwargs): # no-op step to advance from terminal/lost life state step_result = self.env.step(0) obs, info = step_result[0], step_result[-1] + assert hasattr(self.env.unwrapped, "ale") self.lives = self.env.unwrapped.ale.lives() if self._return_info: return obs, info - return obs + return obs, {} class FireResetEnv(gym.Wrapper): @@ -173,15 +178,16 @@ class FireResetEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) + assert hasattr(env.unwrapped, "get_action_meanings") assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 - def reset(self, **kwargs): + def reset(self, **kwargs: Any) -> tuple[Any, dict]: _, _, return_info = _parse_reset_result(self.env.reset(**kwargs)) obs = self.env.step(1)[0] - return (obs, {}) if return_info else obs + return obs, {} class WarpFrame(gym.ObservationWrapper): @@ -190,7 +196,7 @@ class WarpFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) self.size = 84 self.observation_space = gym.spaces.Box( @@ -200,7 +206,7 @@ def __init__(self, env) -> None: dtype=env.observation_space.dtype, ) - def observation(self, frame): + def observation(self, frame: np.ndarray) -> np.ndarray: """Returns the current observation from a frame.""" frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) @@ -212,7 +218,7 @@ class ScaledFloatFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) low = np.min(env.observation_space.low) high = np.max(env.observation_space.high) @@ -225,7 +231,7 @@ def __init__(self, env) -> None: dtype=np.float32, ) - def observation(self, observation): + def observation(self, observation: np.ndarray) -> np.ndarray: return (observation - self.bias) / self.scale @@ -235,13 +241,13 @@ class ClipRewardEnv(gym.RewardWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) self.reward_range = (-1, 1) - def reward(self, reward): + def reward(self, reward: SupportsFloat) -> int: """Bin reward to {+1, 0, -1} by its sign. Note: np.sign(0) == 0.""" - return np.sign(reward) + return np.sign(float(reward)) class FrameStack(gym.Wrapper): @@ -251,10 +257,10 @@ class FrameStack(gym.Wrapper): :param int n_frames: the number of frames to stack. """ - def __init__(self, env, n_frames) -> None: + def __init__(self, env: gym.Env, n_frames: int) -> None: super().__init__(env) - self.n_frames = n_frames - self.frames = deque([], maxlen=n_frames) + self.n_frames: int = n_frames + self.frames: deque[tuple[Any, ...]] = deque([], maxlen=n_frames) shape = (n_frames, *env.observation_space.shape) self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), @@ -263,11 +269,11 @@ def __init__(self, env, n_frames) -> None: dtype=env.observation_space.dtype, ) - def reset(self, **kwargs): + def reset(self, **kwargs: Any) -> tuple[npt.NDArray, dict]: obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) for _ in range(self.n_frames): self.frames.append(obs) - return (self._get_ob(), info) if return_info else self._get_ob() + return (self._get_ob(), info) if return_info else (self._get_ob(), {}) def step(self, action): step_result = self.env.step(action) @@ -282,19 +288,27 @@ def step(self, action): return self._get_ob(), reward, term, trunc, info return self._get_ob(), reward, done, info - def _get_ob(self): + def _get_ob(self) -> npt.NDArray: # the original wrapper use `LazyFrames` but since we use np buffer, # it has no effect return np.stack(self.frames, axis=0) def wrap_deepmind( - env: Env, - episode_life=True, - clip_rewards=True, - frame_stack=4, - scale=False, - warp_frame=True, + env: gym.Env, + episode_life: bool = True, + clip_rewards: bool = True, + frame_stack: int = 4, + scale: bool = False, + warp_frame: bool = True, +) -> ( + MaxAndSkipEnv + | EpisodicLifeEnv + | FireResetEnv + | WarpFrame + | ScaledFloatFrame + | ClipRewardEnv + | FrameStack ): """Configure environment for DeepMind-style Atari. @@ -310,29 +324,34 @@ def wrap_deepmind( """ env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) + assert hasattr(env.unwrapped, "get_action_meanings") # for mypy + + wrapped_env: MaxAndSkipEnv | EpisodicLifeEnv | FireResetEnv | WarpFrame | ScaledFloatFrame | ClipRewardEnv | FrameStack = ( + env + ) if episode_life: - env = EpisodicLifeEnv(env) + wrapped_env = EpisodicLifeEnv(wrapped_env) if "FIRE" in env.unwrapped.get_action_meanings(): - env = FireResetEnv(env) + wrapped_env = FireResetEnv(wrapped_env) if warp_frame: - env = WarpFrame(env) + wrapped_env = WarpFrame(wrapped_env) if scale: - env = ScaledFloatFrame(env) + wrapped_env = ScaledFloatFrame(wrapped_env) if clip_rewards: - env = ClipRewardEnv(env) + wrapped_env = ClipRewardEnv(wrapped_env) if frame_stack: - env = FrameStack(env, frame_stack) - return env + wrapped_env = FrameStack(wrapped_env, frame_stack) + return wrapped_env def make_atari_env( - task, - seed, - training_num, - test_num, + task: str, + seed: int, + training_num: int, + test_num: int, scale: int | bool = False, frame_stack: int = 4, -): +) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: """Wrapper function for Atari env. If EnvPool is installed, it will automatically switch to EnvPool's Atari env. @@ -370,7 +389,7 @@ def __init__( envpool_factory=envpool_factory, ) - def create_env(self, mode: EnvMode) -> Env: + def create_env(self, mode: EnvMode) -> gym.Env: env = super().create_env(mode) is_train = mode == EnvMode.TRAIN return wrap_deepmind( From 21e380577187c2b8d33ec4622bbd5f32e86ad219 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 8 Mar 2024 08:49:25 +0100 Subject: [PATCH 034/115] Rename var to resolve ambiguity for mypy --- test/modelbased/test_psrl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 05691645f..2be879b7a 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -121,8 +121,8 @@ def stop_fn(mean_rewards: float) -> bool: policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - result.pprint_asdict() + stats = test_collector.collect(n_episode=args.test_num, render=args.render) + stats.pprint_asdict() elif env.spec.reward_threshold: assert result.best_reward >= env.spec.reward_threshold From 010395d1b044ef490d08eff55b74abd46bbe7e2a Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 8 Mar 2024 09:37:56 +0100 Subject: [PATCH 035/115] Fix mypy issues (see below) * Add missing type annotations to func signatures * Type ambiguous vars used in if-else conditional blocks to resolve mypy confusion --- test/base/test_policy.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 9fe6f8c3a..6429bf7e5 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -1,10 +1,12 @@ +from collections.abc import Callable + import gymnasium as gym import numpy as np import pytest import torch from torch.distributions import Categorical, Independent, Normal -from tianshou.policy import PPOPolicy +from tianshou.policy import BasePolicy, PPOPolicy from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.discrete import Actor @@ -12,13 +14,16 @@ obs_shape = (5,) -def _to_hashable(x: np.ndarray | int): +def _to_hashable(x: np.ndarray | int) -> int | tuple[list]: return x if isinstance(x, int) else tuple(x.tolist()) @pytest.fixture(params=["continuous", "discrete"]) -def policy(request): +def policy(request: pytest.FixtureRequest) -> PPOPolicy: action_type = request.param + action_space: gym.spaces.Box | gym.spaces.Discrete + actor: Actor | ActorProb + dist_fn: Callable[[torch.Tensor], torch.distributions.Distribution] if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) actor = ActorProb( @@ -43,7 +48,8 @@ def policy(request): actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3) - policy: PPOPolicy = PPOPolicy( + policy: BasePolicy + policy = PPOPolicy( actor=actor, critic=critic, dist_fn=dist_fn, @@ -56,7 +62,7 @@ def policy(request): class TestPolicyBasics: - def test_get_action(self, policy) -> None: + def test_get_action(self, policy: PPOPolicy) -> None: sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] From ed70e82c211bd7e1cb3863199f5eafb50e64c231 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 8 Mar 2024 11:14:30 +0100 Subject: [PATCH 036/115] Add type annotations to DummyDataset and FiniteEnv --- test/base/test_env_finite.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index d1e780251..9a7da5b9c 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -2,6 +2,8 @@ import copy from collections import Counter +from collections.abc import Iterator +from typing import Any import gymnasium as gym import numpy as np @@ -14,20 +16,20 @@ class DummyDataset(Dataset): - def __init__(self, length) -> None: + def __init__(self, length: int) -> None: self.length = length self.episodes = [3 * i % 5 + 1 for i in range(self.length)] - def __getitem__(self, index): + def __getitem__(self, index: int) -> tuple[int, int]: assert 0 <= index < self.length return index, self.episodes[index] - def __len__(self): + def __len__(self) -> int: return self.length class FiniteEnv(gym.Env): - def __init__(self, dataset, num_replicas, rank) -> None: + def __init__(self, dataset: Dataset, num_replicas: int | None, rank: int | None) -> None: self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -36,9 +38,14 @@ def __init__(self, dataset, num_replicas, rank) -> None: sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None, ) - self.iterator = None - - def reset(self): + self.iterator: Iterator | None = None + + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[Any, dict[str, Any]]: if self.iterator is None: self.iterator = iter(self.loader) try: @@ -49,7 +56,7 @@ def reset(self): self.iterator = None return None, {} - def step(self, action): + def step(self, action: int) -> tuple[int, float, bool, bool, dict[str, Any]]: self.current_step += 1 assert self.current_step <= self.step_count return ( From 678c1c3b78593d96a855bcc0af4df58cfb9ae6e9 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 15 Mar 2024 00:40:30 +0100 Subject: [PATCH 037/115] Fix mypy issues * Add missing type annotations to funcs * Add typing to some vars for clarify for mypy --- test/base/test_env_finite.py | 88 ++++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 28 deletions(-) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 9a7da5b9c..bb0cd5318 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -2,16 +2,20 @@ import copy from collections import Counter -from collections.abc import Iterator +from collections.abc import Callable, Iterator, Sequence from typing import Any import gymnasium as gym import numpy as np +import numpy.typing as npt +import torch from gymnasium.spaces import Box from torch.utils.data import DataLoader, Dataset, DistributedSampler from tianshou.data import Batch, Collector +from tianshou.data.types import BatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv +from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type from tianshou.policy import BasePolicy @@ -69,42 +73,48 @@ def step(self, action: int) -> tuple[int, float, bool, bool, dict[str, Any]]: class FiniteVectorEnv(BaseVectorEnv): - def __init__(self, env_fns, **kwargs) -> None: + def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: super().__init__(env_fns, **kwargs) - self._alive_env_ids = set() + self._alive_env_ids: set[int] = set() self._reset_alive_envs() - self._default_obs = self._default_info = None + self._default_obs: np.ndarray | None = None + self._default_info: dict | None = None + self.tracker: MetricTracker - def _reset_alive_envs(self): + def _reset_alive_envs(self) -> None: if not self._alive_env_ids: # starting or running out self._alive_env_ids = set(range(self.env_num)) # to workaround with tianshou's buffer and batch - def _set_default_obs(self, obs): + def _set_default_obs(self, obs: np.ndarray) -> None: if obs is not None and self._default_obs is None: self._default_obs = copy.deepcopy(obs) - def _set_default_info(self, info): + def _set_default_info(self, info: dict) -> None: if info is not None and self._default_info is None: self._default_info = copy.deepcopy(info) - def _get_default_obs(self): + def _get_default_obs(self) -> np.ndarray | None: return copy.deepcopy(self._default_obs) - def _get_default_info(self): + def _get_default_info(self) -> dict | None: return copy.deepcopy(self._default_info) # END - def reset(self, id=None): + def reset( + self, + id: int | list[int] | np.ndarray | None = None, + **kwargs: Any, + ) -> tuple[np.ndarray, dict | list[dict | None]]: id = self._wrap_id(id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - obs = [None] * len(id) - infos = [None] * len(id) + obs: list[npt.ArrayLike | None] = [None] * len(id) + infos: list[dict | None] = [None] * len(id) id2idx = {i: k for k, i in enumerate(id)} if request_id: for k, o, info in zip(request_id, *super().reset(request_id), strict=True): @@ -128,26 +138,32 @@ def reset(self, id=None): self.reset() raise StopIteration + obs = [o for o in obs if o is not None] + return np.stack(obs), infos - def step(self, action, id=None): - id = self._wrap_id(id) - id2idx = {i: k for k, i in enumerate(id)} - request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - result = [[None, 0.0, False, False, None] for _ in range(len(id))] + def step( + self, + action: np.ndarray | torch.Tensor, + id: int | list[int] | np.ndarray | None = None, + ) -> gym_new_venv_step_type: + ids: list[int] | np.ndarray = self._wrap_id(id) + id2idx = {i: k for k, i in enumerate(ids)} + request_id = list(filter(lambda i: i in self._alive_env_ids, ids)) + result: list[tuple] = [(None, 0.0, False, False, None) for _ in range(len(ids))] # ask super to step alive envs and remap to current index if request_id: valid_act = np.stack([action[id2idx[i]] for i in request_id]) - for i, r in zip( + for i, (r_obs, r_reward, r_term, r_trunc, r_info) in zip( request_id, zip(*super().step(valid_act, request_id), strict=True), strict=True, ): - result[id2idx[i]] = r + result[id2idx[i]] = (r_obs, r_reward, r_term, r_trunc, r_info) # logging - for i, r in zip(id, result, strict=True): + for i, r in zip(ids, result, strict=True): if i in self._alive_env_ids: self.tracker.log(*r) @@ -160,7 +176,18 @@ def step(self, action, id=None): if result[i][-1] is None: result[i][-1] = self._get_default_info() - return list(map(np.stack, zip(*result, strict=True))) + obs_list, rew_list, term_list, trunc_list, info_list = zip(*result, strict=True) + try: + obs_stack = np.stack(obs_list) + except ValueError: # different len(obs) + obs_stack = np.array(obs_list, dtype=object) + return ( + obs_stack, + np.stack(rew_list), + np.stack(term_list), + np.stack(trunc_list), + np.stack(info_list), + ) class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv): @@ -175,23 +202,28 @@ class AnyPolicy(BasePolicy): def __init__(self) -> None: super().__init__(action_space=Box(-1, 1, (1,))) - def forward(self, batch, state=None): + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: return Batch(act=np.stack([1] * len(batch))) - def learn(self, batch): + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> None: pass -def _finite_env_factory(dataset, num_replicas, rank): +def _finite_env_factory(dataset: Dataset, num_replicas: int, rank: int) -> Callable[[], FiniteEnv]: return lambda: FiniteEnv(dataset, num_replicas, rank) class MetricTracker: def __init__(self) -> None: - self.counter = Counter() - self.finished = set() + self.counter: Counter = Counter() + self.finished: set[int] = set() - def log(self, obs, rew, terminated, truncated, info): + def log(self, obs: Any, rew: float, terminated: bool, truncated: bool, info: dict) -> None: assert rew == 1.0 done = terminated or truncated index = info["sample"] @@ -200,7 +232,7 @@ def log(self, obs, rew, terminated, truncated, info): self.finished.add(index) self.counter[index] += 1 - def validate(self): + def validate(self) -> None: assert len(self.finished) == 100 for k, v in self.counter.items(): assert v == k * 3 % 5 + 1 From 70e6dc1e7c6aec0a1a86b71588dba07d2e780ca1 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 15 Mar 2024 00:53:36 +0100 Subject: [PATCH 038/115] Fix some mypy issues * Rename vars to prevent confusion for mypy * Add missing function type annotations * Type some vars correctly --- test/base/test_env.py | 6 +++--- test/discrete/test_a2c_with_il.py | 9 +++++---- test/pettingzoo/pistonball.py | 8 ++++---- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index edeb3f361..6b6051420 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -30,7 +30,7 @@ envpool = None -def has_ray(): +def has_ray() -> bool: try: import ray # noqa: F401 @@ -39,7 +39,7 @@ def has_ray(): return False -def recurse_comp(a, b): +def recurse_comp(a, b) -> np.bool_ | bool | None: try: if isinstance(a, np.ndarray): if a.dtype == object: @@ -232,7 +232,7 @@ def test_env_obs_dtype() -> None: envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) obs, info = envs.reset() assert obs.dtype == object - obs = envs.step([1, 1, 1, 1])[0] + obs = envs.step(np.array([1, 1, 1, 1]))[0] assert obs.dtype == object diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 3ca7ce6cf..3ad2f51ae 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -89,7 +89,8 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy: A2CPolicy = A2CPolicy( + policy: BasePolicy + policy = A2CPolicy( actor=actor, critic=critic, optim=optim, @@ -152,10 +153,10 @@ def stop_fn(mean_rewards: float) -> bool: # if args.task == 'CartPole-v0': # env.spec.reward_threshold = 190 # lower the goal net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - net = Actor(net, args.action_shape, device=args.device).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + optim = torch.optim.Adam(actor.parameters(), lr=args.il_lr) il_policy: ImitationPolicy = ImitationPolicy( - actor=net, + actor=actor, optim=optim, action_space=env.action_space, ) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 0dd750b4c..1e9ca2b5c 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -8,7 +8,7 @@ from pettingzoo.butterfly import pistonball_v6 from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager @@ -68,7 +68,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def get_env(args: argparse.Namespace = get_args()): +def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: return PettingZooEnv(pistonball_v6.env(continuous=False, n_pistons=args.n_pistons)) @@ -116,7 +116,7 @@ def train_agent( args: argparse.Namespace = get_args(), agents: list[BasePolicy] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[dict, BasePolicy]: +) -> tuple[InfoStats, BasePolicy]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -154,7 +154,7 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: [agent.set_eps(args.eps_test) for agent in policy.policies.values()] - def reward_metric(rews): + def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer From 2db9b20bc1d08e00a430882b714db775fa36f112 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 15 Mar 2024 16:58:17 +0100 Subject: [PATCH 039/115] Fix many mypy issues related to: * Function type annotations * Typing some vars to avoid confusion * Rewrite some assignments that respect input argument types --- test/base/test_collector.py | 47 +++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index f7a24a86e..8a1d34a08 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,3 +1,6 @@ +from collections.abc import Callable, Sequence +from typing import Any + import gymnasium as gym import numpy as np import pytest @@ -13,8 +16,11 @@ ReplayBuffer, VectorReplayBuffer, ) +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy +from tianshou.policy.base import TrainingStats try: import envpool @@ -31,9 +37,9 @@ class MyPolicy(BasePolicy): def __init__( self, action_space: gym.spaces.Space | None = None, - dict_state=False, - need_state=True, - action_shape=None, + dict_state: bool = False, + need_state: bool = True, + action_shape: Sequence[int] | int | None = None, ) -> None: """Mock policy for testing. @@ -47,28 +53,38 @@ def __init__( self.need_state = need_state self.action_shape = action_shape - def forward(self, batch, state=None): + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: if self.need_state: if state is None: state = np.zeros((len(batch.obs), 2)) else: state += 1 if self.dict_state: - action_shape = self.action_shape if self.action_shape else len(batch.obs["index"]) + if self.action_shape: + action_shape = self.action_shape + elif isinstance(batch.obs, BatchProtocol): + action_shape = len(batch.obs["index"]) + else: + action_shape = len(batch.obs) return Batch(act=np.ones(action_shape), state=state) action_shape = self.action_shape if self.action_shape else len(batch.obs) return Batch(act=np.ones(action_shape), state=state) - def learn(self): - pass + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: + raise NotImplementedError class Logger: - def __init__(self, writer) -> None: + def __init__(self, writer: SummaryWriter) -> None: self.cnt = 0 self.writer = writer - def preprocess_fn(self, **kwargs): + def preprocess_fn(self, **kwargs: Any) -> Batch: # modify info before adding into the buffer, and recorded into tfb # if obs && env_id exist -> reset # if obs_next/rew/done/info/env_id exist -> normal step @@ -82,7 +98,7 @@ def preprocess_fn(self, **kwargs): return Batch() @staticmethod - def single_preprocess_fn(**kwargs): + def single_preprocess_fn(**kwargs: Any) -> Batch: # same as above, without tfb if "rew" in kwargs: info = kwargs["info"] @@ -92,7 +108,7 @@ def single_preprocess_fn(**kwargs): @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) -def test_collector(gym_reset_kwargs) -> None: +def test_collector(gym_reset_kwargs: None | dict) -> None: writer = SummaryWriter("log/collector") logger = Logger(writer) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] @@ -212,14 +228,18 @@ def test_collector(gym_reset_kwargs) -> None: # test NXEnv for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) + + def create_env(i: int, t: str) -> Callable[[], NXEnv]: + return lambda: NXEnv(i, t) + + envs = SubprocVectorEnv([create_env(x, obs_type) for x in [5, 10, 15, 20]]) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs) assert c3.buffer.obs.dtype == object @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) -def test_collector_with_async(gym_reset_kwargs) -> None: +def test_collector_with_async(gym_reset_kwargs: None | dict) -> None: env_lens = [2, 3, 4, 5] writer = SummaryWriter("log/async_collector") logger = Logger(writer) @@ -422,6 +442,7 @@ def test_collector_with_ma() -> None: ) rew = c1.collect(n_step=12).returns assert rew.shape == (2, 4) and np.all(rew == 1), rew + rew: list | np.ndarray rew = c1.collect(n_episode=8).returns assert rew.shape == (8, 4) assert np.all(rew == 1) From b4d9450445f2a82be343f7693ab670e9cdf1b0b5 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 15 Mar 2024 22:53:40 +0100 Subject: [PATCH 040/115] Fix mypy issues: * add missing function type annotations * use mandatory kw args properly * access attrs dataclass with . notation --- test/pettingzoo/pistonball_continuous.py | 26 +++++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 8bbb20cfd..d53c38323 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -1,6 +1,7 @@ import argparse import os import warnings +from dataclasses import asdict from typing import Any import gymnasium as gym @@ -131,7 +132,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def get_env(args: argparse.Namespace = get_args()): +def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: return PettingZooEnv(pistonball_v6.env(continuous=True, n_pistons=args.n_pistons)) @@ -139,7 +140,7 @@ def get_agents( args: argparse.Namespace = get_args(), agents: list[BasePolicy] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[BasePolicy, list[torch.optim.Optimizer], list]: +) -> tuple[BasePolicy, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -185,10 +186,10 @@ def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) agent: PPOPolicy = PPOPolicy( - actor, - critic, - optim, - dist, + actor=actor, + critic=critic, + optim=optim, + dist_fn=dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -207,7 +208,12 @@ def dist(*logits: torch.Tensor) -> Distribution: agents.append(agent) optims.append(optim) - policy = MultiAgentPolicyManager(agents, env, action_scaling=True, action_bound_method="clip") + policy = MultiAgentPolicyManager( + policies=agents, + env=env, + action_scaling=True, + action_bound_method="clip", + ) return policy, optims, env.agents @@ -247,7 +253,7 @@ def save_best_fn(policy: BasePolicy) -> None: def stop_fn(mean_rewards: float) -> bool: return False - def reward_metric(rews): + def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer @@ -267,7 +273,7 @@ def reward_metric(rews): resume_from_log=args.resume, ).run() - return result, policy + return asdict(result), policy def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = None) -> None: @@ -280,5 +286,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) - rews, lens = collector_result["rews"], collector_result["lens"] + rews, lens = collector_result["rews"], collector_result.lens print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") From 62b884d21be95824e9e35c3807a972a3eddc14a0 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 15 Mar 2024 22:56:38 +0100 Subject: [PATCH 041/115] ignore mypy check --- test/base/test_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 86958184f..4284dcf76 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -551,7 +551,7 @@ def test_batch_standard_compatibility() -> None: batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])) batch_mean = np.mean(batch) assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst` - assert sorted(batch_mean.keys()) == ["a", "b", "c"] + assert sorted(batch_mean.keys()) == ["a", "b", "c"] # type: ignore with pytest.raises(TypeError): len(batch_mean) assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) From 2b4ffa794a804ff41866fddb64fcc10c6a2e9b1f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 16 Mar 2024 10:02:56 +0100 Subject: [PATCH 042/115] Fix mypy issues: * add missing function type annotation * use dot notation to access dataclass attrs * extend some typing * adapt output to respect existing type hints --- test/pettingzoo/tic_tac_toe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 62b66dfa4..43149b9da 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -1,6 +1,7 @@ import argparse import os from copy import deepcopy +from dataclasses import asdict from functools import partial import gymnasium @@ -18,7 +19,7 @@ from tianshou.utils.net.common import Net -def get_env(render_mode: str | None = None): +def get_env(render_mode: str | None = None) -> PettingZooEnv: return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode)) @@ -95,7 +96,7 @@ def get_agents( agent_learn: BasePolicy | None = None, agent_opponent: BasePolicy | None = None, optim: torch.optim.Optimizer | None = None, -) -> tuple[BasePolicy, torch.optim.Optimizer, list]: +) -> tuple[BasePolicy, torch.optim.Optimizer | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -193,7 +194,7 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) - def reward_metric(rews): + def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, args.agent_id - 1] # trainer @@ -216,7 +217,7 @@ def reward_metric(rews): reward_metric=reward_metric, ).run() - return result, policy.policies[agents[args.agent_id - 1]] + return asdict(result), policy.policies[agents[args.agent_id - 1]] def watch( @@ -230,5 +231,5 @@ def watch( policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] + rews, lens = result["rews"], result.lens print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") From 2fb1d95cb6078636dc3b7ad22f2b269c63e8f858 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 16 Mar 2024 10:28:50 +0100 Subject: [PATCH 043/115] Add missing type annotations --- test/pettingzoo/pistonball.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 1e9ca2b5c..e9e2189e2 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -76,7 +76,7 @@ def get_agents( args: argparse.Namespace = get_args(), agents: list[BasePolicy] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[BasePolicy, list[torch.optim.Optimizer], list]: +) -> tuple[BasePolicy, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -191,5 +191,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] + rews, lens = result["rews"], result.lens print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") From 0398ef7df326e7ebb0e3686e727fc822858ae28e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 16 Mar 2024 11:36:57 +0100 Subject: [PATCH 044/115] Fix some mypy issues: * Re-type sleep to float (as it seems that is actually the way it is used) * Add some typing to funcs --- test/base/env.py | 2 +- test/base/test_env.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 8a2de26cc..820b26846 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -15,7 +15,7 @@ class MyTestEnv(gym.Env): def __init__( self, size: int, - sleep: int = 0, + sleep: float = 0, dict_state: bool = False, recurse_state: bool = False, ma_rew: int = 0, diff --git a/test/base/test_env.py b/test/base/test_env.py index 6b6051420..69434d7de 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,5 +1,6 @@ import sys import time +from typing import Any import gymnasium as gym import numpy as np @@ -39,7 +40,7 @@ def has_ray() -> bool: return False -def recurse_comp(a, b) -> np.bool_ | bool | None: +def recurse_comp(a: np.ndarray | list | tuple | dict, b: Any) -> np.bool_ | bool | None: try: if isinstance(a, np.ndarray): if a.dtype == object: @@ -53,7 +54,7 @@ def recurse_comp(a, b) -> np.bool_ | bool | None: return False -def test_async_env(size=10000, num=8, sleep=0.1) -> None: +def test_async_env(size: int = 10000, num: int = 8, sleep: float = 0.1) -> None: # simplify the test case, just keep stepping env_fns = [ lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True) From cc04d08f64ce6a96cce2090966445aa1857102c4 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 16 Mar 2024 11:49:36 +0100 Subject: [PATCH 045/115] Add missing type annotations to funcs --- test/base/test_env.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index 69434d7de..cdfff6059 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -18,6 +18,7 @@ VectorEnvNormObs, ) from tianshou.env.gym_wrappers import TruncatedAsTerminated +from tianshou.env.venvs import BaseVectorEnv from tianshou.utils import RunningMeanStd if __name__ == "__main__": @@ -107,7 +108,12 @@ def test_async_env(size: int = 10000, num: int = 8, sleep: float = 0.1) -> None: assert spent_time < 6.0 * sleep * num / (num + 1) -def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: +def test_async_check_id( + size: int = 100, + num: int = 4, + sleep: float = 0.2, + timeout: float = 0.7, +) -> None: env_fns = [ lambda: MyTestEnv(size=size, sleep=sleep * 2), lambda: MyTestEnv(size=size, sleep=sleep * 3), @@ -155,7 +161,7 @@ def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: assert total_pass >= 2 -def test_vecenv(size=10, num=8, sleep=0.001) -> None: +def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None: env_fns = [ lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True) for i in range(size, size + num) @@ -197,7 +203,7 @@ def test_vecenv(size=10, num=8, sleep=0.001) -> None: for i, v in enumerate(venv): print(f"{type(v)}: {t[i]:.6f}s") - def assert_get(v, expected): + def assert_get(v: BaseVectorEnv, expected: list) -> None: assert v.get_env_attr("size") == expected assert v.get_env_attr("size", id=0) == [expected[0]] assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3] @@ -237,7 +243,7 @@ def test_env_obs_dtype() -> None: assert obs.dtype == object -def test_env_reset_optional_kwargs(size=10000, num=8) -> None: +def test_env_reset_optional_kwargs(size: int = 10000, num: int = 8) -> None: env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] if has_ray(): From d265b2fc3f1639d503ea63e5972543e9cd3a2ee0 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 16 Mar 2024 16:52:49 +0100 Subject: [PATCH 046/115] Fix mypy issues * Use ndarray input to step() * Use proper type for list elements --- test/base/test_env.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index cdfff6059..2d8afe0be 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -180,7 +180,7 @@ def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None: for a in action_list: o = [] for v in venv: - A, B, C, D, E = v.step([a] * num) + A, B, C, D, E = v.step(np.array([a] * num)) if sum(C + D): A, _ = v.reset(np.where(C + D)[0]) o.append([A, B, C, D, E]) @@ -191,12 +191,12 @@ def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None: assert recurse_comp(infos[0], info) if __name__ == "__main__": - t = [0] * len(venv) + t = [0.0] * len(venv) for i, e in enumerate(venv): t[i] = time.time() e.reset() for a in action_list: - done = e.step([a] * num)[2] + done = e.step(np.array([a] * num))[2] if sum(done) > 0: e.reset(np.where(done)[0]) t[i] = time.time() - t[i] @@ -230,8 +230,9 @@ def test_attr_unwrapped() -> None: train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")]) train_envs.set_env_attr("test_attribute", 1337) assert train_envs.get_env_attr("test_attribute") == [1337] - assert hasattr(train_envs.workers[0].env, "test_attribute") - assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute") + # mypy doesn't know but BaseVectorEnv takes the reserved keys in gym.Env (one of which is env) + assert hasattr(train_envs.workers[0].env, "test_attribute") # type: ignore + assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute") # type: ignore def test_env_obs_dtype() -> None: From 48af7a536320077c29c12721d61159eba9537990 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 16 Mar 2024 17:09:46 +0100 Subject: [PATCH 047/115] Pass correct typed param env_fns to SubprocVectorEnv --- test/base/test_collector.py | 7 +++---- test/base/test_env.py | 6 +++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 8a1d34a08..d1cd7472d 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -226,12 +226,11 @@ def test_collector(gym_reset_kwargs: None | dict) -> None: with pytest.raises(TypeError): c2.collect() + def create_env(i: int, t: str) -> Callable[[], NXEnv]: + return lambda: NXEnv(i, t) + # test NXEnv for obs_type in ["array", "object"]: - - def create_env(i: int, t: str) -> Callable[[], NXEnv]: - return lambda: NXEnv(i, t) - envs = SubprocVectorEnv([create_env(x, obs_type) for x in [5, 10, 15, 20]]) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs) diff --git a/test/base/test_env.py b/test/base/test_env.py index 2d8afe0be..bdf924998 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,5 +1,6 @@ import sys import time +from collections.abc import Callable from typing import Any import gymnasium as gym @@ -236,8 +237,11 @@ def test_attr_unwrapped() -> None: def test_env_obs_dtype() -> None: + def create_env(i: int, t: str) -> Callable[[], NXEnv]: + return lambda: NXEnv(i, t) + for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) + envs = SubprocVectorEnv([create_env(x, obs_type) for x in [5, 10, 15, 20]]) obs, info = envs.reset() assert obs.dtype == object obs = envs.step(np.array([1, 1, 1, 1]))[0] From 4330e33605b50d0393130adc43a4540bd49fb233 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 16 Mar 2024 17:48:54 +0100 Subject: [PATCH 048/115] Fix more mypy issues: run_align_norm_obs() * Mypy complains about unreachable statements but the logic in the conditional is sound and they are in fact reachable * info can be dict or list[dict]. Take into account the latter too. * Add type anno to DummyEnv.step() --- test/base/test_env.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index bdf924998..d66d0799d 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,7 +1,7 @@ import sys import time from collections.abc import Callable -from typing import Any +from typing import Any, Literal import gymnasium as gym import numpy as np @@ -274,13 +274,18 @@ def test_venv_wrapper_gym(num_envs: int = 4) -> None: assert obs.shape[0] == len(info) == num_envs -def run_align_norm_obs(raw_env, train_env, test_env, action_list): - def reset_result_to_obs(reset_result): +def run_align_norm_obs( + raw_env: DummyVectorEnv, + train_env: VectorEnvNormObs, + test_env: VectorEnvNormObs, + action_list: list[np.ndarray], +) -> None: + def reset_result_to_obs(reset_result: tuple[np.ndarray, dict | list[dict]]) -> np.ndarray: """Extract observation from reset result (result is possibly a tuple containing info).""" if isinstance(reset_result, tuple) and len(reset_result) == 2: obs, _ = reset_result else: - obs = reset_result + obs = reset_result # type: ignore return obs eps = np.finfo(np.float32).eps.item() @@ -295,7 +300,7 @@ def reset_result_to_obs(reset_result): obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: - obs, rew, done, info = step_result + obs, rew, done, info = step_result # type: ignore raw_obs.append(obs) if np.any(done): reset_result = raw_env.reset(np.where(done)[0]) @@ -306,7 +311,7 @@ def reset_result_to_obs(reset_result): obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: - obs, rew, done, info = step_result + obs, rew, done, info = step_result # type: ignore train_obs.append(obs) if np.any(done): reset_result = train_env.reset(np.where(done)[0]) @@ -330,7 +335,7 @@ def reset_result_to_obs(reset_result): obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: - obs, rew, done, info = step_result + obs, rew, done, info = step_result # type: ignore test_obs.append(obs) if np.any(done): reset_result = test_env.reset(np.where(done)[0]) @@ -361,7 +366,7 @@ def __init__(self) -> None: self.action_space = gym.spaces.Box(low=-1.0, high=2.0, shape=(4,), dtype=np.float32) self.observation_space = gym.spaces.Discrete(2) - def step(self, act): + def step(self, act: Any) -> tuple[Any, Literal[-1], Literal[False], Literal[True], dict]: return self.observation_space.sample(), -1, False, True, {} bsz = 10 @@ -416,9 +421,15 @@ def test_venv_wrapper_envpool_gym_reset_return_info() -> None: ) obs, info = env.reset() assert obs.shape[0] == num_envs - for _, v in info.items(): - if not isinstance(v, dict): - assert v.shape[0] == num_envs + if isinstance(info, dict): + for _, v in info.items(): + if not isinstance(v, dict): + assert v.shape[0] == num_envs + else: + for _info in info: + for _, v in _info.items(): + if not isinstance(v, dict): + assert v.shape[0] == num_envs if __name__ == "__main__": From 91a3ee2cc7b84aad423ced1be549f692d7cd2899 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 17 Mar 2024 12:01:26 +0100 Subject: [PATCH 049/115] Bugfix: Tuple item assignment --- test/base/test_env_finite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index bb0cd5318..0ecdb849a 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -150,7 +150,7 @@ def step( ids: list[int] | np.ndarray = self._wrap_id(id) id2idx = {i: k for k, i in enumerate(ids)} request_id = list(filter(lambda i: i in self._alive_env_ids, ids)) - result: list[tuple] = [(None, 0.0, False, False, None) for _ in range(len(ids))] + result: list[list] = [[None, 0.0, False, False, None] for _ in range(len(ids))] # ask super to step alive envs and remap to current index if request_id: @@ -160,7 +160,7 @@ def step( zip(*super().step(valid_act, request_id), strict=True), strict=True, ): - result[id2idx[i]] = (r_obs, r_reward, r_term, r_trunc, r_info) + result[id2idx[i]] = [r_obs, r_reward, r_term, r_trunc, r_info] # logging for i, r in zip(ids, result, strict=True): From a481e2f3ec08c045eecb65fd087acaba09625ead Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Mon, 18 Mar 2024 13:39:10 +0100 Subject: [PATCH 050/115] Make mypy happy and use [] instea of . notation * mypy does not know that value of keys are Batch --- test/base/test_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 47ecd22d8..e609749ae 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -190,9 +190,9 @@ def test_ignore_obs_next(size: int = 10) -> None: ], ), ) - assert np.allclose(data.info["if"], data2.info["if"]) + assert np.allclose(data["info"]["if"], data2["info"]["if"]) assert np.allclose( - data.info["if"], + data["info"]["if"], np.array( [ [0, 0, 0, 0], From c6aa77b87ac32f73967424498f1278f7f2eb691f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 21 Mar 2024 14:29:02 +0100 Subject: [PATCH 051/115] Typing: extend index type --- tianshou/data/batch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7002b55d6..508e5c9a2 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -3,6 +3,7 @@ from collections.abc import Collection, Iterable, Iterator, Sequence from copy import deepcopy from numbers import Number +from types import EllipsisType from typing import ( Any, Protocol, @@ -17,7 +18,8 @@ import numpy as np import torch -IndexType = slice | int | np.ndarray | list[int] +_SingleIndexType = slice | int | EllipsisType +IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] TBatch = TypeVar("TBatch", bound="BatchProtocol") arr_type = torch.Tensor | np.ndarray From 0de06a49c0b9f19572552457700c7050c9cb28c2 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 21 Mar 2024 14:37:42 +0100 Subject: [PATCH 052/115] Typing in tests: added asserts and cast to remove some mypy errors --- test/base/test_batch.py | 5 +++++ test/base/test_collector.py | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 4284dcf76..e88dd0587 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -2,6 +2,7 @@ import pickle import sys from itertools import starmap +from typing import cast import networkx as nx import numpy as np @@ -160,7 +161,11 @@ def test_batch() -> None: batch5 = Batch(a=np.array([{"index": 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) + # We use setattr b/c the setattr of Batch will actually change the type of the field that is being set! + # However, mypy would not understand this, and rightly expect that batch.b = some_array would lead to + # batch.b being an array (which it is not, it's turned into a Batch instead) batch5.b = np.array([{"index": 1}]) + batch5.b = cast(Batch, batch5.b) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d1cd7472d..f36e49a2d 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -314,9 +314,11 @@ def test_collector_with_dict_state() -> None: batch, _ = c1.buffer.sample(10) c0.buffer.update(c1.buffer) assert len(c0.buffer) in [42, 43] + cur_obs = c0.buffer[:].obs + assert isinstance(cur_obs, Batch) if len(c0.buffer) == 42: assert np.all( - c0.buffer[:].obs.index[..., 0] + cur_obs.index[..., 0] == [ 0, 1, @@ -364,7 +366,7 @@ def test_collector_with_dict_state() -> None: ), c0.buffer[:].obs.index[..., 0] else: assert np.all( - c0.buffer[:].obs.index[..., 0] + cur_obs.index[..., 0] == [ 0, 1, From 0596f0484fb7d81a0f9f1334ba2e45edca77a98f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:32:06 +0100 Subject: [PATCH 053/115] Revert output type to tuple[InfoStats, BasePolicy] * Maintain consistency used throughout codebase, using .best_reward --- test/pettingzoo/pistonball_continuous.py | 6 +++--- test/pettingzoo/tic_tac_toe.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index d53c38323..987f5ca8f 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -1,7 +1,6 @@ import argparse import os import warnings -from dataclasses import asdict from typing import Any import gymnasium as gym @@ -13,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy @@ -221,7 +221,7 @@ def train_agent( args: argparse.Namespace = get_args(), agents: list[BasePolicy] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[dict, BasePolicy]: +) -> tuple[InfoStats, BasePolicy]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -273,7 +273,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: resume_from_log=args.resume, ).run() - return asdict(result), policy + return result, policy def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = None) -> None: diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 43149b9da..a7e3188c6 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -1,7 +1,6 @@ import argparse import os from copy import deepcopy -from dataclasses import asdict from functools import partial import gymnasium @@ -11,6 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy @@ -146,7 +146,7 @@ def train_agent( agent_learn: BasePolicy | None = None, agent_opponent: BasePolicy | None = None, optim: torch.optim.Optimizer | None = None, -) -> tuple[dict, BasePolicy]: +) -> tuple[InfoStats, BasePolicy]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -217,7 +217,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: reward_metric=reward_metric, ).run() - return asdict(result), policy.policies[agents[args.agent_id - 1]] + return result, policy.policies[agents[args.agent_id - 1]] def watch( From 3031c8f42142f5b50d4d639a123f476ba7e127eb Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 22 Mar 2024 14:07:02 +0100 Subject: [PATCH 054/115] For mypy: Store obs/obs_next in new var and assert type --- test/base/test_collector.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index f36e49a2d..09ffb648b 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -126,7 +126,9 @@ def test_collector(gym_reset_kwargs: None | dict) -> None: c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs) assert len(c0.buffer) == 3 assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) - assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) + obs_next = c0.buffer[:].obs_next + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next[..., 0], [1, 2, 1]) keys = np.zeros(100) keys[:3] = 1 assert np.allclose(c0.buffer.info["key"], keys) @@ -139,7 +141,9 @@ def test_collector(gym_reset_kwargs: None | dict) -> None: c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs) assert len(c0.buffer) == 8 assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) - assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + obs_next = c0.buffer[:].obs_next + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c0.buffer.info["key"][:8], 1) for e in c0.buffer.info["env"][:8]: assert isinstance(e, MyTestEnv) @@ -158,7 +162,9 @@ def test_collector(gym_reset_kwargs: None | dict) -> None: valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] assert np.allclose(c1.buffer.obs[:, 0], obs) - assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + obs_next = c1.buffer[:].obs_next + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) keys = np.zeros(100) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] assert np.allclose(c1.buffer.info["key"], keys) @@ -175,8 +181,10 @@ def test_collector(gym_reset_kwargs: None | dict) -> None: valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c1.buffer.obs[:, 0], obs) + obs_next = c1.buffer[:].obs_next + assert isinstance(obs_next, np.ndarray) assert np.allclose( - c1.buffer[:].obs_next[..., 0], + obs_next[..., 0], [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], ) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] @@ -363,7 +371,7 @@ def test_collector_with_dict_state() -> None: 3, 4, ], - ), c0.buffer[:].obs.index[..., 0] + ), cur_obs.index[..., 0] else: assert np.all( cur_obs.index[..., 0] @@ -412,7 +420,7 @@ def test_collector_with_dict_state() -> None: 3, 4, ], - ), c0.buffer[:].obs.index[..., 0] + ), cur_obs.index[..., 0] c2 = Collector( policy, envs, @@ -594,7 +602,9 @@ def test_collector_with_atari_setting() -> None: obs = np.zeros_like(c2.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] assert np.all(c2.buffer.obs == obs) - assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) + obs_next = c2.buffer[:].obs_next + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) # atari multi buffer env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] From 5451119fb1121b1124d7c4eb0ce4a4699af6898f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:39:40 +0100 Subject: [PATCH 055/115] For mypy: store intermediate vars & assert type --- test/base/test_buffer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index e609749ae..f55d441b2 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -53,6 +53,7 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert buf.act.dtype == int assert buf.act.shape == (bufsize, 1) data, indices = buf.sample(bufsize * 2) + assert isinstance(data, Batch) assert (indices < len(buf)).all() assert (data.obs < size).all() assert (data.done >= 0).all() @@ -1133,8 +1134,10 @@ def test_multibuf_stack() -> None: assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indices = np.array(sorted(buf4.sample_indices(0))) assert np.allclose(indices, [*list(range(bufsize)), 9, 10, 14, 15, 19, 20]) + cur_obs = buf4[indices].obs + assert isinstance(cur_obs, np.ndarray) assert np.allclose( - buf4[indices].obs[..., 0], + cur_obs[..., 0], [ [11, 11, 11, 12], [11, 11, 12, 13], @@ -1153,8 +1156,10 @@ def test_multibuf_stack() -> None: [11, 11, 11, 12], ], ) + next_obs = buf4[indices].obs_next + assert isinstance(next_obs, np.ndarray) assert np.allclose( - buf4[indices].obs_next[..., 0], + next_obs[..., 0], [ [11, 11, 12, 13], [11, 12, 13, 14], @@ -1303,11 +1308,15 @@ def test_from_data() -> None: buf = ReplayBuffer.from_data(obs, act, rew, terminated, truncated, done, obs_next) assert len(buf) == 10 batch = buf[3] - assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8")) + cur_obs = batch.obs + assert isinstance(cur_obs, np.ndarray) + assert np.array_equal(cur_obs, 3 * np.ones((3, 3), dtype="uint8")) assert batch.act == 3 assert batch.rew == 3.0 assert not batch.done - assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8")) + next_obs = batch.obs_next + assert isinstance(next_obs, np.ndarray) + assert np.array_equal(next_obs, 4 * np.ones((3, 3), dtype="uint8")) os.remove(path) From a797811d14094d8a399da457aaac1a59570a9f35 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 23 Mar 2024 10:26:11 +0100 Subject: [PATCH 056/115] Use pprint_asdict() instead --- test/pettingzoo/pistonball.py | 3 +-- test/pettingzoo/pistonball_continuous.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index e9e2189e2..8dc77258e 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -191,5 +191,4 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result.lens - print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") + result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 987f5ca8f..5f02637b3 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -286,5 +286,4 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) - rews, lens = collector_result["rews"], collector_result.lens - print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") + collector_result.pprint_asdict() From b09b5812f83042dfed3d0058f46356d81a64cff5 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 23 Mar 2024 10:27:05 +0100 Subject: [PATCH 057/115] Check for none --- test/discrete/test_bdq.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index b1750d49d..ce6eef4a5 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -57,7 +57,10 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) print("Observations shape:", args.state_shape) print("Num branches:", args.num_branches) From 41748f18c6ff8abed563549d34eae639213736b4 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 23 Mar 2024 10:34:40 +0100 Subject: [PATCH 058/115] Use pprint_asdict to print CollectStats --- test/pettingzoo/tic_tac_toe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index a7e3188c6..e5a38c42a 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -231,5 +231,4 @@ def watch( policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result.lens - print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") + result.pprint_asdict() From 12d5e680b931ec8b9dd1018aa2dafabdc64d8537 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 23 Mar 2024 10:43:12 +0100 Subject: [PATCH 059/115] Ignore type as mypy doesn't know that it should be wrong here --- test/base/test_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index f55d441b2..67e678095 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -707,12 +707,12 @@ def test_hdf5() -> None: data = {"not_supported": lambda x: x * x} grp = h5py.Group with pytest.raises(NotImplementedError): - to_hdf5(data, grp) + to_hdf5(data, grp) # type: ignore # ndarray with data type not supported by HDF5 that cannot be pickled data = {"not_supported": np.array(lambda x: x * x)} grp = h5py.Group with pytest.raises(RuntimeError): - to_hdf5(data, grp) + to_hdf5(data, grp) # type: ignore def test_replaybuffermanager() -> None: From 2fdab27eac8430b28e0058bd30a7fd9ddb7842ee Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 23 Mar 2024 14:42:15 +0100 Subject: [PATCH 060/115] Ignore type annotation for step() because env can generate non-scalar rewards: * See issue #1080 --- test/base/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/env.py b/test/base/env.py index 820b26846..2a4392e4f 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -121,7 +121,7 @@ def do_sleep(self) -> None: sleep_time *= self.sleep time.sleep(sleep_time) - def step(self, action: np.ndarray | int): + def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. issue #1080 self.steps += 1 if self._md_action and isinstance(action, np.ndarray): action = action[0] From bda6b10f00e775bb4954399c8ea2e61bb8d66916 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sat, 23 Mar 2024 16:47:54 +0100 Subject: [PATCH 061/115] Make mypy understand what type weight is --- test/base/test_buffer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 67e678095..d1b8e2a3e 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -326,11 +326,14 @@ def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: batch_sample, indices = buf2.sample(10) buf2.update_weight(indices, batch_sample.weight * 0) weight = buf2[np.arange(buf2.maxsize)].weight + assert isinstance(weight, np.ndarray) mask = np.isin(np.arange(buf2.maxsize), indices) - assert np.all(weight[mask] == weight[mask][0]) - assert np.all(weight[~mask] == weight[~mask][0]) - assert weight[~mask][0] < weight[mask][0] - assert weight[mask][0] <= 1 + selected_weight = weight[mask] + unselected_weight = weight[~mask] + assert np.all(selected_weight == selected_weight[0]) + assert np.all(unselected_weight == unselected_weight[0]) + assert unselected_weight[0] < selected_weight[0] + assert selected_weight[0] <= 1 def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4) -> None: From c4f6c2d8beaee2266aeeef672ccc822fa64ce8c6 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 11:12:53 +0100 Subject: [PATCH 062/115] Make mypy happy and use proper type for adding constant to batch --- test/base/test_batch.py | 2 +- test/base/test_collector.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index e88dd0587..f11a8d60e 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -220,7 +220,7 @@ def test_batch_over_batch() -> None: batch5[:, 3] with pytest.raises(IndexError): batch5[:, :, -1] - batch5[:, -1] += 1 + batch5[:, -1] += np.int_(1) assert np.allclose(batch5.a, [1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) with pytest.raises(ValueError): diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 09ffb648b..e46ea7f0d 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -62,6 +62,8 @@ def forward( if self.need_state: if state is None: state = np.zeros((len(batch.obs), 2)) + elif isinstance(state, np.ndarray | BatchProtocol): + state += np.int_(1) else: state += 1 if self.dict_state: From 2d6f5f171df0420ee4914e53a0c47d0ff8176a8c Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 11:14:37 +0100 Subject: [PATCH 063/115] Use setattr/getattr because mypy doesn't know * of dynamically added attributes * ignore ruff checks because ruff would reformat back s.t. mypy will complain (circular) --- test/base/test_stats.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/base/test_stats.py b/test/base/test_stats.py index 537519287..e13f144d7 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -13,7 +13,8 @@ class TestStats: @staticmethod def test_training_stats_wrapper() -> None: train_stats = TrainingStats(train_time=1.0) - train_stats.loss_field = 12 + + setattr(train_stats, "loss_field", 12) # noqa: B010 wrapped_train_stats = DummyTrainingStatsWrapper(train_stats, dummy_field=42) @@ -37,4 +38,8 @@ def test_training_stats_wrapper() -> None: # existing fields, wrapped and not-wrapped, can be mutated wrapped_train_stats.loss_field = 13 wrapped_train_stats.dummy_field = 43 - assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13 + assert ( + getattr(wrapped_train_stats.wrapped_stats, "loss_field") # noqa: B009 + == getattr(wrapped_train_stats, "loss_field") # noqa: B009 + == 13 + ) From f5084caa30b949469ad2e05cfef6d75f6ac848d3 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 11:16:25 +0100 Subject: [PATCH 064/115] Remove unused var --- test/base/test_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index d66d0799d..f4edd8ac8 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -177,7 +177,6 @@ def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None: for v in venv: v.seed(0) action_list = [1] * 5 + [0] * 10 + [1] * 20 - o = [v.reset()[0] for v in venv] for a in action_list: o = [] for v in venv: From 2e8f3c3e25c161c2b44ff7e207b93f3397604666 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 11:17:14 +0100 Subject: [PATCH 065/115] Use SpaceInfo to determne types action/obs space --- test/discrete/test_bdq.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index ce6eef4a5..8f5d74e46 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -1,5 +1,6 @@ import argparse import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,6 +11,7 @@ from tianshou.policy import BranchingDQNPolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import BranchingNet +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -51,9 +53,10 @@ def get_args() -> argparse.Namespace: def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) - - args.state_shape = env.observation_space.shape or env.observation_space.n - args.num_branches = env.action_space.shape[0] + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.num_branches = space_info.action_info.action_dim if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} From 176ecc0f42b48c669abe5066c2c0641e2c2f36ff Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 14:19:57 +0100 Subject: [PATCH 066/115] Revert "Use SpaceInfo to determne types action/obs space" This reverts commit 2e8f3c3e25c161c2b44ff7e207b93f3397604666. * SpaceInfo needs to support MultiDiscrete action spaces for the changes in this commit to run successfully --- test/discrete/test_bdq.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 8f5d74e46..ce6eef4a5 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -1,6 +1,5 @@ import argparse import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -11,7 +10,6 @@ from tianshou.policy import BranchingDQNPolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import BranchingNet -from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -53,10 +51,9 @@ def get_args() -> argparse.Namespace: def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) - env.action_space = cast(gym.spaces.Discrete, env.action_space) - space_info = SpaceInfo.from_env(env) - args.state_shape = space_info.observation_info.obs_shape - args.num_branches = space_info.action_info.action_dim + + args.state_shape = env.observation_space.shape or env.observation_space.n + args.num_branches = env.action_space.shape[0] if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} From 444b464f40afda394d48a5511651f560f364cdc1 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 15:00:37 +0100 Subject: [PATCH 067/115] Cover case when state is dict * I made the assumption that state has a key "hidden". Are the keys for an RNN-based state: cell, hidden? --- test/base/test_collector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index e46ea7f0d..eda8be149 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -64,8 +64,8 @@ def forward( state = np.zeros((len(batch.obs), 2)) elif isinstance(state, np.ndarray | BatchProtocol): state += np.int_(1) - else: - state += 1 + elif isinstance(state, dict) and state.get("hidden") is not None: + state["hidden"] += np.int_(1) if self.dict_state: if self.action_shape: action_shape = self.action_shape From 01aad1a1c033f264c4c5facdc4aef5aa770dfc3f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 15:30:10 +0100 Subject: [PATCH 068/115] Set integer default value for batch_size --- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_reinforce_hl.py | 2 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index fec2e264f..187208757 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -30,7 +30,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 80, repeat_per_collect: int = 1, - batch_size: int | None = None, + batch_size: int = 16, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 6ab0eb891..18360f779 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -32,7 +32,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 1024, repeat_per_collect: int = 1, - batch_size: int | None = None, + batch_size: int = 16, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index bc07e050b..5651ee1b8 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -29,7 +29,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 2048, repeat_per_collect: int = 1, - batch_size: int | None = None, + batch_size: int = 16, training_num: int = 10, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 2f9a77748..f113645b1 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -32,7 +32,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 1024, repeat_per_collect: int = 1, - batch_size: int | None = None, + batch_size: int = 16, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, From 06ae70ca7fe0a1ac57e469ac90ea3783cda30398 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 19:25:58 +0100 Subject: [PATCH 069/115] Add typing to env methods and use Gym API >v0.26 (with terminated, truncated) --- examples/vizdoom/env.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 04f4f0563..042d068d1 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -1,10 +1,12 @@ import os from collections.abc import Sequence +from typing import Any import cv2 import gymnasium as gym import numpy as np import vizdoom as vzd +from numpy.typing import NDArray from tianshou.env import ShmemVectorEnv @@ -77,7 +79,11 @@ def get_obs(self) -> None: self.obs_buffer[:-1] = self.obs_buffer[1:] self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2])) - def reset(self): + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[NDArray[np.uint8], dict[str, Any]]: if self.save_lmp: self.game.new_episode(f"lmps/episode_{self.count}.lmp") else: @@ -88,9 +94,9 @@ def reset(self): self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH) self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) - return self.obs_buffer + return self.obs_buffer, {} - def step(self, action): + def step(self, action: int) -> tuple[NDArray[np.uint8], float, bool, bool, dict[str, Any]]: self.game.make_action(self.available_actions[action], self.skip) reward = 0.0 self.get_obs() @@ -112,7 +118,7 @@ def step(self, action): elif self.game.is_episode_finished(): done = True info["TimeLimit.truncated"] = True - return self.obs_buffer, reward, done, info + return self.obs_buffer, reward, done, info.get("TimeLimit.truncated", False), info def render(self) -> None: pass @@ -121,7 +127,15 @@ def close(self) -> None: self.game.close() -def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_num): +def make_vizdoom_env( + task: str, + frame_skip: int, + res: tuple[int], + save_lmp: bool = False, + seed: int | None = None, + training_num: int = 10, + test_num: int = 10, +) -> tuple[Any | Env, Any | ShmemVectorEnv, Any | ShmemVectorEnv]: test_num = min(os.cpu_count() - 1, test_num) if envpool is not None: task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1" @@ -175,7 +189,7 @@ def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_n env = Env("maps/D3_battle.cfg", 4, (4, 84, 84)) print(env.available_actions) action_num = env.action_space.n - obs = env.reset() + obs, _ = env.reset() if env.spec: print(env.spec.reward_threshold) print(obs.shape, action_num) From aba1d8c371cf91290e30300f1a292e39c09ccd8c Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 19:34:07 +0100 Subject: [PATCH 070/115] Treat case when cpu_count() is None --- examples/vizdoom/env.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 042d068d1..414a5a89b 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -136,7 +136,9 @@ def make_vizdoom_env( training_num: int = 10, test_num: int = 10, ) -> tuple[Any | Env, Any | ShmemVectorEnv, Any | ShmemVectorEnv]: - test_num = min(os.cpu_count() - 1, test_num) + cpu_count = os.cpu_count() + if cpu_count is not None: + test_num = min(cpu_count - 1, test_num) if envpool is not None: task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1" lmp_save_dir = "lmps/" if save_lmp else "" From 92e19beebaa72bd9eade9abbbbd18e436805667e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Sun, 24 Mar 2024 19:54:42 +0100 Subject: [PATCH 071/115] Fix mypy issues: * Add missing type annotations * Resolve confusion of unreachable return statement by ignoring that mypy check. If enum has more than one member than mypy won't be confused anymore --- examples/mujoco/mujoco_env.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index dacf91548..fa972ac31 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,6 +1,8 @@ import logging import pickle +from gymnasium import Env + from tianshou.env import BaseVectorEnv, VectorEnvNormObs from tianshou.highlevel.env import ( ContinuousEnvironments, @@ -22,7 +24,13 @@ log = logging.getLogger(__name__) -def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool): +def make_mujoco_env( + task: str, + seed: int, + num_train_envs: int, + num_test_envs: int, + obs_norm: bool, +) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: """Wrapper function for Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. @@ -41,16 +49,16 @@ class MujocoEnvObsRmsPersistence(Persistence): def persist(self, event: PersistEvent, world: World) -> None: if event != PersistEvent.PERSIST_POLICY: - return + return # type: ignore[unreachable] # since PersistEvent has only one member, mypy infers that line is unreachable obs_rms = world.envs.train_envs.get_obs_rms() path = world.persist_path(self.FILENAME) log.info(f"Saving environment obs_rms value to {path}") with open(path, "wb") as f: pickle.dump(obs_rms, f) - def restore(self, event: RestoreEvent, world: World): + def restore(self, event: RestoreEvent, world: World) -> None: if event != RestoreEvent.RESTORE_POLICY: - return + return # type: ignore[unreachable] path = world.restore_path(self.FILENAME) log.info(f"Restoring environment obs_rms value from {path}") with open(path, "rb") as f: From 5ae29dfc47fa4f439d90f9a98266e892a0e7d88c Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Mon, 25 Mar 2024 08:58:14 +0100 Subject: [PATCH 072/115] Ignore mypy typing on lines that use old Gym API: * mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) --- examples/atari/atari_wrapper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 2a9c0cd27..2cf5bb6b4 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -62,7 +62,7 @@ def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: for _ in range(noops): step_result = self.env.step(self.noop_action) if len(step_result) == 4: - obs, rew, done, info = step_result + obs, rew, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) else: obs, rew, term, trunc, info = step_result done = term or trunc @@ -95,7 +95,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: for _ in range(self._skip): step_result = self.env.step(action) if len(step_result) == 4: - obs, reward, done, info = step_result + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) else: obs, reward, term, trunc, info = step_result done = term or trunc @@ -128,7 +128,7 @@ def __init__(self, env: gym.Env) -> None: def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: step_result = self.env.step(action) if len(step_result) == 4: - obs, reward, done, info = step_result + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) new_step_api = False else: obs, reward, term, trunc, info = step_result @@ -278,7 +278,7 @@ def reset(self, **kwargs: Any) -> tuple[npt.NDArray, dict]: def step(self, action): step_result = self.env.step(action) if len(step_result) == 4: - obs, reward, done, info = step_result + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) new_step_api = False else: obs, reward, term, trunc, info = step_result From 65c4aa19e15ca56ae8f9b70ce76335765af5b93f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Mon, 25 Mar 2024 09:02:27 +0100 Subject: [PATCH 073/115] Make mypy happy and add type to var --- examples/atari/atari_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 2cf5bb6b4..de10d5eb7 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -277,6 +277,7 @@ def reset(self, **kwargs: Any) -> tuple[npt.NDArray, dict]: def step(self, action): step_result = self.env.step(action) + done: bool if len(step_result) == 4: obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) new_step_api = False From 34d94bc9c1dea9b451bdb02588a62af32e50c673 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Mon, 25 Mar 2024 09:33:30 +0100 Subject: [PATCH 074/115] Use space_info to type env spaces --- examples/box2d/lunarlander_dqn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 46e98cc62..778c932cf 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -14,6 +14,7 @@ from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -50,9 +51,11 @@ def get_args() -> argparse.Namespace: def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + assert isinstance(env.action_space, gym.spaces.Discrete) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) From 216ba7c074ce8427cd28b172df4e5fa45348378a Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Mon, 25 Mar 2024 10:32:16 +0100 Subject: [PATCH 075/115] Use assert instead of cast to check for obs_space/action_space: * For situation in which algorithms demand a specific space --- examples/box2d/acrobot_dualdqn.py | 3 +-- examples/discrete/discrete_dqn.py | 4 +--- examples/offline/d4rl_cql.py | 3 +-- test/continuous/test_redq.py | 3 +-- test/discrete/test_c51.py | 3 +-- test/discrete/test_dqn.py | 3 +-- test/discrete/test_drqn.py | 3 +-- test/discrete/test_fqf.py | 3 +-- test/discrete/test_iqn.py | 3 +-- test/discrete/test_qrdqn.py | 3 +-- test/discrete/test_rainbow.py | 3 +-- test/discrete/test_sac.py | 3 +-- test/modelbased/test_dqn_icm.py | 3 +-- test/offline/gather_cartpole_data.py | 3 +-- test/offline/test_cql.py | 3 +-- test/offline/test_discrete_bcq.py | 3 +-- test/offline/test_discrete_cql.py | 3 +-- test/offline/test_discrete_crr.py | 3 +-- 18 files changed, 18 insertions(+), 37 deletions(-) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index ae5223a01..f28715bd5 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -51,7 +50,7 @@ def get_args() -> argparse.Namespace: def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index c588814ae..4f1a82b12 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -1,5 +1,3 @@ -from typing import cast - import gymnasium as gym import torch from torch.utils.tensorboard import SummaryWriter @@ -29,7 +27,7 @@ def main() -> None: # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network env = gym.make(task, render_mode="human") - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) state_shape = space_info.observation_info.obs_shape action_shape = space_info.action_info.action_shape diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 0e9fe62cd..11386d683 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -4,7 +4,6 @@ import datetime import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -220,7 +219,7 @@ def get_args() -> argparse.Namespace: def test_cql() -> None: args = get_args() env = gym.make(args.task) - env.action_space = cast(gym.spaces.Box, env.action_space) + assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 20177f429..0899bfbbe 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -58,7 +57,7 @@ def get_args() -> argparse.Namespace: def test_redq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Box, env.action_space) + assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 013e2c414..e092a4a0d 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -64,7 +63,7 @@ def get_args() -> argparse.Namespace: def test_c51(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index d0aba10c4..b7563c4c6 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -58,7 +57,7 @@ def get_args() -> argparse.Namespace: def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 3ed1f4fbe..8c6c049e2 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -51,7 +50,7 @@ def get_args() -> argparse.Namespace: def test_drqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 7de090119..65c319914 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -64,7 +63,7 @@ def get_args() -> argparse.Namespace: def test_fqf(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 5ff71f515..40280fe0c 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -64,7 +63,7 @@ def get_args() -> argparse.Namespace: def test_iqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index c1bbcc3fa..a3abed8eb 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -55,7 +54,7 @@ def get_args() -> argparse.Namespace: def test_qrdqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index fafa1e03b..27c442d67 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -63,7 +62,7 @@ def get_args() -> argparse.Namespace: def test_rainbow(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 9d5e27be6..8a397ef8c 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -54,7 +53,7 @@ def get_args() -> argparse.Namespace: def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 0f957c75d..e7360effe 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -73,7 +72,7 @@ def get_args() -> argparse.Namespace: def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 7f1a3128b..73b726aee 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -1,7 +1,6 @@ import argparse import os import pickle -from typing import cast import gymnasium as gym import numpy as np @@ -61,7 +60,7 @@ def get_args() -> argparse.Namespace: def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: args = get_args() env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 5ce5b406d..93eb6e5de 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -3,7 +3,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -81,7 +80,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: else: buffer = gather_data() env = gym.make(args.task) - env.action_space = cast(gym.spaces.Box, env.action_space) + assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 32e6d5696..512531e40 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -57,7 +56,7 @@ def get_args() -> argparse.Namespace: def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 3d8bb4c39..e0f336807 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -54,7 +53,7 @@ def get_args() -> argparse.Namespace: def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index beca5467f..6f762d746 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -52,7 +51,7 @@ def get_args() -> argparse.Namespace: def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape From 524f0af04fbb661d142d3f5dc6214312013b2ae6 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Mon, 25 Mar 2024 15:58:22 +0100 Subject: [PATCH 076/115] Assert action space --- examples/offline/atari_bcq.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 1f29aa5d1..63a6570b6 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -9,6 +9,7 @@ import numpy as np import torch +from gymnasium.spaces import Discrete from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env @@ -82,6 +83,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) + assert isinstance(env.action_space, Discrete) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W From fe08d6e588f11395da63fff1ba1edcefe2ff4a21 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 26 Mar 2024 16:52:03 +0100 Subject: [PATCH 077/115] Use assert as mypy doesn't know that FetchReach env has compute_reward method --- examples/mujoco/fetch_her_ddpg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index d79d69bda..501d6c180 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -118,6 +118,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_fetch_env(args.task, args.training_num, args.test_num) # The method HER works with goal-based environments assert isinstance(env.observation_space, gym.spaces.Dict) + assert hasattr(env, "compute_reward") args.state_shape = { "observation": env.observation_space["observation"].shape, "achieved_goal": env.observation_space["achieved_goal"].shape, From 7a5071cf5c61fbca0a55a56066fc2d8e6b95b113 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 26 Mar 2024 16:55:04 +0100 Subject: [PATCH 078/115] Check for none before comparing mean_rewards to reward_threshold * If none, then check is False. Or better to have a default reward_threshold value? --- examples/box2d/bipedal_bdq.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index f782785a5..c57baf456 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -125,7 +125,9 @@ def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - return mean_rewards >= getattr(env.spec.reward_threshold) + if env.spec and env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + return False def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) From c8e544860f8f98fc17cafc911f347e713aeca7c5 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 26 Mar 2024 16:58:24 +0100 Subject: [PATCH 079/115] Assert action space before accessing attributes specific to that space --- test/base/test_env.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index f4edd8ac8..f5d19feed 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -371,10 +371,12 @@ def step(self, act: Any) -> tuple[Any, Literal[-1], Literal[False], Literal[True bsz = 10 action_per_branch = [4, 6, 10, 7] env = DummyEnv() + assert isinstance(env.action_space, gym.spaces.Box) original_act = env.action_space.high # convert continous to multidiscrete action space # with different action number per dimension env_m = ContinuousToDiscrete(env, action_per_branch) + assert isinstance(env_m.action_space, gym.spaces.MultiDiscrete) # check conversion is working properly for one action np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act) # check conversion is working properly for a batch of actions @@ -385,8 +387,12 @@ def step(self, act: Any) -> tuple[Any, Literal[-1], Literal[False], Literal[True # convert multidiscrete with different action number per # dimension to discrete action space env_d = MultiDiscreteToDiscrete(env_m) + assert isinstance(env_d.action_space, gym.spaces.Discrete) # check conversion is working properly for one action - np.testing.assert_allclose(env_d.action(env_d.action_space.n - 1), env_m.action_space.nvec - 1) + np.testing.assert_allclose( + env_d.action(np.array(env_d.action_space.n - 1)), + env_m.action_space.nvec - 1, + ) # check conversion is working properly for a batch of actions np.testing.assert_allclose( env_d.action(np.array([env_d.action_space.n - 1] * bsz)), From f71ad85a47a17eb459acca2c42719b7eee331654 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:00:22 +0100 Subject: [PATCH 080/115] Refactor way DQN API is used: * Pass c, h, w explicitly * Check that obs_space is a tuple of size 3 * Cast some scalar vars to write type --- examples/atari/atari_network.py | 48 ++++++++++++++++++++++++--------- examples/offline/atari_bcq.py | 11 +++++--- examples/offline/atari_cql.py | 9 +++++-- examples/offline/atari_crr.py | 17 +++++++++--- examples/offline/atari_il.py | 5 +++- 5 files changed, 68 insertions(+), 22 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index d52fbc7df..0d6461f71 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -58,7 +58,7 @@ def __init__( c: int, h: int, w: int, - action_shape: Sequence[int], + action_shape: Sequence[int] | int, device: str | int | torch.device = "cpu", features_only: bool = False, output_dim: int | None = None, @@ -78,13 +78,14 @@ def __init__( with torch.no_grad(): self.output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])) if not features_only: + action_dim = int(np.prod(action_shape)) self.net = nn.Sequential( self.net, layer_init(nn.Linear(self.output_dim, 512)), nn.ReLU(inplace=True), - layer_init(nn.Linear(512, int(np.prod(action_shape)))), + layer_init(nn.Linear(512, action_dim)), ) - self.output_dim = np.prod(action_shape) + self.output_dim = action_dim elif output_dim is not None: self.net = nn.Sequential( self.net, @@ -122,7 +123,7 @@ def __init__( num_atoms: int = 51, device: str | int | torch.device = "cpu", ) -> None: - self.action_num = np.prod(action_shape) + self.action_num = int(np.prod(action_shape)) super().__init__(c, h, w, [self.action_num * num_atoms], device) self.num_atoms = num_atoms @@ -161,10 +162,10 @@ def __init__( is_noisy: bool = True, ) -> None: super().__init__(c, h, w, action_shape, device, features_only=True) - self.action_num = np.prod(action_shape) + self.action_num = int(np.prod(action_shape)) self.num_atoms = num_atoms - def linear(x, y): + def linear(x: int, y: int) -> NoisyLinear | nn.Linear: if is_noisy: return NoisyLinear(x, y, noisy_std) return nn.Linear(x, y) @@ -217,7 +218,7 @@ def __init__( c: int, h: int, w: int, - action_shape: Sequence[int], + action_shape: Sequence[int] | int, num_quantiles: int = 200, device: str | int | torch.device = "cpu", ) -> None: @@ -251,12 +252,25 @@ def __init__( self.features_only = features_only def create_module(self, envs: Environments, device: TDevice) -> Actor: + obs_shape = envs.get_observation_shape() + if isinstance(obs_shape, int): + obs_shape = [obs_shape] + assert len(obs_shape) == 3 + c, h, w = obs_shape + action_shape = envs.get_action_shape() + if isinstance(action_shape, np.int64): + action_shape = int(action_shape) + net: nn.Module net = DQN( - *envs.get_observation_shape(), - envs.get_action_shape(), + c, + h, + w, + action_shape, device=device, features_only=self.features_only, - output_dim=self.hidden_size, + output_dim=self.hidden_size + if isinstance(self.hidden_size, int) + else self.hidden_size[-1], layer_init=layer_init, ) if self.scale_obs: @@ -270,9 +284,19 @@ def __init__(self, features_only: bool = False, net_only: bool = False) -> None: self.net_only = net_only def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: + obs_shape = envs.get_observation_shape() + if isinstance(obs_shape, int): + obs_shape = [obs_shape] + assert len(obs_shape) == 3 + c, h, w = obs_shape + action_shape = envs.get_action_shape() + if isinstance(action_shape, np.int64): + action_shape = int(action_shape) dqn = DQN( - *envs.get_observation_shape(), - envs.get_action_shape(), + c, + h, + w, + action_shape, device=device, features_only=self.features_only, ).to(device) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 63a6570b6..1fc0dc7e3 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -84,8 +84,8 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: frame_stack=args.frames_stack, ) assert isinstance(env.action_space, Discrete) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape + args.action_shape = int(env.action_space.n) # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -93,8 +93,13 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model + assert args.state_shape is not None + assert len(args.state_shape) == 3 + c, h, w = args.state_shape feature_net = DQN( - *args.state_shape, + c, + h, + w, args.action_shape, device=args.device, features_only=True, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 07acbe4c2..b853782e4 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -9,6 +9,7 @@ import numpy as np import torch +from gymnasium.spaces import Discrete from examples.atari.atari_network import QRDQN from examples.atari.atari_wrapper import make_atari_env @@ -80,8 +81,9 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) + assert isinstance(env.action_space, Discrete) args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.action_shape = int(env.action_space.n) # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -89,7 +91,10 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) + assert args.state_shape is not None + assert len(args.state_shape) == 3 + c, h, w = args.state_shape + net = QRDQN(c, h, w, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: DiscreteCQLPolicy = DiscreteCQLPolicy( diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 8ec57a4bd..a4b31c4fb 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -9,6 +9,7 @@ import numpy as np import torch +from gymnasium.spaces import Discrete from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env @@ -20,6 +21,7 @@ from tianshou.trainer import OfflineTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -82,8 +84,10 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + assert isinstance(env.action_space, Discrete) + space_info = SpaceInfo.from_env(env) + args.state_shape = env.observation_space.shape + args.action_shape = space_info.action_info.action_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -91,8 +95,13 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model + assert args.state_shape is not None + assert len(args.state_shape) == 3 + c, h, w = args.state_shape feature_net = DQN( - *args.state_shape, + c, + h, + w, args.action_shape, device=args.device, features_only=True, @@ -107,7 +116,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: critic = Critic( feature_net, hidden_sizes=args.hidden_sizes, - last_size=np.prod(args.action_shape), + last_size=int(np.prod(args.action_shape)), device=args.device, ).to(args.device) actor_critic = ActorCritic(actor, critic) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index a1bb62f70..51468e64e 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -82,7 +82,10 @@ def test_il(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = DQN(*args.state_shape, args.action_shape, device=args.device).to(args.device) + assert args.state_shape is not None + assert len(args.state_shape) == 3 + c, h, w = args.state_shape + net = DQN(c, h, w, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: ImitationPolicy = ImitationPolicy(actor=net, optim=optim, action_space=env.action_space) From 56608dad369dd92cc597396273082687b1a76905 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:05:21 +0100 Subject: [PATCH 081/115] Fix FiniteVectorEnv.reset() to satisfy superclass type annotations --- test/base/test_env_finite.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 0ecdb849a..aadef3c97 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -7,7 +7,6 @@ import gymnasium as gym import numpy as np -import numpy.typing as npt import torch from gymnasium.spaces import Box from torch.utils.data import DataLoader, Dataset, DistributedSampler @@ -107,30 +106,30 @@ def reset( self, id: int | list[int] | np.ndarray | None = None, **kwargs: Any, - ) -> tuple[np.ndarray, dict | list[dict | None]]: - id = self._wrap_id(id) + ) -> tuple[np.ndarray, dict | list[dict]]: + id: list[int] | np.ndarray = self._wrap_id(id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - obs: list[npt.ArrayLike | None] = [None] * len(id) + obs_list: list[np.ndarray | None] = [None] * len(id) infos: list[dict | None] = [None] * len(id) id2idx = {i: k for k, i in enumerate(id)} if request_id: for k, o, info in zip(request_id, *super().reset(request_id), strict=True): - obs[id2idx[k]] = o + obs_list[id2idx[k]] = o infos[id2idx[k]] = info - for i, o in zip(id, obs, strict=True): + for i, o in zip(id, obs_list, strict=True): if o is None and i in self._alive_env_ids: self._alive_env_ids.remove(i) # fill empty observation with default(fake) observation - for o in obs: + for o in obs_list: self._set_default_obs(o) - for i in range(len(obs)): - if obs[i] is None: - obs[i] = self._get_default_obs() + for i in range(len(obs_list)): + if obs_list[i] is None: + obs_list[i] = self._get_default_obs() if infos[i] is None: infos[i] = self._get_default_info() @@ -138,9 +137,16 @@ def reset( self.reset() raise StopIteration - obs = [o for o in obs if o is not None] + obs_list = [o for o in obs_list if o is not None] + infos = [info for info in infos if info is not None] - return np.stack(obs), infos + obs: np.ndarray + try: + obs = np.stack(obs_list) + except ValueError: + obs = np.array(obs_list, dtype=object) + + return obs, infos def step( self, From 10352d4a97c19c771696eeefe69e0fb6d03972b0 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:11:53 +0100 Subject: [PATCH 082/115] Fix mypy issues for AtariWrappers * Use type annotations of Gym API > 0.26 ver * Rename some vars to prevent confusion from mypy * add missing type annotations * add explicit type to some vars to make mypy happy * assert that shape obs space is not None before using it --- examples/atari/atari_wrapper.py | 41 +++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index de10d5eb7..557ca7cd4 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -8,7 +8,6 @@ import cv2 import gymnasium as gym import numpy as np -import numpy.typing as npt from gymnasium import Env from tianshou.env import BaseVectorEnv @@ -108,7 +107,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: if new_step_api: return max_frame, total_reward, term, trunc, info - return max_frame, total_reward, done, info + return max_frame, total_reward, done, info.get("TimeLimit.truncated", False), info class EpisodicLifeEnv(gym.Wrapper): @@ -134,7 +133,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: obs, reward, term, trunc, info = step_result done = term or trunc new_step_api = True - + reward = float(reward) self.was_real_done = done # check current lives, make loss of life terminal, then update lives to # handle bonus lives @@ -149,7 +148,7 @@ def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: self.lives = lives if new_step_api: return obs, reward, term, trunc, info - return obs, reward, done, info + return obs, reward, done, info.get("TimeLimit.truncated", False), info def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: """Calls the Gym environment reset, only when lives are exhausted. @@ -199,11 +198,13 @@ class WarpFrame(gym.ObservationWrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) self.size = 84 + obs_space = env.observation_space + assert isinstance(obs_space, gym.spaces.Box) self.observation_space = gym.spaces.Box( - low=np.min(env.observation_space.low), - high=np.max(env.observation_space.high), + low=np.min(obs_space.low), + high=np.max(obs_space.high), shape=(self.size, self.size), - dtype=env.observation_space.dtype, + dtype=obs_space.dtype, ) def observation(self, frame: np.ndarray) -> np.ndarray: @@ -220,14 +221,16 @@ class ScaledFloatFrame(gym.ObservationWrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) - low = np.min(env.observation_space.low) - high = np.max(env.observation_space.high) + obs_space = env.observation_space + assert isinstance(obs_space, gym.spaces.Box) + low = np.min(obs_space.low) + high = np.max(obs_space.high) self.bias = low self.scale = high - low self.observation_space = gym.spaces.Box( low=0.0, high=1.0, - shape=env.observation_space.shape, + shape=obs_space.shape, dtype=np.float32, ) @@ -261,7 +264,10 @@ def __init__(self, env: gym.Env, n_frames: int) -> None: super().__init__(env) self.n_frames: int = n_frames self.frames: deque[tuple[Any, ...]] = deque([], maxlen=n_frames) - shape = (n_frames, *env.observation_space.shape) + obs_space_shape = env.observation_space.shape + assert obs_space_shape is not None + shape = (n_frames, *obs_space_shape) + assert isinstance(env.observation_space, gym.spaces.Box) self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), high=np.max(env.observation_space.high), @@ -269,13 +275,13 @@ def __init__(self, env: gym.Env, n_frames: int) -> None: dtype=env.observation_space.dtype, ) - def reset(self, **kwargs: Any) -> tuple[npt.NDArray, dict]: + def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) for _ in range(self.n_frames): self.frames.append(obs) return (self._get_ob(), info) if return_info else (self._get_ob(), {}) - def step(self, action): + def step(self, action: Any) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: step_result = self.env.step(action) done: bool if len(step_result) == 4: @@ -285,11 +291,12 @@ def step(self, action): obs, reward, term, trunc, info = step_result new_step_api = True self.frames.append(obs) + reward = float(reward) if new_step_api: return self._get_ob(), reward, term, trunc, info - return self._get_ob(), reward, done, info + return self._get_ob(), reward, done, info.get("TimeLimit.truncated", False), info - def _get_ob(self) -> npt.NDArray: + def _get_ob(self) -> np.ndarray: # the original wrapper use `LazyFrames` but since we use np buffer, # it has no effect return np.stack(self.frames, axis=0) @@ -379,7 +386,7 @@ def __init__( envpool_factory = None if use_envpool_if_available: if envpool_is_available: - envpool_factory = self.EnvPoolFactory(self) + envpool_factory = self.EnvPoolFactoryAtari(self) log.info("Using envpool, because it available") else: log.info("Not using envpool, because it is not available") @@ -401,7 +408,7 @@ def create_env(self, mode: EnvMode) -> gym.Env: scale=self.scale, ) - class EnvPoolFactory(EnvPoolFactory): + class EnvPoolFactoryAtari(EnvPoolFactory): """Atari-specific envpool creation. Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`, it sets the creation keyword arguments accordingly. From c21586a9beb06b147a6d9df7e611e9630bff78ae Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:02:07 +0100 Subject: [PATCH 083/115] Make mypy happy and use typing for obs_space_dtype * Should we move this code in a utils? or maybe SpaceInfo? --- examples/atari/atari_wrapper.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 557ca7cd4..db1b6b2c3 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -199,12 +199,17 @@ def __init__(self, env: gym.Env) -> None: super().__init__(env) self.size = 84 obs_space = env.observation_space + obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]] + if np.issubdtype(type(obs_space.dtype), np.integer): + obs_space_dtype = np.integer + elif np.issubdtype(type(obs_space.dtype), np.floating): + obs_space_dtype = np.floating assert isinstance(obs_space, gym.spaces.Box) self.observation_space = gym.spaces.Box( low=np.min(obs_space.low), high=np.max(obs_space.high), shape=(self.size, self.size), - dtype=obs_space.dtype, + dtype=obs_space_dtype, ) def observation(self, frame: np.ndarray) -> np.ndarray: @@ -264,15 +269,21 @@ def __init__(self, env: gym.Env, n_frames: int) -> None: super().__init__(env) self.n_frames: int = n_frames self.frames: deque[tuple[Any, ...]] = deque([], maxlen=n_frames) + obs_space = env.observation_space obs_space_shape = env.observation_space.shape assert obs_space_shape is not None shape = (n_frames, *obs_space_shape) assert isinstance(env.observation_space, gym.spaces.Box) + obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]] + if np.issubdtype(type(obs_space.dtype), np.integer): + obs_space_dtype = np.integer + elif np.issubdtype(type(obs_space.dtype), np.floating): + obs_space_dtype = np.floating self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), high=np.max(env.observation_space.high), shape=shape, - dtype=env.observation_space.dtype, + dtype=obs_space_dtype, ) def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: From 75bb1f04808760b7d29ba05ac7af0451eea088d8 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:09:01 +0100 Subject: [PATCH 084/115] Assert env.action_space = MultiDiscrete: * For MultiDiscrete, shape can never be None (cf. `shape()` in `MultiDiscrete`) --- test/discrete/test_bdq.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index ce6eef4a5..3fe3a7075 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -53,6 +53,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: env = ContinuousToDiscrete(env, args.action_per_branch) args.state_shape = env.observation_space.shape or env.observation_space.n + assert isinstance(env.action_space, gym.spaces.MultiDiscrete) args.num_branches = env.action_space.shape[0] if args.reward_threshold is None: From a0d1427b0170af8aef29d8c6a1ff93cbc3876fb5 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 15:52:47 +0100 Subject: [PATCH 085/115] Extend IndexType to explicitly have list[int]: * mypy can't infer that this type is also permitted and so it complains that Collector._reset_state's id is not compatible when used inside the function (e.g. state.empty_(id) when state is Batch) --- tianshou/data/batch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 508e5c9a2..7a6bb0045 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -19,7 +19,13 @@ import torch _SingleIndexType = slice | int | EllipsisType -IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] +IndexType = ( + np.ndarray + | _SingleIndexType + | list[_SingleIndexType] + | tuple[_SingleIndexType, ...] + | list[int] +) TBatch = TypeVar("TBatch", bound="BatchProtocol") arr_type = torch.Tensor | np.ndarray From 66d92af137709f8a656a98ac84ddf0274e93c241 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:32:29 +0100 Subject: [PATCH 086/115] Use SpaceInfo to type env obs/action space: * Unpack c,h,w. star notation is confusing mypy --- examples/offline/atari_cql.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index b853782e4..d1f9bc410 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -19,6 +19,7 @@ from tianshou.policy import DiscreteCQLPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -82,8 +83,12 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: frame_stack=args.frames_stack, ) assert isinstance(env.action_space, Discrete) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = int(env.action_space.n) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + assert isinstance(args.state_shape, list[int] | tuple[int]) + assert len(args.state_shape) == 3 + c, h, w = args.state_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -91,9 +96,6 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - assert args.state_shape is not None - assert len(args.state_shape) == 3 - c, h, w = args.state_shape net = QRDQN(c, h, w, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy From d0c8745b566fbe4a3dd1854224381f66f49fc9da Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:47:04 +0100 Subject: [PATCH 087/115] Use SpaceInfo to type env obs/action space: * Unpack c,h,w. star notation is confusing mypy --- examples/offline/atari_il.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 51468e64e..bb7822ea9 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -18,6 +18,7 @@ from tianshou.policy import ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -73,8 +74,12 @@ def test_il(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + assert isinstance(args.state_shape, list[int] | tuple[int]) + assert len(args.state_shape) == 3 + c, h, w = args.state_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -82,9 +87,6 @@ def test_il(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - assert args.state_shape is not None - assert len(args.state_shape) == 3 - c, h, w = args.state_shape net = DQN(c, h, w, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy From d4b5d2388de4f6e0b200d188b4f866466327c13a Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 20:37:30 +0100 Subject: [PATCH 088/115] Assert action space before accessing space-specific attrs --- examples/vizdoom/env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 414a5a89b..99c5b6f3e 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -190,6 +190,7 @@ def make_vizdoom_env( # env = Env("maps/D1_basic.cfg", 4, (4, 84, 84)) env = Env("maps/D3_battle.cfg", 4, (4, 84, 84)) print(env.available_actions) + assert isinstance(env.action_space, gym.spaces.Discrete) action_num = env.action_space.n obs, _ = env.reset() if env.spec: From 2255f6e7c72a6a446ef8d0c2d55c3c79e5ccc292 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 21:19:58 +0100 Subject: [PATCH 089/115] Add type hints to obs/action spaces: * Should `BranchingPolicy` support also `MultiDiscrete` action spaces? --- examples/box2d/bipedal_bdq.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index c57baf456..d0ed2ddc3 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -57,11 +57,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.num_branches = ( - args.action_shape if isinstance(args.action_shape, int) else args.action_shape[0] - ) + assert isinstance(env.action_space, gym.spaces.MultiDiscrete) + assert isinstance( + env.observation_space, + gym.spaces.Box, + ) # BipedalWalker-v3 has `Box` observation space by design + args.state_shape = env.observation_space.shape + args.action_shape = env.action_space.shape + args.num_branches = args.action_shape[0] print("Observations shape:", args.state_shape) print("Num branches:", args.num_branches) @@ -102,7 +105,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: model=net, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, + action_space=env.action_space, # type: ignore[arg-type] # TODO: should BranchingPolicy support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector From 8c5fa778a9d76d4b3b2b8aac3d006c2be6933460 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 27 Mar 2024 21:22:29 +0100 Subject: [PATCH 090/115] Add type hints to action/obs spaces: * Should `BranchingPolicy` support also `MultiDiscrete` action spaces? --- test/discrete/test_bdq.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 3fe3a7075..e7abe0b8d 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -52,7 +52,10 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) - args.state_shape = env.observation_space.shape or env.observation_space.n + if isinstance(env.observation_space, gym.spaces.Box): + args.state_shape = env.observation_space.shape + elif isinstance(env.observation_space, gym.spaces.Discrete): + args.state_shape = int(env.observation_space.n) assert isinstance(env.action_space, gym.spaces.MultiDiscrete) args.num_branches = env.action_space.shape[0] @@ -100,7 +103,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: model=net, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, + action_space=env.action_space, # type: ignore[arg-type] # TODO: should BranchingPolicy support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector From 069a2e620104fe0f338afcb158f56bc8fde50900 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Thu, 28 Mar 2024 09:35:38 +0100 Subject: [PATCH 091/115] Minor reformat comments --- examples/box2d/bipedal_bdq.py | 2 +- test/discrete/test_bdq.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index d0ed2ddc3..8be3bac00 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -105,7 +105,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: model=net, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, # type: ignore[arg-type] # TODO: should BranchingPolicy support also `MultiDiscrete` action spaces? + action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index e7abe0b8d..55434ceb5 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -103,7 +103,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: model=net, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, # type: ignore[arg-type] # TODO: should BranchingPolicy support also `MultiDiscrete` action spaces? + action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector From dab400f04c167f1636292d30e5ddc89d4e89e1e5 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 1 Apr 2024 17:43:05 +0200 Subject: [PATCH 092/115] Post-resolving conflicts in tests and examples --- examples/atari/atari_ppo.py | 13 +- examples/vizdoom/vizdoom_ppo.py | 13 +- test/base/test_collector.py | 317 ++++++++++++++++++-------------- test/base/test_env.py | 18 +- test/base/test_env_finite.py | 18 +- test/base/test_policy.py | 5 +- 6 files changed, 212 insertions(+), 172 deletions(-) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index cd94eca18..461926f90 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -8,7 +8,7 @@ import torch from atari_network import DQN, layer_init, scale_obs from atari_wrapper import make_atari_env -from torch.distributions import Categorical, Distribution +from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, VectorReplayBuffer @@ -131,16 +131,11 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - # define policy - def dist(logits: torch.Tensor) -> Distribution: - return Categorical(logits=logits) - - policy: PPOPolicy | ICMPolicy - policy = PPOPolicy( + policy: PPOPolicy = PPOPolicy( actor=actor, critic=critic, optim=optim, - dist_fn=dist, + dist_fn=Categorical, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, @@ -168,7 +163,7 @@ def dist(logits: torch.Tensor) -> Distribution: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( + policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] policy=policy, model=icm_net, optim=icm_optim, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 010bb28ec..7476d4f26 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -8,7 +8,7 @@ import torch from env import make_vizdoom_env from network import DQN -from torch.distributions import Categorical, Distribution +from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, VectorReplayBuffer @@ -136,16 +136,11 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - # define policy - def dist(logits: torch.Tensor) -> Distribution: - return Categorical(logits=logits) - - policy: PPOPolicy | ICMPolicy - policy = PPOPolicy( + policy: PPOPolicy = PPOPolicy( actor=actor, critic=critic, optim=optim, - dist_fn=dist, + dist_fn=Categorical, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, @@ -178,7 +173,7 @@ def dist(logits: torch.Tensor) -> Distribution: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( + policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] policy=policy, model=icm_net, optim=icm_optim, diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6bc1703f6..d48294212 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,3 +1,6 @@ +from collections.abc import Callable, Sequence +from typing import Any + import gymnasium as gym import numpy as np import pytest @@ -12,8 +15,10 @@ ReplayBuffer, VectorReplayBuffer, ) +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import BasePolicy +from tianshou.policy import BasePolicy, TrainingStats try: import envpool @@ -30,9 +35,9 @@ class MaxActionPolicy(BasePolicy): def __init__( self, action_space: gym.spaces.Space | None = None, - dict_state=False, - need_state=True, - action_shape=None, + dict_state: bool = False, + need_state: bool = True, + action_shape: Sequence[int] | int | None = None, ) -> None: """Mock policy for testing, will always return an array of ones of the shape of the action space. Note that this doesn't make much sense for discrete action space (the output is then intepreted as @@ -48,20 +53,32 @@ def __init__( self.need_state = need_state self.action_shape = action_shape - def forward(self, batch, state=None): + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: if self.need_state: if state is None: state = np.zeros((len(batch.obs), 2)) - else: - state += 1 + elif isinstance(state, np.ndarray | BatchProtocol): + state += np.int_(1) + elif isinstance(state, dict) and state.get("hidden") is not None: + state["hidden"] += np.int_(1) if self.dict_state: - action_shape = self.action_shape if self.action_shape else len(batch.obs["index"]) + if self.action_shape: + action_shape = self.action_shape + elif isinstance(batch.obs, BatchProtocol): + action_shape = len(batch.obs["index"]) + else: + action_shape = len(batch.obs) return Batch(act=np.ones(action_shape), state=state) action_shape = self.action_shape if self.action_shape else len(batch.obs) return Batch(act=np.ones(action_shape), state=state) - def learn(self): - pass + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: + raise NotImplementedError def test_collector() -> None: @@ -90,7 +107,9 @@ def test_collector() -> None: # Making one more step results in obs_next=1 # The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0]) - assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1]) + obs_next = c_single_env.buffer[:].obs_next[..., 0] + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, [1, 2, 1]) keys = np.zeros(100) keys[:3] = 1 assert np.allclose(c_single_env.buffer.info["key"], keys) @@ -110,7 +129,9 @@ def test_collector() -> None: c_single_env.collect(n_episode=3) assert len(c_single_env.buffer) == 8 assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) - assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + obs_next = c_single_env.buffer[:].obs_next[..., 0] + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c_single_env.buffer.info["key"][:8], 1) for e in c_single_env.buffer.info["env"][:8]: assert isinstance(e, MoveToRightEnv) @@ -131,7 +152,9 @@ def test_collector() -> None: valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) - assert np.allclose(c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + obs_next = c_subproc_venv_4_envs.buffer[:].obs_next[..., 0] + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) keys = np.zeros(100) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) @@ -153,8 +176,10 @@ def test_collector() -> None: valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) + obs_next = c_subproc_venv_4_envs.buffer[:].obs_next[..., 0] + assert isinstance(obs_next, np.ndarray) assert np.allclose( - c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], + obs_next, [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], ) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] @@ -204,9 +229,12 @@ def test_collector() -> None: with pytest.raises(TypeError): c_dummy_venv_4_envs.collect() + def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: + return lambda: NXEnv(i, t) + # test NXEnv for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) + envs = SubprocVectorEnv([get_env_factory(i=i, t=obs_type) for i in [5, 10, 15, 20]]) c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c_suproc_new.reset() c_suproc_new.collect(n_step=6) @@ -214,46 +242,55 @@ def test_collector() -> None: @pytest.fixture() -def get_AsyncCollector(): +def async_collector_and_env_lens() -> tuple[AsyncCollector, list[int]]: env_lens = [2, 3, 4, 5] env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MaxActionPolicy() bufsize = 60 - c1 = AsyncCollector( + async_collector = AsyncCollector( policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), ) - c1.reset() - return c1, env_lens + async_collector.reset() + return async_collector, env_lens class TestAsyncCollector: - def test_collect_without_argument_gives_error(self, get_AsyncCollector): - c1, env_lens = get_AsyncCollector + def test_collect_without_argument_gives_error( + self, + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens with pytest.raises(TypeError): c1.collect() - def test_collect_one_episode_async(self, get_AsyncCollector): - c1, env_lens = get_AsyncCollector + def test_collect_one_episode_async( + self, + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens result = c1.collect(n_episode=1) assert result.n_collected_episodes >= 1 def test_enough_episodes_two_collection_cycles_n_episode_without_reset( self, - get_AsyncCollector, - ): - c1, env_lens = get_AsyncCollector + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens n_episode = 2 result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False) assert result_c1.n_collected_episodes >= n_episode result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False) assert result_c2.n_collected_episodes >= n_episode - def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_AsyncCollector): - c1, env_lens = get_AsyncCollector + def test_enough_episodes_two_collection_cycles_n_episode_with_reset( + self, + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens n_episode = 2 result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True) assert result_c1.n_collected_episodes >= n_episode @@ -262,9 +299,9 @@ def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_As def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode( self, - get_AsyncCollector, - ): - c1, env_lens = get_AsyncCollector + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens ptr = [0, 0, 0, 0] bufsize = 60 for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): @@ -284,9 +321,9 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step( self, - get_AsyncCollector, - ): - c1, env_lens = get_AsyncCollector + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens bufsize = 60 ptr = [0, 0, 0, 0] for n_step in tqdm.trange(1, 15, desc="test async n_step"): @@ -303,17 +340,15 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti assert np.all(buf.obs[indices].reshape(count, env_len) == seq) assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) - @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step( self, - get_AsyncCollector, - gym_reset_kwargs, - ): - c1, env_lens = get_AsyncCollector + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens bufsize = 60 ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): - result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) + result = c1.collect(n_episode=n_episode) assert result.n_collected_episodes >= n_episode # check buffer data, obs and obs_next, env_id for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): @@ -328,7 +363,7 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data, thus no bincount stuff as above for n_step in tqdm.trange(1, 15, desc="test async n_step"): - result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) + result = c1.collect(n_step=n_step) assert result.n_collected_steps >= n_step for i in range(4): env_len = i + 2 @@ -371,9 +406,11 @@ def test_collector_with_dict_state() -> None: batch, _ = c1.buffer.sample(10) c0.buffer.update(c1.buffer) assert len(c0.buffer) in [42, 43] + cur_obs = c0.buffer[:].obs + assert isinstance(cur_obs, Batch) if len(c0.buffer) == 42: assert np.all( - c0.buffer[:].obs.index[..., 0] + cur_obs.index[..., 0] == [ 0, 1, @@ -418,10 +455,10 @@ def test_collector_with_dict_state() -> None: 3, 4, ], - ), c0.buffer[:].obs.index[..., 0] + ), cur_obs.index[..., 0] else: assert np.all( - c0.buffer[:].obs.index[..., 0] + cur_obs.index[..., 0] == [ 0, 1, @@ -467,7 +504,7 @@ def test_collector_with_dict_state() -> None: 3, 4, ], - ), c0.buffer[:].obs.index[..., 0] + ), cur_obs.index[..., 0] c2 = Collector( policy, envs, @@ -512,96 +549,100 @@ def test_collector_with_multi_agent() -> None: c_single_env.buffer.update(c_multi_env_ma.buffer) assert len(c_single_env.buffer) in [42, 43] if len(c_single_env.buffer) == 42: - multi_env_returns = [ - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - ] + multi_env_returns = np.ndarray( + [ + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + ], + ) else: - multi_env_returns = [ - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - ] + multi_env_returns = np.ndarray( + [ + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + ], + ) assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns]) assert np.all(c_single_env.buffer[:].done == multi_env_returns) c2 = Collector( @@ -656,7 +697,9 @@ def test_collector_with_atari_setting() -> None: obs = np.zeros_like(c2.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] assert np.all(c2.buffer.obs == obs) - assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) + obs_next = c2.buffer[:].obs_next + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) # atari multi buffer env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] @@ -881,7 +924,7 @@ def test_collector_envpool_gym_reset_return_info() -> None: assert np.allclose(c0.buffer.info["env_id"], env_ids) -def test_collector_with_vector_env(): +def test_collector_with_vector_env() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] dum = DummyVectorEnv(env_fns) @@ -905,7 +948,7 @@ def test_collector_with_vector_env(): assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens) -def test_async_collector_with_vector_env(): +def test_async_collector_with_vector_env() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] dum = DummyVectorEnv(env_fns) diff --git a/test/base/test_env.py b/test/base/test_env.py index 7d93897bd..a476ec5a9 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -290,8 +290,8 @@ def reset_result_to_obs(reset_result: tuple[np.ndarray, dict | list[dict]]) -> n eps = np.finfo(np.float32).eps.item() raw_reset_result = raw_env.reset() train_reset_result = train_env.reset() - initial_raw_obs = reset_result_to_obs(raw_reset_result) - initial_train_obs = reset_result_to_obs(train_reset_result) + initial_raw_obs = reset_result_to_obs(raw_reset_result) # type: ignore + initial_train_obs = reset_result_to_obs(train_reset_result) # type: ignore raw_obs, train_obs = [initial_raw_obs], [initial_train_obs] for action in action_list: step_result = raw_env.step(action) @@ -303,7 +303,7 @@ def reset_result_to_obs(reset_result: tuple[np.ndarray, dict | list[dict]]) -> n raw_obs.append(obs) if np.any(done): reset_result = raw_env.reset(np.where(done)[0]) - obs = reset_result_to_obs(reset_result) + obs = reset_result_to_obs(reset_result) # type: ignore raw_obs.append(obs) step_result = train_env.step(action) if len(step_result) == 5: @@ -314,7 +314,7 @@ def reset_result_to_obs(reset_result: tuple[np.ndarray, dict | list[dict]]) -> n train_obs.append(obs) if np.any(done): reset_result = train_env.reset(np.where(done)[0]) - obs = reset_result_to_obs(reset_result) + obs = reset_result_to_obs(reset_result) # type: ignore train_obs.append(obs) ref_rms = RunningMeanStd() for ro, to in zip(raw_obs, train_obs, strict=True): @@ -326,7 +326,7 @@ def reset_result_to_obs(reset_result: tuple[np.ndarray, dict | list[dict]]) -> n assert np.allclose(ref_rms.mean, test_env.get_obs_rms().mean) assert np.allclose(ref_rms.var, test_env.get_obs_rms().var) reset_result = test_env.reset() - obs = reset_result_to_obs(reset_result) + obs = reset_result_to_obs(reset_result) # type: ignore test_obs = [obs] for action in action_list: step_result = test_env.step(action) @@ -338,7 +338,7 @@ def reset_result_to_obs(reset_result: tuple[np.ndarray, dict | list[dict]]) -> n test_obs.append(obs) if np.any(done): reset_result = test_env.reset(np.where(done)[0]) - obs = reset_result_to_obs(reset_result) + obs = reset_result_to_obs(reset_result) # type: ignore test_obs.append(obs) for ro, to in zip(raw_obs, test_obs, strict=True): no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps) @@ -408,6 +408,7 @@ def step(self, act: Any) -> tuple[Any, Literal[-1], Literal[False], Literal[True assert truncated +# TODO: old gym envs are no longer supported! Replace by Ant-v4 and fix assoticiated tests @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_venv_wrapper_envpool() -> None: raw = envpool.make_gymnasium("Ant-v3", num_envs=4) @@ -426,8 +427,9 @@ def test_venv_wrapper_envpool_gym_reset_return_info() -> None: ) obs, info = env.reset() assert obs.shape[0] == num_envs - if isinstance(info, dict): - for _, v in info.items(): + # This is not actually unreachable b/c envpool does not return info in the right format + if isinstance(info, dict): # type: ignore[unreachable] + for _, v in info.items(): # type: ignore[unreachable] if not isinstance(v, dict): assert v.shape[0] == num_envs else: diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index f69079d33..657100554 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -3,7 +3,7 @@ import copy from collections import Counter from collections.abc import Callable, Iterator, Sequence -from typing import Any +from typing import Any, cast import gymnasium as gym import numpy as np @@ -102,14 +102,18 @@ def _get_default_info(self) -> dict | None: # END - def reset(self, env_id: int | list[int] | np.ndarray | None = None): + def reset( + self, + env_id: int | list[int] | np.ndarray | None = None, + **kwargs: Any, + ) -> tuple[np.ndarray, np.ndarray]: env_id = self._wrap_id(env_id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index request_id = list(filter(lambda i: i in self._alive_env_ids, env_id)) - obs_list = [None] * len(env_id) - infos = [None] * len(env_id) + obs_list: list[np.ndarray | None] = [None] * len(env_id) + infos: list[dict | None] = [None] * len(env_id) id2idx = {i: k for k, i in enumerate(env_id)} if request_id: for k, o, info in zip(request_id, *super().reset(request_id), strict=True): @@ -133,11 +137,14 @@ def reset(self, env_id: int | list[int] | np.ndarray | None = None): self.reset() raise StopIteration + obs_list = cast(list[np.ndarray], obs_list) + infos = cast(list[dict], infos) + return np.stack(obs_list), np.array(infos) def step( self, - action: np.ndarray | torch.Tensor, + action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: ids: list[int] | np.ndarray = self._wrap_id(id) @@ -146,6 +153,7 @@ def step( result: list[list] = [[None, 0.0, False, False, None] for _ in range(len(ids))] # ask super to step alive envs and remap to current index + assert action is not None if request_id: valid_act = np.stack([action[id2idx[i]] for i in request_id]) for i, (r_obs, r_reward, r_term, r_trunc, r_info) in zip( diff --git a/test/base/test_policy.py b/test/base/test_policy.py index a52344fdf..7c3aacc07 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -1,5 +1,3 @@ -from collections.abc import Callable - import gymnasium as gym import numpy as np import pytest @@ -23,7 +21,6 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy: action_type = request.param action_space: gym.spaces.Box | gym.spaces.Discrete actor: Actor | ActorProb - dist_fn: Callable[[torch.Tensor], torch.distributions.Distribution] if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) actor = ActorProb( @@ -41,7 +38,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n), action_shape=action_space.n, ) - dist_fn = lambda logits: Categorical(logits=logits) + dist_fn = Categorical else: raise ValueError(f"Unknown action type: {action_type}") From aaa56ea05228c8fdd80d95709ac0090d505509ca Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 1 Apr 2024 18:18:16 +0200 Subject: [PATCH 093/115] Pyproject: added stubs, included tests and examples in type-check --- poetry.lock | 127 ++++++++++++++++++++++++------------------------- pyproject.toml | 7 ++- 2 files changed, 66 insertions(+), 68 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0aa79b7e7..a6fbf3229 100644 --- a/poetry.lock +++ b/poetry.lock @@ -980,6 +980,13 @@ files = [ {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, + {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, @@ -1083,43 +1090,6 @@ packaging = "*" types-protobuf = ">=3.17.3" typing-extensions = "*" -[[package]] -name = "etils" -version = "1.6.0" -description = "Collection of common python utils" -optional = true -python-versions = ">=3.10" -files = [ - {file = "etils-1.6.0-py3-none-any.whl", hash = "sha256:3da192b057929f2511f9ef713cee7d9c498e741740f8b2a9c0f6392d787201d4"}, - {file = "etils-1.6.0.tar.gz", hash = "sha256:c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a"}, -] - -[package.dependencies] -fsspec = {version = "*", optional = true, markers = "extra == \"epath\""} -importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""} -typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""} -zipp = {version = "*", optional = true, markers = "extra == \"epath\""} - -[package.extras] -all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"] -array-types = ["etils[enp]"] -dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"] -docs = ["etils[all,dev]", "sphinx-apitree[ext]"] -eapp = ["absl-py", "etils[epy]", "simple_parsing"] -ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"] -edc = ["etils[epy]"] -enp = ["etils[epy]", "numpy"] -epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"] -epath-gcs = ["etils[epath]", "gcsfs"] -epath-s3 = ["etils[epath]", "s3fs"] -epy = ["typing_extensions"] -etqdm = ["absl-py", "etils[epy]", "tqdm"] -etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"] -etree-dm = ["dm-tree", "etils[etree]"] -etree-jax = ["etils[etree]", "jax[cpu]"] -etree-tf = ["etils[etree]", "tensorflow"] -lazy-imports = ["etils[ecolab]"] - [[package]] name = "executing" version = "2.0.1" @@ -2688,42 +2658,40 @@ files = [ [[package]] name = "mujoco" -version = "3.1.1" +version = "2.3.7" description = "MuJoCo Physics Simulator" optional = true python-versions = ">=3.8" files = [ - {file = "mujoco-3.1.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:be7aa04f8c91bc77fea6574c80154e62973fda0a959a6add4c9bc426db0ea9de"}, - {file = "mujoco-3.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e35a60ade27b8e074ad7f08496e4a9101da9d358401bcbb08610dcf5066c3622"}, - {file = "mujoco-3.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f450b46802fca047e2d19ce8adefa9f4a1787273a27511d76ef717eafaf18d8b"}, - {file = "mujoco-3.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51ac0f9df06e612ee628c571bab0320dc7721b7732e8c025a2289fda17f98a47"}, - {file = "mujoco-3.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:d78a07fd18ae82a4cd4628e062fff1224220a7d86749c02170a0ea8e356c7442"}, - {file = "mujoco-3.1.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:34a61d8c1631aa6d85252b04b01fdc98bf7d6829e1aab08182069f29af02617e"}, - {file = "mujoco-3.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:34f2b63b9f7e76b10a9a82d085d2637ecccf6f2b2df177d7bc3d16b6857af861"}, - {file = "mujoco-3.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:537e6ca9b0896865a8c30da6060b158299450776cd8e5796fd23c1fc54d26aa5"}, - {file = "mujoco-3.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aee8a9af27f5443a0c6fc09dd2384ebb3e2774928fda7213ca9809e552e0010"}, - {file = "mujoco-3.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:431fdb51194f5a6dc1b3c2d625410d7468c40ec1091ac4e4e23081ace47d9a15"}, - {file = "mujoco-3.1.1-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:53ca08b1af724104ceeb307b47131e5f244ebb35ff5b5b38cf4f5f3b6b662b9f"}, - {file = "mujoco-3.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5e6502c1ba6902c276d384fe7dee8a99ca570ef187dc122c60692baf0f068cb"}, - {file = "mujoco-3.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:267458ff744cb1a2265ce2cf3f81ecb096883b2003a647de2d9177bb606514bb"}, - {file = "mujoco-3.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5731c8e6efb739312ece205fa6932d76e8d6ecd78a19c78da58e58b2abe5b591"}, - {file = "mujoco-3.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:0037ea34af70a5553cf516027e76d3f91b13389a4b01679d5d77d8ea0bc4aaf7"}, - {file = "mujoco-3.1.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:70a440463d7ec272085a16057115bd3e2c74c4e91773f4fc809a40edca2b4546"}, - {file = "mujoco-3.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b2f471410896a23a0325b240ab535ea6ba170af1a044ff82f6ac34fb5e17f7d6"}, - {file = "mujoco-3.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50930f8bddb81f23b7c01d2beee9b01bb52827f0413c53dd2ff0b0220688e4a3"}, - {file = "mujoco-3.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31aa58202baeafa9f95dac65dc19c7c04b6b5079eaed65113c66235d08a49a98"}, - {file = "mujoco-3.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:d867792d0ca21720337e17e9dda67ada16d03bdc2c300082140aca7d1a2d01f0"}, - {file = "mujoco-3.1.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:f9d2e6e3cd857662e1eac7b7ff68074b329ab99bda9c0a5020e2aeb242db00e1"}, - {file = "mujoco-3.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ec29474314726a71f60ed2fa519a9f8df332ae23b638368a7833c851ce0fe500"}, - {file = "mujoco-3.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3195aa1bfb96cfce4aaf116baf8b74aee7e479cb3c2427ede4d6f9ad91f7c107"}, - {file = "mujoco-3.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd0ebcfc7f4771aeedb5e66321c00e9c8c4393834722385b4a23401f1eee3e8f"}, - {file = "mujoco-3.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:0e76ebd3030aa32fd755e4ec0c1db069ad0a0fb86184b80c12fe5f2ef822bc56"}, - {file = "mujoco-3.1.1.tar.gz", hash = "sha256:1121273de2fbf4ed309e5944a3db39d01f385b220d20e78c460ec4efc06945b3"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:36513024330f88b5f9a43558efef5692b33599bffd5141029b690a27918ffcbe"}, + {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d4eede8ba8210fbd3d3cd1dbf69e24dd1541aa74c5af5b8adbbbf65504b6dba"}, + {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab85fafc9d5a091c712947573b7e694512d283876bf7f33ae3f8daad3a20c0db"}, + {file = "mujoco-2.3.7-cp310-cp310-win_amd64.whl", hash = "sha256:f8b7e13fef8c813d91b78f975ed0815157692777907ffa4b4be53a4edb75019b"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:779520216f72a8e370e3f0cdd71b45c3b7384c63331a3189194c930a3e7cff5c"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d4018053879016282d27ab7a91e292c72d44efb5a88553feacfe5b843dde103"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:3149b16b8122ee62642474bfd2871064e8edc40235471cf5d84be3569afc0312"}, + {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c08660a8d52ef3efde76095f0991e807703a950c1e882d2bcd984b9a846626f7"}, + {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426af8965f8636d94a0f75740c3024a62b3e585020ee817ef5208ec844a1ad94"}, + {file = "mujoco-2.3.7-cp311-cp311-win_amd64.whl", hash = "sha256:215415a8e98a4b50625beae859079d5e0810b2039e50420f0ba81763c34abb59"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:8b78d14f4c60cea3c58e046bd4de453fb5b9b33aca6a25fc91d39a53f3a5342a"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c6f5a51d6f537a4bf294cf73816f3a6384573f8f10a5452b044df2771412a96"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:ea8911e6047f92d7d775701f37e4c093971b6def3160f01d0b6926e29a7e962e"}, + {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7473a3de4dd1a8762d569ffb139196b4c5e7eca27d256df97b6cd4c66d2a09b2"}, + {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7e2d8f93d2495ec74efec84e5118ecc6e1d85157a844789c73c9ac9a4e28e"}, + {file = "mujoco-2.3.7-cp38-cp38-win_amd64.whl", hash = "sha256:720bc228a2023b3b0ed6af78f5b0f8ea36867be321d473321555c57dbf6e4e5b"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:855e79686366442aa410246043b44f7d842d3900d68fe7e37feb42147db9d707"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98947f4a742d34d36f3c3f83e9167025bb0414bbaa4bd859b0673bdab9959963"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d42818f2ee5d1632dbce31d136ed5ff868db54b04e4e9aca0c5a3ac329f8a90f"}, + {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9237e1ba14bced9449c31199e6d5be49547f3a4c99bc83b196af7ca45fd73b83"}, + {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b728ea638245b150e2650c5433e6952e0ed3798c63e47e264574270caea2a3"}, + {file = "mujoco-2.3.7-cp39-cp39-win_amd64.whl", hash = "sha256:9c721a5042b99d948d5f0296a534bcce3f142c777c4d7642f503a539513f3912"}, + {file = "mujoco-2.3.7.tar.gz", hash = "sha256:422041f1ce37c6d151fbced1048df626837e94fe3cd9f813585907046336a7d0"}, ] [package.dependencies] absl-py = "*" -etils = {version = "*", extras = ["epath"]} glfw = "*" numpy = "*" pyopengl = "*" @@ -3271,6 +3239,7 @@ optional = false python-versions = ">=3" files = [ {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"}, {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, ] @@ -4220,6 +4189,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -5935,6 +5905,31 @@ files = [ {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"}, ] +[[package]] +name = "types-requests" +version = "2.31.0.20240311" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.31.0.20240311.tar.gz", hash = "sha256:b1c1b66abfb7fa79aae09097a811c4aa97130eb8831c60e47aee4ca344731ca5"}, + {file = "types_requests-2.31.0.20240311-py3-none-any.whl", hash = "sha256:47872893d65a38e282ee9f277a4ee50d1b28bd592040df7d1fdaffdf3779937d"}, +] + +[package.dependencies] +urllib3 = ">=2" + +[[package]] +name = "types-tabulate" +version = "0.9.0.20240106" +description = "Typing stubs for tabulate" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-tabulate-0.9.0.20240106.tar.gz", hash = "sha256:c9b6db10dd7fcf55bd1712dd3537f86ddce72a08fd62bb1af4338c7096ce947e"}, + {file = "types_tabulate-0.9.0.20240106-py3-none-any.whl", hash = "sha256:0378b7b6fe0ccb4986299496d027a6d4c218298ecad67199bbd0e2d7e9d335a1"}, +] + [[package]] name = "typing-extensions" version = "4.8.0" @@ -6228,4 +6223,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "0d4ff98ed02fe3f34c0d05b5d175822a82ac08a5ed52e57b7f847a48c302add6" +content-hash = "06b9166b2e752fbab564cbc0dbce226844c26dd2b59f9f7e95104570e377c43b" diff --git a/pyproject.toml b/pyproject.toml index 7a795004a..813cbd335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,8 @@ envpool = { version = "^0.8.2", optional = true, markers = "sys_platform != 'da gymnasium-robotics = { version = "*", optional = true } imageio = { version = ">=2.14.1", optional = true } jsonargparse = {version = "^4.24.1", optional = true} -mujoco = { version = ">=2.1.5", optional = true } +# we need <3 b/c of https://github.com/Farama-Foundation/Gymnasium/issues/749 +mujoco = { version = ">=2.1.5, <3", optional = true } mujoco-py = { version = ">=2.1,<2.2", optional = true } opencv_python = { version = "*", optional = true } pybullet = { version = "*", optional = true } @@ -111,6 +112,8 @@ sphinx-togglebutton = "^0.3.2" sphinx-toolbox = "^3.5.0" sphinxcontrib-bibtex = "*" sphinxcontrib-spelling = "^8.0.0" +types-requests = "^2.31.0.20240311" +types-tabulate = "^0.9.0.20240106" wandb = "^0.12.0" [tool.mypy] @@ -219,6 +222,6 @@ doc-clean = "rm -rf docs/_build" doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build" doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"] -_mypy = "mypy tianshou" +_mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"] From 32951b2bcf5707227a81f184628fe006f6003d43 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 1 Apr 2024 18:24:10 +0200 Subject: [PATCH 094/115] Extended pre-commit type check [skip ci] --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54b9a9794..aa00e7474 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: pass_filenames: false - id: mypy name: mypy - entry: poetry run mypy tianshou + entry: poetry run mypy tianshou examples test # filenames should not be passed as they would collide with the config in pyproject.toml pass_filenames: false files: '^tianshou(/[^/]*)*/[^/]*\.py$' From 6235a37f15cec152565ccc2bc9de68e86d894d8c Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 1 Apr 2024 23:58:01 +0200 Subject: [PATCH 095/115] Typo in test --- test/base/test_collector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d48294212..6baa6abf3 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -549,7 +549,7 @@ def test_collector_with_multi_agent() -> None: c_single_env.buffer.update(c_multi_env_ma.buffer) assert len(c_single_env.buffer) in [42, 43] if len(c_single_env.buffer) == 42: - multi_env_returns = np.ndarray( + multi_env_returns = np.array( [ 0, 0, @@ -596,7 +596,7 @@ def test_collector_with_multi_agent() -> None: ], ) else: - multi_env_returns = np.ndarray( + multi_env_returns = np.array( [ 0, 0, From 4a867bdade24f1f0fb210ef5d8db04042d58adb3 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 11:43:39 +0200 Subject: [PATCH 096/115] Remove type in IndexType for batch --- tianshou/data/batch.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 2e9e5090c..7e4395f5a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -19,13 +19,7 @@ import torch _SingleIndexType = slice | int | EllipsisType -IndexType = ( - np.ndarray - | _SingleIndexType - | list[_SingleIndexType] - | tuple[_SingleIndexType, ...] - | list[int] -) +IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] TBatch = TypeVar("TBatch", bound="BatchProtocol") arr_type = torch.Tensor | np.ndarray From 528eb10fa3af6c75c4c7cdb585b73afbbee83d5e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:01:30 +0200 Subject: [PATCH 097/115] Use assert hasattr instead of getattr --- test/base/test_stats.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/test/base/test_stats.py b/test/base/test_stats.py index e13f144d7..9776374ba 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -38,8 +38,12 @@ def test_training_stats_wrapper() -> None: # existing fields, wrapped and not-wrapped, can be mutated wrapped_train_stats.loss_field = 13 wrapped_train_stats.dummy_field = 43 - assert ( - getattr(wrapped_train_stats.wrapped_stats, "loss_field") # noqa: B009 - == getattr(wrapped_train_stats, "loss_field") # noqa: B009 - == 13 - ) + assert hasattr( + wrapped_train_stats.wrapped_stats, + "loss_field", + ), "Attribute `loss_field` not found in `wrapped_train_stats.wrapped_stats`." + assert hasattr( + wrapped_train_stats, + "loss_field", + ), "Attribute `loss_field` not found in `wrapped_train_stats`." + assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13 From d10b8b271333bb7acdea76df5520d7195d2950c7 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:06:04 +0200 Subject: [PATCH 098/115] Remove iter since Batch already implements __iter__ --- test/base/test_returns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 0196a415b..23f50fb22 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -71,7 +71,7 @@ def test_episodic_returns(size: int = 2560) -> None: truncated=np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), rew=np.array([101, 102, 103.0, 200, 104, 105, 106, 201, 107, 108, 109, 202]), ) - for b in iter(batch): + for b in batch: b.obs = b.act = 1 buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) From 5ed5d500fa7be3c9ebfccb92382ecbdffa2e9aff Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:12:38 +0200 Subject: [PATCH 099/115] Use Literal instead of asserting members of list --- examples/mujoco/plotter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py index ee86f712f..5e2f9e016 100755 --- a/examples/mujoco/plotter.py +++ b/examples/mujoco/plotter.py @@ -3,7 +3,7 @@ import argparse import os import re -from typing import Any +from typing import Any, Literal import matplotlib.pyplot as plt import matplotlib.ticker as mticker @@ -14,7 +14,7 @@ def smooth( y: np.ndarray, radius: int, - mode: str = "two_sided", + mode: Literal["two_sided", "causal"] = "two_sided", valid_only: bool = False, ) -> np.ndarray: """Smooth signal y, where radius is determines the size of the window. @@ -25,7 +25,6 @@ def smooth( average over the window [max(index - radius, 0), index] valid_only: put nan in entries where the full-sized window is not available """ - assert mode in ("two_sided", "causal") if len(y) < 2 * radius + 1: return np.ones_like(y) * y.mean() if mode == "two_sided": From 7a18c4d0b378992b542519f54c8db9612542ab03 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:57:23 +0200 Subject: [PATCH 100/115] Use stop_fn for running this example * Change default num epochs 1000 -> 25 --- examples/box2d/bipedal_bdq.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 8be3bac00..f52f6d5c1 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -35,7 +35,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--target-update-freq", type=int, default=1000) - parser.add_argument("--epoch", type=int, default=1000) + parser.add_argument("--epoch", type=int, default=25) parser.add_argument("--step-per-epoch", type=int, default=80000) parser.add_argument("--step-per-collect", type=int, default=16) parser.add_argument("--update-per-step", type=float, default=0.0625) @@ -150,14 +150,14 @@ def test_fn(epoch: int, env_step: int | None) -> None: episode_per_test=args.test_num, batch_size=args.batch_size, update_per_step=args.update_per_step, - # stop_fn=stop_fn, + stop_fn=stop_fn, train_fn=train_fn, test_fn=test_fn, save_best_fn=save_best_fn, logger=logger, ).run() - # assert stop_fn(result.best_reward) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! From 4fb294d352a9e89b6852980373a8caccad6613e8 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:58:41 +0200 Subject: [PATCH 101/115] Use ValueError to inform user about what type of env is supported. --- examples/mujoco/fetch_her_ddpg.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 501d6c180..be6594aa7 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -117,8 +117,15 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_fetch_env(args.task, args.training_num, args.test_num) # The method HER works with goal-based environments - assert isinstance(env.observation_space, gym.spaces.Dict) - assert hasattr(env, "compute_reward") + if not isinstance(env.observation_space, gym.spaces.Dict): + raise ValueError( + "`env.observation_space` must be of type `gym.spaces.Dict`. Make sure you're using a goal-based environment like `FetchReach-v2`.", + ) + if not hasattr(env, "compute_reward"): + raise ValueError( + "Atrribute `compute_reward` not found in `env`. " + "HER-based algorithms typically require this attribute. Make sure you're using a goal-based environment like `FetchReach-v2`.", + ) args.state_shape = { "observation": env.observation_space["observation"].shape, "achieved_goal": env.observation_space["achieved_goal"].shape, From 2ace3ef6263e40bceac0abeabaf68728fc3ab05f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 13:29:19 +0200 Subject: [PATCH 102/115] Use more specific type hint for policy to get access to policy-specific attrs and methods --- docs/02_notebooks/L0_overview.ipynb | 15 +++++++-------- docs/02_notebooks/L5_Collector.ipynb | 5 ++--- docs/02_notebooks/L6_Trainer.ipynb | 5 ++--- docs/02_notebooks/L7_Experiment.ipynb | 5 ++--- test/continuous/test_sac_with_il.py | 4 ++-- test/continuous/test_trpo.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_fqf.py | 2 +- test/discrete/test_pg.py | 2 +- test/offline/test_discrete_cql.py | 2 +- test/offline/test_gail.py | 2 +- 11 files changed, 21 insertions(+), 25 deletions(-) diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index 37cba0be5..59d6fd207 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -17,14 +17,14 @@ }, { "cell_type": "code", - "outputs": [], - "source": [ - "# !pip install tianshou gym" - ], + "execution_count": null, "metadata": { "collapsed": false }, - "execution_count": 0 + "outputs": [], + "source": [ + "# !pip install tianshou gym" + ] }, { "cell_type": "markdown", @@ -71,7 +71,7 @@ "\n", "from tianshou.data import Collector, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import BasePolicy, PPOPolicy\n", + "from tianshou.policy import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import ActorCritic, Net\n", "from tianshou.utils.net.discrete import Actor, Critic\n", @@ -106,8 +106,7 @@ "\n", "# PPO policy\n", "dist = torch.distributions.Categorical\n", - "policy: BasePolicy\n", - "policy = PPOPolicy(\n", + "policy: PPOPolicy = PPOPolicy(\n", " actor=actor,\n", " critic=critic,\n", " optim=optim,\n", diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 3e91e0f43..7da98a5cf 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -60,7 +60,7 @@ "\n", "from tianshou.data import Collector, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import BasePolicy, PGPolicy\n", + "from tianshou.policy import PGPolicy\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor" ] @@ -87,8 +87,7 @@ "actor = Actor(net, env.action_space.n)\n", "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", "\n", - "policy: BasePolicy\n", - "policy = PGPolicy(\n", + "policy: PGPolicy = PGPolicy(\n", " actor=actor,\n", " optim=optim,\n", " dist_fn=torch.distributions.Categorical,\n", diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index d0f1ebf4d..75aea471c 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -75,7 +75,7 @@ "\n", "from tianshou.data import Collector, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import BasePolicy, PGPolicy\n", + "from tianshou.policy import PGPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor" @@ -110,9 +110,8 @@ "actor = Actor(net, env.action_space.n)\n", "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n", "\n", - "policy: BasePolicy\n", "# We choose to use REINFORCE algorithm, also known as Policy Gradient\n", - "policy = PGPolicy(\n", + "policy: PGPolicy = PGPolicy(\n", " actor=actor,\n", " optim=optim,\n", " dist_fn=torch.distributions.Categorical,\n", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 46c065c75..9a97b20cb 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -73,7 +73,7 @@ "\n", "from tianshou.data import Collector, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import BasePolicy, PPOPolicy\n", + "from tianshou.policy import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import ActorCritic, Net\n", "from tianshou.utils.net.discrete import Actor, Critic\n", @@ -164,8 +164,7 @@ "outputs": [], "source": [ "dist = torch.distributions.Categorical\n", - "policy: BasePolicy\n", - "policy = PPOPolicy(\n", + "policy: PPOPolicy = PPOPolicy(\n", " actor=actor,\n", " critic=critic,\n", " optim=optim,\n", diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index c347f75b0..5f09416e8 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -110,7 +110,7 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: BasePolicy = SACPolicy( + policy: SACPolicy = SACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -176,7 +176,7 @@ def stop_fn(mean_rewards: float) -> bool: device=args.device, ).to(args.device) optim = torch.optim.Adam(il_actor.parameters(), lr=args.il_lr) - il_policy: BasePolicy = ImitationPolicy( + il_policy: ImitationPolicy = ImitationPolicy( actor=il_actor, optim=optim, action_space=env.action_space, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 9de81283c..ae788d1cc 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -106,7 +106,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: BasePolicy = TRPOPolicy( + policy: TRPOPolicy = TRPOPolicy( actor=actor, critic=critic, optim=optim, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 9dec477f3..52452b675 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -87,7 +87,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # dueling=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BasePolicy = DQNPolicy( + policy: DQNPolicy = DQNPolicy( model=net, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index ca2bc0420..f1af574f7 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -100,7 +100,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) - policy: BasePolicy = FQFPolicy( + policy: FQFPolicy = FQFPolicy( model=net, optim=optim, fraction_model=fraction_net, diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 795086d1b..77b48b5ea 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -76,7 +76,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist_fn = torch.distributions.Categorical - policy: BasePolicy = PGPolicy( + policy: PGPolicy = PGPolicy( actor=net, optim=optim, dist_fn=dist_fn, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index e0f336807..a678b8321 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -79,7 +79,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BasePolicy = DiscreteCQLPolicy( + policy: DiscreteCQLPolicy = DiscreteCQLPolicy( model=net, optim=optim, action_space=env.action_space, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 68fab728f..26281bedd 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -137,7 +137,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: BasePolicy = GAILPolicy( + policy: GAILPolicy = GAILPolicy( actor=actor, critic=critic, optim=optim, From 6e5038908e717ca9dd15beed645348d5dea6b55f Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:37:13 +0200 Subject: [PATCH 103/115] Refactor type annotation make_vizdoom_env --- examples/vizdoom/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 99c5b6f3e..fa707387c 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -135,7 +135,7 @@ def make_vizdoom_env( seed: int | None = None, training_num: int = 10, test_num: int = 10, -) -> tuple[Any | Env, Any | ShmemVectorEnv, Any | ShmemVectorEnv]: +) -> tuple[Env, ShmemVectorEnv, ShmemVectorEnv]: cpu_count = os.cpu_count() if cpu_count is not None: test_num = min(cpu_count - 1, test_num) From 1ac275bf958ff0a2bdc9de16bb0b05cb1c9da199 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:59:09 +0200 Subject: [PATCH 104/115] Use kw-args for better readability --- examples/atari/atari_network.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 0d6461f71..95f4364f7 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -262,10 +262,10 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: action_shape = int(action_shape) net: nn.Module net = DQN( - c, - h, - w, - action_shape, + c=c, + h=h, + w=w, + action_shape=action_shape, device=device, features_only=self.features_only, output_dim=self.hidden_size @@ -293,10 +293,10 @@ def create_intermediate_module(self, envs: Environments, device: TDevice) -> Int if isinstance(action_shape, np.int64): action_shape = int(action_shape) dqn = DQN( - c, - h, - w, - action_shape, + c=c, + h=h, + w=w, + action_shape=action_shape, device=device, features_only=self.features_only, ).to(device) From 322b6aaeac0ef319ae87c6ff6caa70e574906c83 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 17:03:15 +0200 Subject: [PATCH 105/115] Use os.path.join --- examples/vizdoom/replay.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/vizdoom/replay.py b/examples/vizdoom/replay.py index ac6e183e4..45f9df671 100755 --- a/examples/vizdoom/replay.py +++ b/examples/vizdoom/replay.py @@ -1,4 +1,5 @@ # import cv2 +import os import sys import time @@ -6,7 +7,10 @@ import vizdoom as vzd -def main(cfg_path: str = "maps/D3_battle.cfg", lmp_path: str = "test.lmp") -> None: +def main( + cfg_path: str = os.path.join("maps", "D3_battle.cfg"), + lmp_path: str = os.path.join("test.lmp"), +) -> None: game = vzd.DoomGame() game.load_config(cfg_path) game.set_screen_format(vzd.ScreenFormat.CRCGCB) From 0e2babbe5c405e451abb5ef4fd5f9f6afec55f9b Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 2 Apr 2024 17:14:32 +0200 Subject: [PATCH 106/115] Remove if and assert hasattr beforehand --- test/base/test_buffer.py | 50 ++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index cc2ad01d6..31265f664 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -166,31 +166,37 @@ def test_ignore_obs_next(size: int = 10) -> None: assert isinstance(data, Batch) assert isinstance(data2, Batch) assert np.allclose(indices, orig) - if hasattr(data.obs_next, "mask") and hasattr(data2.obs_next, "mask"): - assert np.allclose(data.obs_next.mask, data2.obs_next.mask) - assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) + assert hasattr(data.obs_next, "mask") and hasattr( + data2.obs_next, + "mask", + ), "Both `data.obs_next` and `data2.obs_next` must have attribute `mask`." + assert np.allclose(data.obs_next.mask, data2.obs_next.mask) + assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) buf.stack_num = 4 data = buf[indices] data2 = buf[indices] - if hasattr(data.obs_next, "mask") and hasattr(data2.obs_next, "mask"): - assert np.allclose(data.obs_next.mask, data2.obs_next.mask) - assert np.allclose( - data.obs_next.mask, - np.array( - [ - [0, 0, 0, 0], - [1, 1, 1, 2], - [1, 1, 2, 3], - [1, 1, 2, 3], - [4, 4, 4, 5], - [4, 4, 5, 6], - [4, 4, 5, 6], - [7, 7, 7, 8], - [7, 7, 8, 9], - [7, 7, 8, 9], - ], - ), - ) + assert hasattr(data.obs_next, "mask") and hasattr( + data2.obs_next, + "mask", + ), "Both `data.obs_next` and `data2.obs_next` must have attribute `mask`." + assert np.allclose(data.obs_next.mask, data2.obs_next.mask) + assert np.allclose( + data.obs_next.mask, + np.array( + [ + [0, 0, 0, 0], + [1, 1, 1, 2], + [1, 1, 2, 3], + [1, 1, 2, 3], + [4, 4, 4, 5], + [4, 4, 5, 6], + [4, 4, 5, 6], + [7, 7, 7, 8], + [7, 7, 8, 9], + [7, 7, 8, 9], + ], + ), + ) assert np.allclose(data["info"]["if"], data2["info"]["if"]) assert np.allclose( data["info"]["if"], From 71a4006e5308edabf641b2a344b2a242d61fd24e Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 3 Apr 2024 00:06:32 +0200 Subject: [PATCH 107/115] Return non-empty dict when reset --- examples/vizdoom/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index fa707387c..2869acd1a 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -94,7 +94,7 @@ def reset( self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH) self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) - return self.obs_buffer, {} + return self.obs_buffer, {"TimeLimit.truncated": False} def step(self, action: int) -> tuple[NDArray[np.uint8], float, bool, bool, dict[str, Any]]: self.game.make_action(self.available_actions[action], self.skip) From 10669a6ef24166958a181a732f522fc3b59cfe22 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 3 Apr 2024 00:18:44 +0200 Subject: [PATCH 108/115] Refactor ActorFactoryAtariDQN hidden_size semantics and output_dim of APIs using it. * `hidden_size` is replaced with `output_dim_added_layer` to better reflect its purpose, which is to add an output layer to DQN body net or not (based on whether `features_only`) --- examples/atari/atari_network.py | 30 +++++++++++++++++------------- examples/atari/atari_ppo.py | 2 +- examples/atari/atari_ppo_hl.py | 5 ++--- examples/atari/atari_sac.py | 2 +- examples/atari/atari_sac_hl.py | 7 ++++--- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 95f4364f7..4ec231d0a 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -61,9 +61,14 @@ def __init__( action_shape: Sequence[int] | int, device: str | int | torch.device = "cpu", features_only: bool = False, - output_dim: int | None = None, + output_dim_added_layer: int | None = None, layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, ) -> None: + # TODO: Add docstring + if features_only and output_dim_added_layer is not None: + raise ValueError( + "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", + ) super().__init__() self.device = device self.net = nn.Sequential( @@ -76,23 +81,24 @@ def __init__( nn.Flatten(), ) with torch.no_grad(): - self.output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])) + base_cnn_output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])) if not features_only: action_dim = int(np.prod(action_shape)) self.net = nn.Sequential( self.net, - layer_init(nn.Linear(self.output_dim, 512)), + layer_init(nn.Linear(base_cnn_output_dim, 512)), nn.ReLU(inplace=True), layer_init(nn.Linear(512, action_dim)), ) self.output_dim = action_dim - elif output_dim is not None: + elif output_dim_added_layer is not None: self.net = nn.Sequential( self.net, - layer_init(nn.Linear(self.output_dim, output_dim)), + layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)), nn.ReLU(inplace=True), ) - self.output_dim = output_dim + else: + self.output_dim = base_cnn_output_dim def forward( self, @@ -243,11 +249,11 @@ def forward( class ActorFactoryAtariDQN(ActorFactory): def __init__( self, - hidden_size: int | Sequence[int], - scale_obs: bool, - features_only: bool, + scale_obs: bool = True, + features_only: bool = False, + output_dim_added_layer: int | None = None, ) -> None: - self.hidden_size = hidden_size + self.output_dim_added_layer = output_dim_added_layer self.scale_obs = scale_obs self.features_only = features_only @@ -268,9 +274,7 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: action_shape=action_shape, device=device, features_only=self.features_only, - output_dim=self.hidden_size - if isinstance(self.hidden_size, int) - else self.hidden_size[-1], + output_dim_added_layer=self.output_dim_added_layer, layer_init=layer_init, ) if self.scale_obs: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 461926f90..612b54008 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -115,7 +115,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.action_shape, device=args.device, features_only=True, - output_dim=args.hidden_size, + output_dim_added_layer=args.hidden_size, layer_init=layer_init, ) if args.scale_obs: diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 03272aa5b..53393b05a 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -34,7 +34,7 @@ def main( step_per_collect: int = 1000, repeat_per_collect: int = 4, batch_size: int = 256, - hidden_sizes: int | Sequence[int] = 512, + hidden_sizes: Sequence[int] = (512,), training_num: int = 10, test_num: int = 10, rew_norm: bool = False, @@ -93,12 +93,11 @@ def main( else None, ), ) - .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True)) + .with_actor_factory(ActorFactoryAtariDQN(scale_obs=scale_obs, features_only=True)) .with_critic_factory_use_actor() .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: - hidden_sizes = [hidden_sizes] if isinstance(hidden_sizes, int) else hidden_sizes builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 78d05e7be..d5edf1a9a 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -108,7 +108,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: args.action_shape, device=args.device, features_only=True, - output_dim=args.hidden_size, + output_dim_added_layer=args.hidden_size, ) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 127156777..dd49f49a7 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os +from collections.abc import Sequence from examples.atari.atari_network import ( ActorFactoryAtariDQN, @@ -39,7 +40,7 @@ def main( step_per_collect: int = 10, update_per_step: float = 0.1, batch_size: int = 64, - hidden_size: int = 512, + hidden_sizes: Sequence[int] = (512,), training_num: int = 10, test_num: int = 10, frames_stack: int = 4, @@ -80,7 +81,7 @@ def main( estimation_step=n_step, ), ) - .with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True)) + .with_actor_factory(ActorFactoryAtariDQN(scale_obs=False, features_only=True)) .with_common_critic_factory_use_actor() .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) @@ -88,7 +89,7 @@ def main( builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), - hidden_sizes=[hidden_size], + hidden_sizes=hidden_sizes, lr=actor_lr, lr_scale=icm_lr_scale, reward_scale=icm_reward_scale, From 5567ce01e1174fc521d6fc4e63b968dd7db505bd Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 3 Apr 2024 00:26:47 +0200 Subject: [PATCH 109/115] Refactor type annotations of scale_obs: * it operates on Nets not Modules --- examples/atari/atari_network.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 4ec231d0a..e989dd04d 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -14,6 +14,7 @@ IntermediateModule, IntermediateModuleFactory, ) +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, NoisyLinear @@ -24,7 +25,7 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0. class ScaledObsInputModule(torch.nn.Module): - def __init__(self, module: torch.nn.Module, denom: float = 255.0) -> None: + def __init__(self, module: Net, denom: float = 255.0) -> None: super().__init__() self.module = module self.denom = denom @@ -42,7 +43,7 @@ def forward( return self.module.forward(obs / self.denom, state, info) -def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module: +def scale_obs(module: Net, denom: float = 255.0) -> ScaledObsInputModule: return ScaledObsInputModule(module, denom=denom) @@ -266,7 +267,7 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) - net: nn.Module + net: DQN | ScaledObsInputModule net = DQN( c=c, h=h, From 1c79d19cb0115393760987aadc7eadbc1b2885f2 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 3 Apr 2024 00:28:49 +0200 Subject: [PATCH 110/115] Simplify checks of obs_shape for atari envs --- examples/atari/atari_network.py | 6 +----- examples/offline/atari_cql.py | 7 +++++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index e989dd04d..b1f9874b1 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -259,11 +259,7 @@ def __init__( self.features_only = features_only def create_module(self, envs: Environments, device: TDevice) -> Actor: - obs_shape = envs.get_observation_shape() - if isinstance(obs_shape, int): - obs_shape = [obs_shape] - assert len(obs_shape) == 3 - c, h, w = obs_shape + c, h, w = envs.get_observation_shape() # type: ignore # only right shape is a sequence of length 3 action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index d1f9bc410..fc328761c 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -86,8 +86,11 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - assert isinstance(args.state_shape, list[int] | tuple[int]) - assert len(args.state_shape) == 3 + assert isinstance( + args.state_shape, + list[int] | tuple[int], + ), "state shape must be a sequence of ints." + assert len(args.state_shape) == 3, "state shape must have only 3 dimensions." c, h, w = args.state_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) From 2b6722f1c05419e837f54e2d3a3e14153a3faea2 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 3 Apr 2024 00:31:17 +0200 Subject: [PATCH 111/115] Use kw for input arguments to QRDQN --- examples/offline/atari_cql.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index fc328761c..0a619e23d 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -99,7 +99,14 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = QRDQN(c, h, w, args.action_shape, args.num_quantiles, args.device) + net = QRDQN( + c=c, + h=h, + w=w, + action_shape=args.action_shape, + num_quantiles=args.num_quantiles, + device=args.device, + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: DiscreteCQLPolicy = DiscreteCQLPolicy( From e8ba5add5be18eae748494e38278cb76695c6ab0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 3 Apr 2024 15:10:38 +0200 Subject: [PATCH 112/115] Made NetBase generic (explanation below), removed **kwargs from forward The generic was needed because of the type of "state", representing the hidden state The only place where it is ever used is in recurrent version of DQN --- examples/atari/atari_network.py | 21 +++++------ examples/atari/atari_qrdqn.py | 10 +++++- examples/box2d/acrobot_dualdqn.py | 4 +-- examples/box2d/bipedal_hardcore_sac.py | 17 +++++---- examples/box2d/lunarlander_dqn.py | 4 +-- examples/box2d/mcc_sac.py | 10 +++--- examples/inverse/irl_gail.py | 7 +++- examples/mujoco/mujoco_ddpg.py | 6 ++-- examples/mujoco/mujoco_redq.py | 6 ++-- examples/mujoco/mujoco_sac.py | 10 +++--- examples/mujoco/mujoco_td3.py | 10 +++--- examples/offline/atari_cql.py | 6 ++-- examples/offline/d4rl_bcq.py | 8 ++--- examples/offline/d4rl_cql.py | 12 +++---- examples/offline/d4rl_il.py | 4 +-- examples/offline/d4rl_td3_bc.py | 8 ++--- test/continuous/test_ddpg.py | 6 ++-- test/continuous/test_ppo.py | 4 +-- test/continuous/test_redq.py | 6 ++-- test/continuous/test_sac_with_il.py | 10 +++--- test/continuous/test_td3.py | 10 +++--- test/discrete/test_a2c_with_il.py | 4 +-- test/discrete/test_c51.py | 4 +-- test/discrete/test_dqn.py | 4 +-- test/discrete/test_pg.py | 4 +-- test/discrete/test_ppo.py | 2 +- test/discrete/test_qrdqn.py | 4 +-- test/discrete/test_rainbow.py | 4 +-- test/discrete/test_sac.py | 4 +-- test/modelbased/test_dqn_icm.py | 4 +-- test/modelbased/test_ppo_icm.py | 2 +- test/offline/gather_cartpole_data.py | 4 +-- test/offline/gather_pendulum_data.py | 6 ++-- test/offline/test_bcq.py | 4 +-- test/offline/test_cql.py | 8 ++--- test/offline/test_discrete_bcq.py | 2 +- test/offline/test_discrete_cql.py | 4 +-- test/offline/test_discrete_crr.py | 6 ++-- test/offline/test_gail.py | 13 ++++--- test/offline/test_td3_bc.py | 8 ++--- test/pettingzoo/pistonball.py | 4 +-- test/pettingzoo/tic_tac_toe.py | 4 +-- tianshou/highlevel/module/actor.py | 14 ++++---- tianshou/highlevel/module/critic.py | 6 ++-- tianshou/utils/net/common.py | 48 ++++++++++++++++---------- 45 files changed, 187 insertions(+), 159 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index b1f9874b1..ea900e975 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -14,7 +14,7 @@ IntermediateModule, IntermediateModuleFactory, ) -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import NetBase from tianshou.utils.net.discrete import Actor, NoisyLinear @@ -25,7 +25,7 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0. class ScaledObsInputModule(torch.nn.Module): - def __init__(self, module: Net, denom: float = 255.0) -> None: + def __init__(self, module: NetBase, denom: float = 255.0) -> None: super().__init__() self.module = module self.denom = denom @@ -43,11 +43,11 @@ def forward( return self.module.forward(obs / self.denom, state, info) -def scale_obs(module: Net, denom: float = 255.0) -> ScaledObsInputModule: +def scale_obs(module: NetBase, denom: float = 255.0) -> ScaledObsInputModule: return ScaledObsInputModule(module, denom=denom) -class DQN(nn.Module): +class DQN(NetBase[Any]): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -106,10 +106,9 @@ def forward( obs: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*).""" - if info is None: - info = {} obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) return self.net(obs), state @@ -139,10 +138,9 @@ def forward( obs: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - if info is None: - info = {} obs, state = super().forward(obs) obs = obs.view(-1, self.num_atoms).softmax(dim=-1) obs = obs.view(-1, self.action_num, self.num_atoms) @@ -196,10 +194,9 @@ def forward( obs: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - if info is None: - info = {} obs, state = super().forward(obs) q = self.Q(obs) q = q.view(-1, self.action_num, self.num_atoms) @@ -222,6 +219,7 @@ class QRDQN(DQN): def __init__( self, + *, c: int, h: int, w: int, @@ -238,10 +236,9 @@ def forward( obs: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - if info is None: - info = {} obs, state = super().forward(obs) obs = obs.view(-1, self.action_num, self.num_quantiles) return obs, state diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index f33c5da5b..7d6330ee1 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -82,7 +82,15 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) + c, h, w = args.state_shape + net = QRDQN( + c=c, + h=h, + w=w, + action_shape=args.action_shape, + num_quantiles=args.num_quantiles, + device=args.device, + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: QRDQNPolicy = QRDQNPolicy( diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index f28715bd5..ad53b16da 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -68,8 +68,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, dueling_param=(Q_param, V_param), diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index e7186915b..2c071bc1c 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -108,13 +108,18 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net_a, args.action_shape, device=args.device, unbounded=True).to(args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + preprocess_net=net_a, + action_shape=args.action_shape, + device=args.device, + unbounded=True, + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -123,8 +128,8 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 778c932cf..47ba9d102 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -70,8 +70,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, dueling_param=(Q_param, V_param), diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 858c70834..5c093093e 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -66,12 +66,12 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -79,8 +79,8 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 3ee3709bd..2d013a01b 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -120,7 +120,12 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb(net_a, args.action_shape, unbounded=True, device=args.device).to(args.device) + actor = ActorProb( + preprocess_net=net_a, + action_shape=args.action_shape, + unbounded=True, + device=args.device, + ).to(args.device) net_c = Net( args.state_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index ec065e728..ceac47604 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -83,14 +83,14 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 8eb9c3658..b300e498a 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -86,7 +86,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net_a, args.action_shape, @@ -100,8 +100,8 @@ def linear(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(args.ensemble_size, x, y) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index ed37a6025..2058a71e9 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -83,7 +83,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net_a, args.action_shape, @@ -93,15 +93,15 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 82b5e3bcc..30e7539c1 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -88,21 +88,21 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 0a619e23d..40d91c1bb 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -6,6 +6,7 @@ import pickle import pprint import sys +from collections.abc import Sequence import numpy as np import torch @@ -86,10 +87,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - assert isinstance( - args.state_shape, - list[int] | tuple[int], - ), "state shape must be a sequence of ints." + assert isinstance(args.state_shape, Sequence) assert len(args.state_shape) == 3, "state shape must have only 3 dimensions." c, h, w = args.state_shape # should be N_FRAMES x H x W diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 7c275b555..80b233cb7 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -109,15 +109,15 @@ def test_bcq() -> None: actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 11386d683..7ca8ae2fb 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -244,8 +244,8 @@ def test_cql() -> None: # model # actor network net_a = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) @@ -260,15 +260,15 @@ def test_cql() -> None: # critic network net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index b7153ed11..c2152a711 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -83,8 +83,8 @@ def test_il() -> None: # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index b2a0b24c8..4d6159ff5 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -117,15 +117,15 @@ def test_td3_bc() -> None: # critic network net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index e2de17e85..a17c3b513 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -73,14 +73,14 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 38ddbe8f0..5a522dedb 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -84,10 +84,10 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) critic = Critic( - Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), + Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device, ).to(args.device) actor_critic = ActorCritic(actor, critic) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index f8e098d32..697b59e98 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -79,7 +79,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net, args.action_shape, @@ -93,8 +93,8 @@ def linear(x: int, y: int) -> nn.Module: return EnsembleLinear(args.ensemble_size, x, y) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 5f09416e8..fd5b15a9f 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -81,12 +81,12 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed + args.training_num) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -94,8 +94,8 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index fb1e28a83..ea55da052 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -75,14 +75,14 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -90,8 +90,8 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 39435cc2d..f60857ea4 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -84,7 +84,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) @@ -154,7 +154,7 @@ def stop_fn(mean_rewards: float) -> bool: # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v0': # env.spec.reward_threshold = 190 # lower the goal - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(actor.parameters(), lr=args.il_lr) il_policy: ImitationPolicy = ImitationPolicy( diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 0f1555859..483aca9c6 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -85,8 +85,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=True, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 52452b675..6c588839f 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -80,8 +80,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, # dueling=(Q_param, V_param), diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 77b48b5ea..95db43c23 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -68,8 +68,8 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=True, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 63ef55122..132cbea5a 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -80,7 +80,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor: nn.Module critic: nn.Module if torch.cuda.is_available(): diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 081a88300..879717a75 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -80,8 +80,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 23ec7de41..c7035345e 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -91,8 +91,8 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: return NoisyLinear(x, y, args.noisy_std) net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=True, diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 8a397ef8c..b2f466f3d 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -76,10 +76,10 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: # model obs_dim = space_info.observation_info.obs_dim action_dim = space_info.action_info.action_dim - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, softmax_output=False, device=args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) critic1 = Critic(net_c1, last_size=action_dim, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net(obs_dim, hidden_sizes=args.hidden_sizes, device=args.device) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index ee800a287..9ca9c7055 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -97,8 +97,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, # dueling=(Q_param, V_param), diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 3095f7cc9..ebf93cd5a 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -99,7 +99,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index b11b35702..e8411b2dd 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -84,8 +84,8 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index e7a76221a..bc46ce4da 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -91,7 +91,7 @@ def gather_data() -> VectorReplayBuffer: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net, args.action_shape, @@ -100,8 +100,8 @@ def gather_data() -> VectorReplayBuffer: ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 425f70b25..1839d863a 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -114,8 +114,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 93eb6e5de..1e31b1feb 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -109,8 +109,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: # model # actor network net_a = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) @@ -125,8 +125,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: # critic network net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index e993e76c9..77790808b 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -72,7 +72,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) policy_net = Actor( net, args.action_shape, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index a678b8321..7323eac13 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -70,8 +70,8 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 6f762d746..b3cb64616 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -67,10 +67,10 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) actor = Actor( - net, - args.action_shape, + preprocess_net=net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax_output=False, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 26281bedd..256140c41 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -97,12 +97,17 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, device=args.device).to( + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + preprocess_net=net, + action_shape=args.action_shape, + max_action=args.max_action, + device=args.device, + ).to( args.device, ) critic = Critic( - Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), + Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device, ).to(args.device) actor_critic = ActorCritic(actor, critic) @@ -115,7 +120,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: # discriminator disc_net = Critic( Net( - args.state_shape, + state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=torch.nn.Tanh, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 961af2ab3..18778563c 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -114,15 +114,15 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: # critic network net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index d68989adf..7b3fb4dfc 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -91,8 +91,8 @@ def get_agents( for _ in range(args.n_pistons): # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ).to(args.device) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 21720d05e..da580f358 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -108,8 +108,8 @@ def get_agents( if agent_learn is None: # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ).to(args.device) diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index faaaa68d0..867ece17a 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -147,14 +147,14 @@ def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, ) return continuous.Actor( - net_a, - envs.get_action_shape(), + preprocess_net=net_a, + action_shape=envs.get_action_shape(), hidden_sizes=(), device=device, ).to(device) @@ -182,14 +182,14 @@ def __init__( def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, ) actor = continuous.ActorProb( - net_a, - envs.get_action_shape(), + preprocess_net=net_a, + action_shape=envs.get_action_shape(), unbounded=self.unbounded, device=device, conditioned_sigma=self.conditioned_sigma, @@ -216,7 +216,7 @@ def __init__( def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 6d3a7b107..f1984e4d7 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -120,7 +120,7 @@ def create_module( ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, @@ -146,7 +146,7 @@ def create_module( ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, @@ -275,7 +275,7 @@ def linear_layer(x: int, y: int) -> EnsembleLinear: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index dabe24e75..3fb300261 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import Any, TypeAlias, TypeVar, no_type_check +from typing import Any, Generic, TypeAlias, TypeVar, cast, no_type_check import numpy as np import torch @@ -140,20 +140,23 @@ def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: return self.model(obs) -class NetBase(nn.Module, ABC): +TRecurrentState = TypeVar("TRecurrentState", bound=Any) + + +class NetBase(nn.Module, Generic[TRecurrentState], ABC): """Interface for NNs used in policies.""" @abstractmethod def forward( self, obs: np.ndarray | torch.Tensor, - state: Any = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: + state: TRecurrentState | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, TRecurrentState | None]: pass -class Net(NetBase): +class Net(NetBase[Any]): """Wrapper of MLP to support more specific DRL usage. For advanced usage (how to customize the network), please refer to @@ -259,13 +262,13 @@ def forward( self, obs: np.ndarray | torch.Tensor, state: Any = None, - **kwargs: Any, + info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, Any]: """Mapping: obs -> flatten (inside MLP)-> logits. :param obs: :param state: unused and returned as is - :param kwargs: unused + :param info: unused """ logits = self.model(obs) batch_size = logits.shape[0] @@ -284,7 +287,7 @@ def forward( return logits, state -class Recurrent(NetBase): +class Recurrent(NetBase[RecurrentStateBatch]): """Simple Recurrent network based on LSTM. For advanced usage (how to customize the network), please refer to @@ -313,9 +316,9 @@ def __init__( def forward( self, obs: np.ndarray | torch.Tensor, - state: RecurrentStateBatch | dict[str, torch.Tensor] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + state: RecurrentStateBatch | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, RecurrentStateBatch]: """Mapping: obs -> flatten -> logits. In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the @@ -324,7 +327,7 @@ def forward( :param obs: :param state: either None or a dict with keys 'hidden' and 'cell' - :param kwargs: unused + :param info: unused :return: predicted action, next state as dict with keys 'hidden' and 'cell' """ # Note: the original type of state is Batch but it might also be a dict @@ -357,10 +360,16 @@ def forward( ) obs = self.fc2(obs[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] - return obs, { - "hidden": hidden.transpose(0, 1).detach(), - "cell": cell.transpose(0, 1).detach(), - } + rnn_state_batch = cast( + RecurrentStateBatch, + Batch( + { + "hidden": hidden.transpose(0, 1).detach(), + "cell": cell.transpose(0, 1).detach(), + }, + ), + ) + return obs, rnn_state_batch class ActorCritic(nn.Module): @@ -439,7 +448,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class BranchingNet(NetBase): +#TODO: fix docstring +class BranchingNet(NetBase[Any]): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module @@ -539,7 +549,7 @@ def forward( self, obs: np.ndarray | torch.Tensor, state: Any = None, - **kwargs: Any, + info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, Any]: """Mapping: obs -> model -> logits.""" common_out = self.common(obs) From e4d7d2f6a3db8ae63c7b3a5728d52076d2d4455a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 3 Apr 2024 15:46:23 +0200 Subject: [PATCH 113/115] SamplingConfig: support for batch_size=None --- examples/mujoco/mujoco_reinforce_hl.py | 2 +- tianshou/highlevel/config.py | 7 ++++++- tianshou/utils/net/common.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 5651ee1b8..bc07e050b 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -29,7 +29,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 2048, repeat_per_collect: int = 1, - batch_size: int = 16, + batch_size: int | None = None, training_num: int = 10, test_num: int = 10, rew_norm: bool = True, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 48dde374d..43a4db2e9 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -37,12 +37,17 @@ class SamplingConfig(ToStringMixin): an explanation of epoch semantics. """ - batch_size: int = 64 + batch_size: int | None = 64 """for off-policy algorithms, this is the number of environment steps/transitions to sample from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific. On-policy algorithms use the full buffer that was collected in the preceding collection step but they may use this parameter to perform the gradient update using mini-batches of this size (causing the gradient to be less accurate, a form of regularization). + + ``batch_size=None`` means that the full buffer is used for the gradient update. This doesn't + make much sense for off-policy algorithms and is not recommended then. For on-policy or offline algorithms, + this means that the full buffer is used for the gradient update (no mini-batching), and + may make sense in some cases. """ num_train_envs: int = -1 diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 3fb300261..eceee100f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -448,7 +448,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -#TODO: fix docstring +# TODO: fix docstring class BranchingNet(NetBase[Any]): """Branching dual Q network. From c1a4b409b1f2f29a27cbca568d3c1a8bc2acf0ea Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 3 Apr 2024 15:53:12 +0200 Subject: [PATCH 114/115] Changelog [skip ci] --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb786988d..905bc13b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 +- `SamplingConfig` supports `batch_size=None`. #1077 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 @@ -19,6 +20,8 @@ instead of just `nn.Module`. #1032 - Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032 - Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032 - Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032 +- tests and examples are covered by `mypy`. #1077 +- `NetBase` is more used, stricter typing by making it generic. #1077 ### Breaking Changes @@ -29,6 +32,7 @@ expicitly or pass `reset_before_collect=True` . #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 +- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 ### Tests - Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081 From 4c34a45392de7c9a63cd781aa6b39d19fab5253f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 3 Apr 2024 15:57:40 +0200 Subject: [PATCH 115/115] Changelog [skip ci] --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e1f1b1cc..126f81a89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,4 +40,3 @@ continuous and discrete cases. #1032 Started after v1.0.0 -