From af56ffda554a710c1546720d2536ea6d668fef1f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 3 Mar 2025 19:03:30 +0100 Subject: [PATCH 001/230] v2: Separation of Policy and Algorithm, initial implementation of PG/Reinforce --- examples/atari/atari_c51.py | 4 +- examples/atari/atari_fqf.py | 4 +- examples/atari/atari_iqn.py | 4 +- examples/atari/atari_ppo.py | 4 +- examples/atari/atari_qrdqn.py | 4 +- examples/atari/atari_rainbow.py | 4 +- examples/atari/atari_sac.py | 4 +- examples/box2d/acrobot_dualdqn.py | 4 +- examples/box2d/bipedal_bdq.py | 4 +- examples/box2d/bipedal_hardcore_sac.py | 4 +- examples/box2d/lunarlander_dqn.py | 4 +- examples/box2d/mcc_sac.py | 4 +- examples/inverse/irl_gail.py | 4 +- examples/mujoco/fetch_her_ddpg.py | 4 +- examples/mujoco/mujoco_a2c.py | 4 +- examples/mujoco/mujoco_ddpg.py | 4 +- examples/mujoco/mujoco_npg.py | 4 +- examples/mujoco/mujoco_ppo.py | 4 +- examples/mujoco/mujoco_redq.py | 4 +- examples/mujoco/mujoco_reinforce.py | 8 +- examples/mujoco/mujoco_sac.py | 4 +- examples/mujoco/mujoco_td3.py | 4 +- examples/mujoco/mujoco_trpo.py | 4 +- examples/offline/atari_bcq.py | 4 +- examples/offline/atari_cql.py | 4 +- examples/offline/atari_crr.py | 4 +- examples/offline/atari_il.py | 4 +- examples/offline/d4rl_bcq.py | 4 +- examples/offline/d4rl_cql.py | 4 +- examples/offline/d4rl_il.py | 4 +- examples/offline/d4rl_td3_bc.py | 4 +- examples/vizdoom/vizdoom_c51.py | 4 +- examples/vizdoom/vizdoom_ppo.py | 4 +- test/base/test_collector.py | 7 +- test/base/test_env_finite.py | 6 +- test/base/test_policy.py | 4 +- test/base/test_returns.py | 28 +- test/continuous/test_ddpg.py | 4 +- test/continuous/test_npg.py | 4 +- test/continuous/test_ppo.py | 4 +- test/continuous/test_redq.py | 4 +- test/continuous/test_sac_with_il.py | 4 +- test/continuous/test_td3.py | 4 +- test/continuous/test_trpo.py | 4 +- test/discrete/test_a2c_with_il.py | 6 +- test/discrete/test_c51.py | 4 +- test/discrete/test_dqn.py | 4 +- test/discrete/test_drqn.py | 4 +- test/discrete/test_fqf.py | 4 +- test/discrete/test_iqn.py | 4 +- test/discrete/test_pg.py | 24 +- test/discrete/test_ppo.py | 4 +- test/discrete/test_qrdqn.py | 4 +- test/discrete/test_rainbow.py | 4 +- test/discrete/test_sac.py | 4 +- test/modelbased/test_dqn_icm.py | 4 +- test/modelbased/test_ppo_icm.py | 4 +- test/offline/gather_cartpole_data.py | 4 +- test/offline/gather_pendulum_data.py | 4 +- test/offline/test_bcq.py | 4 +- test/offline/test_cql.py | 4 +- test/offline/test_discrete_bcq.py | 4 +- test/offline/test_discrete_cql.py | 4 +- test/offline/test_discrete_crr.py | 4 +- test/offline/test_gail.py | 4 +- test/offline/test_td3_bc.py | 4 +- test/pettingzoo/pistonball.py | 14 +- test/pettingzoo/pistonball_continuous.py | 14 +- tianshou/data/collector.py | 26 +- tianshou/highlevel/agent.py | 20 +- tianshou/highlevel/experiment.py | 4 +- tianshou/highlevel/params/policy_wrapper.py | 8 +- tianshou/highlevel/trainer.py | 4 +- tianshou/highlevel/world.py | 4 +- tianshou/policy/__init__.py | 67 ++-- tianshou/policy/base.py | 320 +++++++++++--------- tianshou/policy/imitation/base.py | 6 +- tianshou/policy/imitation/bcq.py | 11 +- tianshou/policy/imitation/cql.py | 2 +- tianshou/policy/imitation/discrete_bcq.py | 2 +- tianshou/policy/imitation/discrete_cql.py | 2 +- tianshou/policy/imitation/discrete_crr.py | 6 +- tianshou/policy/imitation/gail.py | 2 +- tianshou/policy/imitation/td3_bc.py | 2 +- tianshou/policy/modelbased/icm.py | 10 +- tianshou/policy/modelbased/psrl.py | 11 +- tianshou/policy/modelfree/a2c.py | 6 +- tianshou/policy/modelfree/bdq.py | 7 +- tianshou/policy/modelfree/c51.py | 7 +- tianshou/policy/modelfree/ddpg.py | 6 +- tianshou/policy/modelfree/discrete_sac.py | 2 +- tianshou/policy/modelfree/dqn.py | 7 +- tianshou/policy/modelfree/fqf.py | 7 +- tianshou/policy/modelfree/iqn.py | 7 +- tianshou/policy/modelfree/npg.py | 2 +- tianshou/policy/modelfree/pg.py | 144 +++++---- tianshou/policy/modelfree/ppo.py | 2 +- tianshou/policy/modelfree/qrdqn.py | 7 +- tianshou/policy/modelfree/rainbow.py | 4 +- tianshou/policy/modelfree/redq.py | 2 +- tianshou/policy/modelfree/sac.py | 2 +- tianshou/policy/modelfree/td3.py | 2 +- tianshou/policy/modelfree/trpo.py | 2 +- tianshou/policy/multiagent/mapolicy.py | 14 +- tianshou/policy/random.py | 6 +- tianshou/trainer/base.py | 8 +- tianshou/utils/torch_utils.py | 4 +- 107 files changed, 599 insertions(+), 503 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 3c757f9b6..326b89ea2 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer @@ -136,7 +136,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index c25002613..0ac0db560 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import FQFPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -149,7 +149,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 8a30ca75d..3d6ad57c4 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import IQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -146,7 +146,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 2f3832d23..3699aa7f1 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -214,7 +214,7 @@ def dist(logits: torch.Tensor) -> Categorical: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index e47c08d92..c5a658b08 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import QRDQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer @@ -140,7 +140,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 5373d0536..60f07140c 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -17,7 +17,7 @@ ) from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy, RainbowPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer @@ -171,7 +171,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index df43e49ac..48984b5b2 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteSACPolicy, ICMPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -197,7 +197,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index b25f35c15..d1709a9ec 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -99,7 +99,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index d88379b23..fff509d5f 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.policy import BranchingDQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet @@ -125,7 +125,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index b377d7bb1..9c46b4dc6 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -176,7 +176,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 347da2cf9..fea3096ba 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -101,7 +101,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 452eb02d6..838c40a37 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.policy import SACPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -122,7 +122,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 815060d1c..10c10de1d 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -25,7 +25,7 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.policy import GAILPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net @@ -252,7 +252,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index ee9b76e75..b5f9f1319 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -23,7 +23,7 @@ from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic @@ -217,7 +217,7 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 817da079a..80325e40c 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -15,7 +15,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import A2CPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -196,7 +196,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index d85a14427..62a4b3d62 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -13,7 +13,7 @@ from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPGPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -146,7 +146,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 9416376a1..8b8294100 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -15,7 +15,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import NPGPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -193,7 +193,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 965ec7739..a385efc59 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -15,7 +15,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPOPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -201,7 +201,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 61d85ae1c..444b550fc 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import REDQPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -174,7 +174,7 @@ def linear(x: int, y: int) -> EnsembleLinear: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 95391e1ea..e49124e92 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -14,8 +14,8 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PGPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy import Reinforce +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb @@ -123,7 +123,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: PGPolicy = PGPolicy( + policy: Reinforce = Reinforce( actor=actor, optim=optim, dist_fn=dist, @@ -173,7 +173,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 151237580..07c91085b 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import SACPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -168,7 +168,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index a5e3e8cf6..31d7f1370 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -13,7 +13,7 @@ from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TD3Policy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -166,7 +166,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 9405b2440..e1357afef 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -15,7 +15,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TRPOPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -198,7 +198,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 62c205076..04dde73bb 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -17,7 +17,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteBCQPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor @@ -178,7 +178,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 436145c90..3b8bd2783 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -18,7 +18,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCQLPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils.space_info import SpaceInfo @@ -162,7 +162,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 565908559..833e7a008 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -17,7 +17,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCRRPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic @@ -179,7 +179,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 16b7cdec3..45d5c5c44 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -16,7 +16,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ImitationPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils.space_info import SpaceInfo @@ -136,7 +136,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 08b38ded6..37cf18446 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.policy import BCQPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net @@ -198,7 +198,7 @@ def test_bcq() -> None: ) logger.load(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 5b68edf9e..a896fffb7 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.policy import CQLPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net @@ -336,7 +336,7 @@ def test_cql() -> None: ) logger.load(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index e1e71fd82..91c998cad 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.policy import ImitationPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net @@ -134,7 +134,7 @@ def test_il() -> None: ) logger.load(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index f4e8b38c2..47c05bda8 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -15,7 +15,7 @@ from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BCPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net @@ -183,7 +183,7 @@ def test_td3_bc() -> None: ) logger.load(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def watch() -> None: diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 0c23e6e8e..b741e7d7a 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer @@ -144,7 +144,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 6d4f55d14..d858305b8 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -224,7 +224,7 @@ def dist(logits: torch.Tensor) -> Categorical: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6355d8bfc..ed648075d 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -80,7 +80,12 @@ def forward( action_shape = self.action_shape if self.action_shape else len(batch.obs) return Batch(act=np.ones(action_shape), state=state) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TrainingStats: raise NotImplementedError diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 39cf9d3ea..2240e1d53 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -20,7 +20,7 @@ ) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm class DummyDataset(Dataset): @@ -204,7 +204,7 @@ class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): pass -class AnyPolicy(BasePolicy): +class AnyPolicy(Algorithm): def __init__(self) -> None: super().__init__(action_space=Box(-1, 1, (1,))) @@ -216,7 +216,7 @@ def forward( ) -> ActBatchProtocol: return cast(ActBatchProtocol, Batch(act=np.stack([1] * len(batch)))) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> None: + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> None: pass diff --git a/test/base/test_policy.py b/test/base/test_policy.py index b918194da..753291a63 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -5,7 +5,7 @@ from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.data import Batch -from tianshou.policy import BasePolicy, PPOPolicy +from tianshou.policy import Algorithm, PPOPolicy from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -58,7 +58,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3) - policy: BasePolicy + policy: Algorithm policy = PPOPolicy( actor=actor, critic=critic, diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 078893113..8b05319d0 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -5,7 +5,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: @@ -21,7 +21,7 @@ def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: def test_episodic_returns(size: int = 2560) -> None: - fn = BasePolicy.compute_episodic_return + fn = Algorithm.compute_episodic_return buf = ReplayBuffer(20) batch = cast( RolloutBatchProtocol, @@ -215,7 +215,7 @@ def test_nstep_returns(size: int = 10000) -> None: # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) .pop("returns") .reshape(-1), ) @@ -223,7 +223,7 @@ def test_nstep_returns(size: int = 10000) -> None: r_ = compute_nstep_return_base(1, 0.1, buf, indices) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -235,7 +235,7 @@ def test_nstep_returns(size: int = 10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) .pop("returns") .reshape(-1), ) @@ -243,7 +243,7 @@ def test_nstep_returns(size: int = 10000) -> None: r_ = compute_nstep_return_base(2, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -255,7 +255,7 @@ def test_nstep_returns(size: int = 10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) .pop("returns") .reshape(-1), ) @@ -263,7 +263,7 @@ def test_nstep_returns(size: int = 10000) -> None: r_ = compute_nstep_return_base(10, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -297,7 +297,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=1) .pop("returns") .reshape(-1), ) @@ -305,7 +305,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: r_ = compute_nstep_return_base(1, 0.1, buf, indices) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -317,7 +317,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=2) .pop("returns") .reshape(-1), ) @@ -325,7 +325,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: r_ = compute_nstep_return_base(2, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, @@ -337,7 +337,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy( - BasePolicy.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) + Algorithm.compute_nstep_return(batch, buf, indices, target_q_fn, gamma=0.1, n_step=10) .pop("returns") .reshape(-1), ) @@ -345,7 +345,7 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None: r_ = compute_nstep_return_base(10, 0.1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( - BasePolicy.compute_nstep_return( + Algorithm.compute_nstep_return( batch, buf, indices, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index ce7998eff..b7b877dee 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -110,7 +110,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index d853e2186..393225e33 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import NPGPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger @@ -132,7 +132,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 4b56bd630..576353c53 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger @@ -133,7 +133,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 82c8f0637..fa947a3fc 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import REDQPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net @@ -141,7 +141,7 @@ def linear(x: int, y: int) -> nn.Module: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 09fc3ca45..6634501a3 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import ImitationPolicy, SACPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -137,7 +137,7 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 6c59ea25a..41a339cb5 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -128,7 +128,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 91e215116..3d1ffda36 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import TRPOPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -132,7 +132,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 192f24c24..aafbac62a 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import A2CPolicy, ImitationPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net @@ -94,7 +94,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy: BasePolicy + policy: Algorithm policy = A2CPolicy( actor=actor, critic=critic, @@ -123,7 +123,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 41d6a0260..8d7bd5b6b 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -16,7 +16,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import C51Policy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -126,7 +126,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index f82aca1f6..842555d0c 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -15,7 +15,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -117,7 +117,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 4cc0b6bd0..6030fc03c 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent @@ -100,7 +100,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index e0899d315..b68ec6157 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -15,7 +15,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import FQFPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -134,7 +134,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 08f545b11..bbfc3a71a 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -15,7 +15,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import IQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -130,7 +130,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 8a681583d..045661ba1 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -9,8 +9,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PGPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy import Reinforce +from tianshou.policy.base import Algorithm +from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -75,13 +76,16 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist_fn = torch.distributions.Categorical - policy: PGPolicy = PGPolicy( + policy = ActorPolicy( actor=net, - optim=optim, dist_fn=dist_fn, - discount_factor=args.gamma, action_space=env.action_space, action_scaling=isinstance(env.action_space, Box), + ) + algorithm: Reinforce = Reinforce( + policy=policy, + optim=optim, + discount_factor=args.gamma, reward_normalization=args.rew_norm, ) for m in net.modules(): @@ -91,25 +95,25 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "pg") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: - torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + def save_best_fn(algorithm: Algorithm) -> None: + torch.save(algorithm.policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer result = OnpolicyTrainer( - policy=policy, + policy=algorithm, train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 7e541fffb..88b0e2b86 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger @@ -128,7 +128,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 5aa543fb5..bf6928e7e 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -14,7 +14,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import QRDQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -123,7 +123,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index d7d4b15b1..cb80c460d 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -15,7 +15,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import RainbowPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.rainbow import RainbowTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -138,7 +138,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 3409dab0a..432d8c2c8 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DiscreteSACPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.discrete_sac import DiscreteSACTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -118,7 +118,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index c108e7c0f..9d3a717b0 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -14,7 +14,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy, ICMPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -164,7 +164,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 7d3780960..beb6561b7 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import ICMPolicy, PPOPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger @@ -166,7 +166,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index bee9063ea..4c69d0315 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -15,7 +15,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import QRDQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -129,7 +129,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 614ee388f..feca48794 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import SACPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.sac import SACTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -137,7 +137,7 @@ def gather_data() -> VectorReplayBuffer: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 2ed910902..dfa1dfe50 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, BCQPolicy +from tianshou.policy import Algorithm, BCQPolicy from tianshou.policy.imitation.bcq import BCQTrainingStats from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger @@ -174,7 +174,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index bd84098ba..56be57a9a 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, CQLPolicy +from tianshou.policy import Algorithm, CQLPolicy from tianshou.policy.imitation.cql import CQLTrainingStats from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger @@ -175,7 +175,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index e69e0a1fa..1b3cd701a 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -15,7 +15,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, DiscreteBCQPolicy +from tianshou.policy import Algorithm, DiscreteBCQPolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net @@ -118,7 +118,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 97766d494..2e7c5af70 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -15,7 +15,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, DiscreteCQLPolicy +from tianshou.policy import Algorithm, DiscreteCQLPolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -107,7 +107,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index bf9a833a9..f13bc006d 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -15,7 +15,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, DiscreteCRRPolicy +from tianshou.policy import Algorithm, DiscreteCRRPolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net @@ -111,7 +111,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index c7f183587..099b96d0a 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import BasePolicy, GAILPolicy +from tianshou.policy import Algorithm, GAILPolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net @@ -171,7 +171,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 17c3afb06..cae6f6f06 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -13,7 +13,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BCPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -162,7 +162,7 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 0cf269d4f..88186c177 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager +from tianshou.policy import Algorithm, DQNPolicy, MultiAgentPolicyManager from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -74,9 +74,9 @@ def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: def get_agents( args: argparse.Namespace = get_args(), - agents: list[BasePolicy] | None = None, + agents: list[Algorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[BasePolicy, list[torch.optim.Optimizer] | None, list]: +) -> tuple[Algorithm, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -114,9 +114,9 @@ def get_agents( def train_agent( args: argparse.Namespace = get_args(), - agents: list[BasePolicy] | None = None, + agents: list[Algorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[InfoStats, BasePolicy]: +) -> tuple[InfoStats, Algorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -143,7 +143,7 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: pass def stop_fn(mean_rewards: float) -> bool: @@ -181,7 +181,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: return result, policy -def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = None) -> None: +def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None) -> None: env = DummyVectorEnv([get_env]) if not policy: warnings.warn( diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 7beb92fde..935e65a6e 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -15,7 +15,7 @@ from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy +from tianshou.policy import Algorithm, MultiAgentPolicyManager, PPOPolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.continuous import ActorProb, Critic @@ -138,9 +138,9 @@ def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: def get_agents( args: argparse.Namespace = get_args(), - agents: list[BasePolicy] | None = None, + agents: list[Algorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[BasePolicy, list[torch.optim.Optimizer] | None, list]: +) -> tuple[Algorithm, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -220,9 +220,9 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: def train_agent( args: argparse.Namespace = get_args(), - agents: list[BasePolicy] | None = None, + agents: list[Algorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[InfoStats, BasePolicy]: +) -> tuple[InfoStats, Algorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -248,7 +248,7 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: pass def stop_fn(mean_rewards: float) -> bool: @@ -277,7 +277,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: return result, policy -def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = None) -> None: +def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None) -> None: env = DummyVectorEnv([get_env]) if not policy: warnings.warn( diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 514db4cc0..7fc661064 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -30,7 +30,7 @@ RolloutBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import episode_mc_return_to_go from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode @@ -309,7 +309,7 @@ class BaseCollector(Generic[TCollectStats], ABC): def __init__( self, - policy: BasePolicy, + algorithm: Algorithm, env: BaseVectorEnv | gym.Env, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, @@ -327,7 +327,7 @@ def __init__( self.buffer: ReplayBuffer | ReplayBufferManager = buffer self.raise_on_nan_in_buffer = raise_on_nan_in_buffer - self.policy = policy + self.algorithm = algorithm self.env = cast(BaseVectorEnv, env) self.exploration_noise = exploration_noise self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 @@ -469,7 +469,7 @@ def collect( self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) pre_collect_time = time.time() - with torch_train_mode(self.policy, enabled=False): + with torch_train_mode(self.algorithm.policy, enabled=False): collect_stats = self._collect( n_step=n_step, n_episode=n_episode, @@ -548,7 +548,7 @@ class Collector(BaseCollector[TCollectStats], Generic[TCollectStats]): # def __init__( self, - policy: BasePolicy, + algorithm: Algorithm, env: gym.Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, @@ -558,7 +558,7 @@ def __init__( collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: """ - :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch + :param algorithm: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch of actions from a batch of observations. :param env: a ``gymnasium.Env`` environment or a vectorized instance of the :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with @@ -599,7 +599,7 @@ def __init__( this is rarely necessary and is mainly done by "power users". """ super().__init__( - policy, + algorithm, env, buffer, exploration_noise=exploration_noise, @@ -691,7 +691,7 @@ def _compute_action_policy_hidden( # TODO: test whether envpool env explicitly except TypeError: # envpool's action space is not for per-env act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) - act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) + act_RA = self.algorithm.map_action_inverse(np.array(act_normalized_RA)) policy_R = Batch() hidden_state_RH = None # TODO: instead use a (uniform) Distribution instance that corresponds to sampling from action_space @@ -701,15 +701,15 @@ def _compute_action_policy_hidden( info_batch = _HACKY_create_info_batch(last_info_R) obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) - act_batch_RA: ActBatchProtocol | DistBatchProtocol = self.policy( + act_batch_RA: ActBatchProtocol | DistBatchProtocol = self.algorithm.policy( obs_batch_R, last_hidden_state_RH, ) act_RA = to_numpy(act_batch_RA.act) if self.exploration_noise: - act_RA = self.policy.exploration_noise(act_RA, obs_batch_R) - act_normalized_RA = self.policy.map_action(act_RA) + act_RA = self.algorithm.exploration_noise(act_RA, obs_batch_R) + act_normalized_RA = self.algorithm.policy.map_action(act_RA) # TODO: cleanup the whole policy in batch thing # todo policy_R can also be none, check @@ -1084,7 +1084,7 @@ class AsyncCollector(Collector[CollectStats]): def __init__( self, - policy: BasePolicy, + algorithm: Algorithm, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, @@ -1099,7 +1099,7 @@ def __init__( # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") super().__init__( - policy, + algorithm, env, buffer, exploration_noise, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 2ad533983..9f7f1011c 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -46,15 +46,15 @@ from tianshou.highlevel.world import World from tianshou.policy import ( A2CPolicy, - BasePolicy, + Algorithm, DDPGPolicy, DiscreteSACPolicy, DQNPolicy, IQNPolicy, NPGPolicy, - PGPolicy, PPOPolicy, REDQPolicy, + Reinforce, SACPolicy, TD3Policy, TRPOPolicy, @@ -78,7 +78,7 @@ "TDiscreteCriticOnlyParams", bound=Params | ParamsMixinLearningRateWithScheduler, ) -TPolicy = TypeVar("TPolicy", bound=BasePolicy) +TPolicy = TypeVar("TPolicy", bound=Algorithm) log = logging.getLogger(__name__) @@ -93,7 +93,7 @@ def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFact def create_train_test_collector( self, - policy: BasePolicy, + policy: Algorithm, envs: Environments, reset_collectors: bool = True, ) -> tuple[BaseCollector, BaseCollector]: @@ -143,10 +143,10 @@ def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: self.trainer_callbacks = callbacks @abstractmethod - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: pass - def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + def create_policy(self, envs: Environments, device: TDevice) -> Algorithm: policy = self._create_policy(envs, device) if self.policy_wrapper_factory is not None: policy = self.policy_wrapper_factory.create_wrapped_policy( @@ -269,7 +269,7 @@ def __init__( self.actor_factory = actor_factory self.optim_factory = optim_factory - def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy: + def _create_policy(self, envs: Environments, device: TDevice) -> Reinforce: actor = self.actor_factory.create_module_opt( envs, device, @@ -286,7 +286,7 @@ def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy: ) dist_fn = self.actor_factory.create_dist_fn(envs) assert dist_fn is not None - return PGPolicy( + return Reinforce( actor=actor.module, optim=actor.optim, action_space=envs.get_action_space(), @@ -444,7 +444,7 @@ def __init__( self.params = params self.optim_factory = optim_factory - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: actor = self.actor_factory.create_module_opt( envs, device, @@ -493,7 +493,7 @@ def __init__( self.params = params self.optim_factory = optim_factory - def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: envs.get_type().assert_continuous(self) actor = self.actor_factory.create_module_opt( envs, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index c0be23dca..9874d26c9 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -108,7 +108,7 @@ TrainerCallbacks, ) from tianshou.highlevel.world import World -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.utils import LazyLogger from tianshou.utils.net.common import ModuleType from tianshou.utils.print import DataclassPPrintMixin @@ -456,7 +456,7 @@ def run( @staticmethod def _watch_agent( num_episodes: int, - policy: BasePolicy, + policy: Algorithm, env: BaseVectorEnv, render: float, ) -> None: diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 43cbfed1e..ab3994151 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -8,17 +8,17 @@ from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.optim import OptimizerFactory -from tianshou.policy import BasePolicy, ICMPolicy +from tianshou.policy import Algorithm, ICMPolicy from tianshou.utils.net.discrete import IntrinsicCuriosityModule -TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy) +TPolicyOut = TypeVar("TPolicyOut", bound=Algorithm) class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC): @abstractmethod def create_wrapped_policy( self, - policy: BasePolicy, + policy: Algorithm, envs: Environments, optim_factory: OptimizerFactory, device: TDevice, @@ -48,7 +48,7 @@ def __init__( def create_wrapped_policy( self, - policy: BasePolicy, + policy: Algorithm, envs: Environments, optim_factory: OptimizerFactory, device: TDevice, diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 498cc3173..b4c2357bb 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -8,9 +8,9 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger -from tianshou.policy import BasePolicy, DQNPolicy +from tianshou.policy import Algorithm, DQNPolicy -TPolicy = TypeVar("TPolicy", bound=BasePolicy) +TPolicy = TypeVar("TPolicy", bound=Algorithm) log = logging.getLogger(__name__) diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 6db216b15..439ed0b7e 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -6,7 +6,7 @@ from tianshou.data import BaseCollector from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger - from tianshou.policy import BasePolicy + from tianshou.policy import Algorithm from tianshou.trainer import BaseTrainer @@ -15,7 +15,7 @@ class World: """Container for instances and configuration items that are relevant to an experiment.""" envs: "Environments" - policy: "BasePolicy" + policy: "Algorithm" train_collector: Optional["BaseCollector"] = None test_collector: Optional["BaseCollector"] = None logger: "TLogger" diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index a9b944da8..1716f364e 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,7 +1,10 @@ """Policy package.""" # isort:skip_file -from tianshou.policy.base import BasePolicy, TrainingStats +from tianshou.policy.base import Algorithm, TrainingStats +from tianshou.policy.modelfree.pg import Reinforce + +""" from tianshou.policy.random import MARLRandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.bdq import BranchingDQNPolicy @@ -10,7 +13,6 @@ from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.modelfree.fqf import FQFPolicy -from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.a2c import A2CPolicy from tianshou.policy.modelfree.npg import NPGPolicy from tianshou.policy.modelfree.ddpg import DDPGPolicy @@ -31,37 +33,38 @@ from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager +""" __all__ = [ - "BasePolicy", - "MARLRandomPolicy", - "DQNPolicy", - "BranchingDQNPolicy", - "C51Policy", - "RainbowPolicy", - "QRDQNPolicy", - "IQNPolicy", - "FQFPolicy", - "PGPolicy", - "A2CPolicy", - "NPGPolicy", - "DDPGPolicy", - "PPOPolicy", - "TRPOPolicy", - "TD3Policy", - "SACPolicy", - "REDQPolicy", - "DiscreteSACPolicy", - "ImitationPolicy", - "BCQPolicy", - "CQLPolicy", - "TD3BCPolicy", - "DiscreteBCQPolicy", - "DiscreteCQLPolicy", - "DiscreteCRRPolicy", - "GAILPolicy", - "PSRLPolicy", - "ICMPolicy", - "MultiAgentPolicyManager", + "Algorithm", + # "MARLRandomPolicy", + # "DQNPolicy", + # "BranchingDQNPolicy", + # "C51Policy", + # "RainbowPolicy", + # "QRDQNPolicy", + # "IQNPolicy", + # "FQFPolicy", + "Reinforce", + # "A2CPolicy", + # "NPGPolicy", + # "DDPGPolicy", + # "PPOPolicy", + # "TRPOPolicy", + # "TD3Policy", + # "SACPolicy", + # "REDQPolicy", + # "DiscreteSACPolicy", + # "ImitationPolicy", + # "BCQPolicy", + # "CQLPolicy", + # "TD3BCPolicy", + # "DiscreteBCQPolicy", + # "DiscreteCQLPolicy", + # "DiscreteCRRPolicy", + # "GAILPolicy", + # "PSRLPolicy", + # "ICMPolicy", + # "MultiAgentPolicyManager", "TrainingStats", ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 066a23a3b..798b65ac7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast import gymnasium as gym import numpy as np @@ -29,6 +29,9 @@ from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode +if TYPE_CHECKING: + from tianshou.highlevel.config import SamplingConfig + logger = logging.getLogger(__name__) TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers @@ -133,71 +136,15 @@ def __setattr__(self, name: str, value: Any) -> None: TTrainingStats = TypeVar("TTrainingStats", bound=TrainingStats) -class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): - """The base class for any RL policy. - - Tianshou aims to modularize RL algorithms. It comes into several classes of - policies in Tianshou. All policy classes must inherit from - :class:`~tianshou.policy.BasePolicy`. - - A policy class typically has the following parts: - - * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \ - coping the target network and so on; - * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \ - observation; - * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \ - replay buffer (this function can interact with replay buffer); - * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \ - data. - * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \ - from the learning process (e.g., prioritized replay buffer needs to update \ - the weight); - * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \ - i.e., `process_fn -> learn -> post_process_fn`. - - Most of the policy needs a neural network to predict the action and an - optimizer to optimize the policy. The rules of self-defined networks are: - - 1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \ - dict or any others), hidden state "state" (for RNN usage), and other information \ - "info" provided by the environment. - 2. Output: some "logits", the next hidden state "state", and the intermediate \ - result during policy forwarding procedure "policy". The "logits" could be a tuple \ - instead of a ``torch.Tensor``. It depends on how the policy process the network \ - output. For example, in PPO, the return of the network might be \ - ``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \ - torch.Tensor or other things, which will be stored in the replay buffer, and can \ - be accessed in the policy update process (e.g. in "policy.learn()", the \ - "batch.policy" is what you need). - - Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can - use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``, - for instance, loading and saving the model: - :: - - torch.save(policy.state_dict(), "policy.pth") - policy.load_state_dict(torch.load("policy.pth")) - - :param action_space: Env's action_space. - :param observation_space: Env's observation space. TODO: appears unused... - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - """ - +class Policy(nn.Module, ABC): def __init__( self, - *, action_space: gym.Space, # TODO: does the policy actually need the observation space? observation_space: gym.Space | None = None, action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: + ): allowed_action_bound_methods = ("clip", "tanh") if ( action_bound_method is not None @@ -212,7 +159,6 @@ def __init__( f"action_scaling can only be True when action_space is Box but " f"got: {action_space}", ) - super().__init__() self.observation_space = observation_space self.action_space = action_space @@ -227,7 +173,6 @@ def __init__( self.updating = False self.action_scaling = action_scaling self.action_bound_method = action_bound_method - self.lr_scheduler = lr_scheduler self.is_within_training_step = False """ flag indicating whether we are currently within a training step, @@ -245,75 +190,10 @@ def __init__( """ self._compile() - def __setstate__(self, state: dict[str, Any]) -> None: - # TODO Use setstate function once merged - if "is_within_training_step" not in state: - state["is_within_training_step"] = False - self.__dict__ = state - @property def action_type(self) -> Literal["discrete", "continuous"]: return self._action_type - def set_agent_id(self, agent_id: int) -> None: - """Set self.agent_id = agent_id, for MARL.""" - self.agent_id = agent_id - - # TODO: needed, since for most of offline algorithm, the algorithm itself doesn't - # have a method to add noise to action. - # So we add the default behavior here. It's a little messy, maybe one can - # find a better way to do this. - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - """Modify the action from policy.forward with exploration noise. - - NOTE: currently does not add any noise! Needs to be overridden by subclasses - to actually do something. - - :param act: a data batch or numpy.ndarray which is the action taken by - policy.forward. - :param batch: the input batch for policy.forward, kept for advanced usage. - :return: action in the same form of input "act" but with added exploration - noise. - """ - return act - - def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: - """Softly update the parameters of target module towards the parameters of source module.""" - for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): - tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) - - def compute_action( - self, - obs: ArrayLike, - info: dict[str, Any] | None = None, - state: dict | BatchProtocol | np.ndarray | None = None, - ) -> np.ndarray | int: - """Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info. - - :param obs: observation from the gym's env. - :param info: information given by the gym's env. - :param state: the hidden state of RNN policy, used for recurrent policy. - :return: action as int (for discrete env's) or array (for continuous ones). - """ - obs = np.array(obs) # convert array-like to array (e.g. LazyFrames) - obs = obs[None, :] # add batch dimension - obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) - act = self.forward(obs_batch, state=state).act.squeeze() - if isinstance(act, torch.Tensor): - act = act.detach().cpu().numpy() - act = self.map_action(act) - if isinstance(self.action_space, Discrete): - # could be an array of shape (), easier to just convert to int - act = int(act) # type: ignore - return act - @abstractmethod def forward( self, @@ -427,6 +307,152 @@ def map_action_inverse( return act + def compute_action( + self, + obs: ArrayLike, + info: dict[str, Any] | None = None, + state: dict | BatchProtocol | np.ndarray | None = None, + ) -> np.ndarray | int: + """Get action as int (for discrete env's) or array (for continuous ones) from an env's observation and info. + + :param obs: observation from the gym's env. + :param info: information given by the gym's env. + :param state: the hidden state of RNN policy, used for recurrent policy. + :return: action as int (for discrete env's) or array (for continuous ones). + """ + obs = np.array(obs) # convert array-like to array (e.g. LazyFrames) + obs = obs[None, :] # add batch dimension + obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info)) + act = self.forward(obs_batch, state=state).act.squeeze() + if isinstance(act, torch.Tensor): + act = act.detach().cpu().numpy() + act = self.map_action(act) + if isinstance(self.action_space, Discrete): + # could be an array of shape (), easier to just convert to int + act = int(act) # type: ignore + return act + + @staticmethod + def _compile() -> None: + f64 = np.array([0, 1], dtype=np.float64) + f32 = np.array([0, 1], dtype=np.float32) + b = np.array([False, True], dtype=np.bool_) + i64 = np.array([[0, 1]], dtype=np.int64) + _gae_return(f64, f64, f64, b, 0.1, 0.1) + _gae_return(f32, f32, f64, b, 0.1, 0.1) + _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) + + +TPolicy = TypeVar("TPolicy", bound=Policy) + + +class Algorithm(Generic[TPolicy, TTrainingStats], ABC): + """ + TODO fix docstring + The base class for any RL policy. + + Tianshou aims to modularize RL algorithms. It comes into several classes of + policies in Tianshou. All policy classes must inherit from + :class:`~tianshou.policy.BasePolicy`. + + A policy class typically has the following parts: + + * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \ + coping the target network and so on; + * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \ + observation; + * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \ + replay buffer (this function can interact with replay buffer); + * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \ + data. + * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \ + from the learning process (e.g., prioritized replay buffer needs to update \ + the weight); + * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \ + i.e., `process_fn -> learn -> post_process_fn`. + + Most of the policy needs a neural network to predict the action and an + optimizer to optimize the policy. The rules of self-defined networks are: + + 1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \ + dict or any others), hidden state "state" (for RNN usage), and other information \ + "info" provided by the environment. + 2. Output: some "logits", the next hidden state "state", and the intermediate \ + result during policy forwarding procedure "policy". The "logits" could be a tuple \ + instead of a ``torch.Tensor``. It depends on how the policy process the network \ + output. For example, in PPO, the return of the network might be \ + ``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \ + torch.Tensor or other things, which will be stored in the replay buffer, and can \ + be accessed in the policy update process (e.g. in "policy.learn()", the \ + "batch.policy" is what you need). + + Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can + use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``, + for instance, loading and saving the model: + :: + + torch.save(policy.state_dict(), "policy.pth") + policy.load_state_dict(torch.load("policy.pth")) + + :param action_space: Env's action_space. + :param observation_space: Env's observation space. TODO: appears unused... + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + + def __init__( + self, + *, + policy: TPolicy, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + self.policy = policy + self.lr_scheduler = lr_scheduler + + # TODO delete this + def __setstate__(self, state: dict[str, Any]) -> None: + # TODO Use setstate function once merged + if "is_within_training_step" not in state: + state["is_within_training_step"] = False + self.__dict__ = state + + def set_agent_id(self, agent_id: int) -> None: + """Set self.agent_id = agent_id, for MARL.""" + self.agent_id = agent_id + + # TODO: needed, since for most of offline algorithm, the algorithm itself doesn't + # have a method to add noise to action. + # So we add the default behavior here. It's a little messy, maybe one can + # find a better way to do this. + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + """Modify the action from policy.forward with exploration noise. + + NOTE: currently does not add any noise! Needs to be overridden by subclasses + to actually do something. + + :param act: a data batch or numpy.ndarray which is the action taken by + policy.forward. + :param batch: the input batch for policy.forward, kept for advanced usage. + :return: action in the same form of input "act" but with added exploration + noise. + """ + return act + + def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: + """Softly update the parameters of target module towards the parameters of source module.""" + for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): + tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) + def process_buffer(self, buffer: TBuffer) -> TBuffer: """Pre-process the replay buffer, e.g., to add new keys. @@ -457,7 +483,12 @@ def process_fn( return batch @abstractmethod - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TTrainingStats: """Update policy with a given batch of data. :return: A dataclass object, including the data needed to be logged (e.g., loss). @@ -530,9 +561,9 @@ def update( # TODO: when does this happen? # -> this happens never in practice as update is either called with a collector buffer or an assert before - if not self.is_within_training_step: + if not self.policy.is_within_training_step: raise RuntimeError( - f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " + f"update() was called outside of a training step as signalled by {self.policy.is_within_training_step=} " f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " f"flag yourself. You can to this e.g., by using the contextmanager {policy_within_training_step.__name__}.", ) @@ -543,8 +574,8 @@ def update( batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - with torch_train_mode(self): - training_stat = self.learn(batch, **kwargs) + with torch_train_mode(self.policy): + training_stat = self._update_with_batch(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: self.lr_scheduler.step() @@ -615,7 +646,7 @@ def compute_episodic_return( v_s_ = np.zeros_like(rew) else: v_s_ = to_numpy(v_s_.flatten()) - v_s_ = v_s_ * BasePolicy.value_mask(buffer, indices) + v_s_ = v_s_ * Algorithm.value_mask(buffer, indices) v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten()) end_flag = np.logical_or(batch.terminated, batch.truncated) @@ -699,7 +730,7 @@ def compute_nstep_return( target_q_IA = to_numpy(target_q_torch_IA.reshape(I, -1)) """Represents the Q-values (one for each action) of the transition after N steps.""" - target_q_IA *= BasePolicy.value_mask(buffer, indices_after_n_steps_I).reshape(-1, 1) + target_q_IA *= Algorithm.value_mask(buffer, indices_after_n_steps_I).reshape(-1, 1) end_flag_B = buffer.done.copy() end_flag_B[buffer.unfinished_index()] = True n_step_return_IA = _nstep_return( @@ -720,18 +751,14 @@ def compute_nstep_return( return cast(BatchWithReturnsProtocol, batch) - @staticmethod - def _compile() -> None: - f64 = np.array([0, 1], dtype=np.float64) - f32 = np.array([0, 1], dtype=np.float32) - b = np.array([False, True], dtype=np.bool_) - i64 = np.array([[0, 1]], dtype=np.int64) - _gae_return(f64, f64, f64, b, 0.1, 0.1) - _gae_return(f32, f32, f64, b, 0.1, 0.1) - _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) + def _create_trainer(self): + pass + def train(self, sampling_config: "SamplingConfig"): + pass -class RandomActionPolicy(BasePolicy): + +class RandomActionPolicy(Algorithm): def __init__( self, action_space: gym.Space, @@ -752,7 +779,12 @@ def forward( act, next_state = self.actor.compute_action_batch(batch.obs), state return cast(ActStateBatchProtocol, Batch(act=act, state=next_state)) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TrainingStats: return TrainingStats() diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 6e21016d9..5135295e7 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -13,7 +13,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import TLearningRateScheduler, TrainingStats # Dimension Naming Convention @@ -31,7 +31,7 @@ class ImitationTrainingStats(TrainingStats): TImitationTrainingStats = TypeVar("TImitationTrainingStats", bound=ImitationTrainingStats) -class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTrainingStats]): +class ImitationPolicy(Algorithm[TImitationTrainingStats], Generic[TImitationTrainingStats]): """Implementation of vanilla imitation learning. :param actor: a model following the rules in @@ -94,7 +94,7 @@ def forward( raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!") return cast(ModelOutputBatchProtocol, result) - def learn( + def _update_with_batch( self, batch: RolloutBatchProtocol, *ags: Any, diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 991c4aace..c94c30d5b 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -10,7 +10,7 @@ from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.net.continuous import VAE from tianshou.utils.optim import clone_optimizer @@ -27,7 +27,7 @@ class BCQTrainingStats(TrainingStats): TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) -class BCQPolicy(BasePolicy[TBCQTrainingStats], Generic[TBCQTrainingStats]): +class BCQPolicy(Algorithm[TBCQTrainingStats], Generic[TBCQTrainingStats]): """Implementation of BCQ algorithm. arXiv:1812.02900. :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` @@ -154,7 +154,12 @@ def sync_weight(self) -> None: self.soft_update(self.critic2_target, self.critic2, self.tau) self.soft_update(self.actor_perturbation_target, self.actor_perturbation, self.tau) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBCQTrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TBCQTrainingStats: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 66438c758..35e3fe51e 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -247,7 +247,7 @@ def process_fn( # Should probably be fixed! return batch - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLTrainingStats: # type: ignore batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next batch_size = obs.shape[0] diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index b5258c141..86a059d20 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -147,7 +147,7 @@ def forward( # type: ignore result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits) return cast(ImitationBatchProtocol, result) - def learn( + def _update_with_batch( self, batch: RolloutBatchProtocol, *args: Any, diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index b63f83e11..640bff0d2 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -82,7 +82,7 @@ def __init__( ) self.min_q_weight = min_q_weight - def learn( + def _update_with_batch( self, batch: RolloutBatchProtocol, *args: Any, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 9c54129da..5c7395ff8 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -10,7 +10,7 @@ from tianshou.data import to_torch, to_torch_as from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats +from tianshou.policy.modelfree.pg import PGTrainingStats, Reinforce from tianshou.utils.net.discrete import Actor, Critic @@ -24,7 +24,7 @@ class DiscreteCRRTrainingStats(PGTrainingStats): TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteCRRTrainingStats) -class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): +class DiscreteCRRPolicy(Reinforce[TDiscreteCRRTrainingStats]): r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. :param actor: the actor network following the rules: @@ -104,7 +104,7 @@ def sync_weight(self) -> None: self.actor_old.load_state_dict(self.actor.state_dict()) self.critic_old.load_state_dict(self.critic.state_dict()) - def learn( # type: ignore + def _update_with_batch( # type: ignore self, batch: RolloutBatchProtocol, *args: Any, diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 524f04001..9ffd6a6b7 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -158,7 +158,7 @@ def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: act = to_torch(batch.act, device=self.disc_net.device) return self.disc_net(torch.cat([obs, act], dim=1)) - def learn( # type: ignore + def _update_with_batch( # type: ignore self, batch: RolloutBatchProtocol, batch_size: int | None, diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index f4b2bfe91..f11b88a5a 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -104,7 +104,7 @@ def __init__( ) self.alpha = alpha - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 9a603b7de..1bd13bb1b 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -8,7 +8,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import ( TLearningRateScheduler, TrainingStats, @@ -33,7 +33,7 @@ def __init__( super().__init__(wrapped_stats) -class ICMPolicy(BasePolicy[ICMTrainingStats]): +class ICMPolicy(Algorithm[ICMTrainingStats]): """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. :param policy: a base policy to add ICM to. @@ -57,7 +57,7 @@ class ICMPolicy(BasePolicy[ICMTrainingStats]): def __init__( self, *, - policy: BasePolicy[TTrainingStats], + policy: Algorithm[TTrainingStats], model: IntrinsicCuriosityModule, optim: torch.optim.Optimizer, lr_scale: float, @@ -150,13 +150,13 @@ def post_process_fn( self.policy.post_process_fn(batch, buffer, indices) batch.rew = batch.policy.orig_rew # restore original reward - def learn( + def _update_with_batch( self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any, ) -> ICMTrainingStats: - training_stat = self.policy.learn(batch, **kwargs) + training_stat = self.policy._update_with_batch(batch, **kwargs) self.optim.zero_grad() act_hat = batch.policy.act_hat act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 8c1374709..66711c74f 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -8,7 +8,7 @@ from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import TLearningRateScheduler, TrainingStats @@ -150,7 +150,7 @@ def __call__( return self.policy[obs] -class PSRLPolicy(BasePolicy[TPSRLTrainingStats]): +class PSRLPolicy(Algorithm[TPSRLTrainingStats]): """Implementation of Posterior Sampling Reinforcement Learning. Reference: Strens M. A Bayesian framework for reinforcement learning [C] @@ -227,7 +227,12 @@ def forward( act = self.model(batch.obs, state=state, info=batch.info) return cast(ActBatchProtocol, Batch(act=act)) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRLTrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TPSRLTrainingStats: n_s, n_a = self.model.n_state, self.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index d41ccb463..d0eb10b4c 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -9,7 +9,7 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy import PGPolicy +from tianshou.policy import Reinforce from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import ActorCritic @@ -30,7 +30,7 @@ class A2CTrainingStats(TrainingStats): # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] +class A2CPolicy(Reinforce[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. :param actor: the actor network following the rules: @@ -157,7 +157,7 @@ def _compute_returns( # TODO: mypy complains b/c signature is different from superclass, although # it's compatible. Can this be fixed? - def learn( # type: ignore + def _update_with_batch( # type: ignore self, batch: RolloutBatchProtocol, batch_size: int | None, diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index d7196a92b..3529d159c 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -162,7 +162,12 @@ def forward( result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TBDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 5bfdba0c1..49758a900 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -116,7 +116,12 @@ def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: ).clamp(0, 1) * next_dist.unsqueeze(1) return target_dist.sum(-1) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TC51TrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index f21744f72..5d66291b7 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -17,7 +17,7 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise, GaussianNoise -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.net.continuous import Actor, Critic @@ -31,7 +31,7 @@ class DDPGTrainingStats(TrainingStats): TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) -class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): +class DDPGPolicy(Algorithm[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. :param actor: The actor network following the rules (s -> actions) @@ -196,7 +196,7 @@ def _mse_optimizer( optimizer.step() return td, critic_loss - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore # critic td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index d1ce28da9..21575dcf0 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -127,7 +127,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: ) return target_q.sum(dim=-1) + self.alpha * dist.entropy() - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index e0ada0733..47b97c795 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -209,7 +209,12 @@ def forward( result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 9c87f9cac..99a64953a 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -152,7 +152,12 @@ def forward( # type: ignore ) return cast(FQFBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TFQFTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 75d76a2dd..1ecfae21f 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -130,7 +130,12 @@ def forward( result = Batch(logits=logits, act=act, state=hidden, taus=taus) return cast(QuantileRegressionBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TIQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 9e04d3feb..ce07e088e 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -123,7 +123,7 @@ def process_fn( batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() return batch - def learn( # type: ignore + def _update_with_batch( # type: ignore self, batch: Batch, batch_size: int | None, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 80bcff672..cde4fe545 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -21,8 +21,8 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy import Algorithm +from tianshou.policy.base import Policy, TLearningRateScheduler, TrainingStats from tianshou.utils import RunningMeanStd from tianshou.utils.net.continuous import ActorProb from tianshou.utils.net.discrete import Actor @@ -50,8 +50,75 @@ class PGTrainingStats(TrainingStats): TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats) -class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): - """Implementation of REINFORCE algorithm. +class ActorPolicy(Policy): + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb | Actor, + dist_fn: TDistFnDiscrOrCont, + action_space: gym.Space, + deterministic_eval: bool = False, + observation_space: gym.Space | None = None, + # TODO: why change the default from the base? + action_scaling: bool = True, + action_bound_method: Literal["clip", "tanh"] | None = "clip", + ) -> None: + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + if action_scaling and not np.isclose(actor.max_action, 1.0): + warnings.warn( + "action_scaling and action_bound_method are only intended" + "to deal with unbounded model action space, but find actor model" + f"bound action space with max_action={actor.max_action}." + "Consider using unbounded=True option of the actor model," + "or set action_scaling to False and action_bound_method to None.", + ) + self.actor = actor + self.dist_fn = dist_fn + self._eps = 1e-8 + self.deterministic_eval = deterministic_eval + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> DistBatchProtocol: + """Compute action over the given batch data by applying the actor. + + Will sample from the dist_fn, if appropriate. + Returns a new object representing the processed batch data + (contrary to other methods that modify the input batch inplace). + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + # TODO - ALGO: marked for algorithm refactoring + action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A + # therefore action_dist_input_BD is equivalent to logits_BA + # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian) + # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked + dist = self.dist_fn(action_dist_input_BD) + + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) + # act is of dimension BA in continuous case and of dimension B in discrete + result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) + return cast(DistBatchProtocol, result) + + +class Reinforce(Algorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats]): + """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm. :param actor: the actor network following the rules: If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). @@ -85,44 +152,23 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): def __init__( self, *, - actor: torch.nn.Module | ActorProb | Actor, - optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, + policy: ActorPolicy, discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - # TODO: why change the default from the base? - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", + optim: torch.optim.Optimizer, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, + policy=policy, lr_scheduler=lr_scheduler, ) - if action_scaling and not np.isclose(actor.max_action, 1.0): - warnings.warn( - "action_scaling and action_bound_method are only intended" - "to deal with unbounded model action space, but find actor model" - f"bound action space with max_action={actor.max_action}." - "Consider using unbounded=True option of the actor model," - "or set action_scaling to False and action_bound_method to None.", - ) - self.actor = actor self.optim = optim - self.dist_fn = dist_fn + self.ret_rms = RunningMeanStd() + self._eps = 1e-8 assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self.gamma = discount_factor self.rew_norm = reward_normalization - self.ret_rms = RunningMeanStd() - self._eps = 1e-8 - self.deterministic_eval = deterministic_eval def process_fn( self, @@ -172,42 +218,8 @@ def process_fn( batch: BatchWithReturnsProtocol return batch - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> DistBatchProtocol: - """Compute action over the given batch data by applying the actor. - - Will sample from the dist_fn, if appropriate. - Returns a new object representing the processed batch data - (contrary to other methods that modify the input batch inplace). - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - # TODO - ALGO: marked for algorithm refactoring - action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) - # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A - # therefore action_dist_input_BD is equivalent to logits_BA - # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian) - # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked - dist = self.dist_fn(action_dist_input_BD) - - act_B = ( - dist.mode - if self.deterministic_eval and not self.is_within_training_step - else dist.sample() - ) - # act is of dimension BA in continuous case and of dimension B in discrete - result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) - return cast(DistBatchProtocol, result) - # TODO: why does mypy complain? - def learn( # type: ignore + def _update_with_batch( # type: ignore self, batch: BatchWithReturnsProtocol, batch_size: int | None, @@ -220,7 +232,7 @@ def learn( # type: ignore for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): self.optim.zero_grad() - result = self(minibatch) + result = self.policy(minibatch) dist = result.dist act = to_torch_as(minibatch.act, result.act) ret = to_torch(minibatch.returns, torch.float, result.act.device) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index a4694b57b..c1b29fbf4 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -166,7 +166,7 @@ def process_fn( return batch # TODO: why does mypy complain? - def learn( # type: ignore + def _update_with_batch( # type: ignore self, batch: RolloutBatchProtocol, batch_size: int | None, diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 71c36de0c..bec8c7cce 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -104,7 +104,12 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: return super().compute_q_value(logits.mean(2), mask) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TQRDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index fad567cd2..fc5b6637f 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -47,7 +47,7 @@ class RainbowPolicy(C51Policy[TRainbowTrainingStats]): explanation. """ - def learn( + def _update_with_batch( self, batch: RolloutBatchProtocol, *args: Any, @@ -56,4 +56,4 @@ def learn( _sample_noise(self.model) if self._target and _sample_noise(self.model_old): self.model_old.train() # so that NoisyLinear takes effect - return super().learn(batch, **kwargs) + return super()._update_with_batch(batch, **kwargs) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 25f299733..dcfa1c39f 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -192,7 +192,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return target_q - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TREDQTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TREDQTrainingStats: # type: ignore # critic ensemble weight = getattr(batch, "weight", 1.0) current_qs = self.critic(batch.obs, batch.act).flatten(1) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index c1f19eff7..8ff349ea9 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -223,7 +223,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - self.alpha * obs_next_result.log_prob ) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 8c2ae8c98..e2560f9be 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -140,7 +140,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: self.critic2_old(obs_next_batch.obs, act_), ) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index e7aa5cfd5..c8530e258 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -103,7 +103,7 @@ def __init__( self.max_kl = max_kl self.backtrack_coeff = backtrack_coeff - def learn( # type: ignore + def _update_with_batch( # type: ignore self, batch: Batch, batch_size: int | None, diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 05cc8db8f..083ace04d 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -6,7 +6,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import TLearningRateScheduler, TrainingStats try: @@ -63,7 +63,7 @@ def __getitem__(self, index: str | IndexType) -> Any: ... -class MultiAgentPolicyManager(BasePolicy): +class MultiAgentPolicyManager(Algorithm): """Multi-agent policy manager for MARL. This multi-agent policy manager accepts a list of @@ -84,7 +84,7 @@ class MultiAgentPolicyManager(BasePolicy): def __init__( self, *, - policies: list[BasePolicy], + policies: list[Algorithm], # TODO: 1 why restrict to PettingZooEnv? # TODO: 2 This is the only policy that takes an env in init, is it really needed? env: PettingZooEnv, @@ -107,11 +107,11 @@ def __init__( # (this MultiAgentPolicyManager) policy.set_agent_id(env.agents[i]) - self.policies: dict[str | int, BasePolicy] = dict(zip(env.agents, policies, strict=True)) + self.policies: dict[str | int, Algorithm] = dict(zip(env.agents, policies, strict=True)) """Maps agent_id to policy.""" # TODO: unused - remove it? - def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: + def replace_policy(self, policy: Algorithm, agent_id: int) -> None: """Replace the "agent_id"th policy in this manager.""" policy.set_agent_id(agent_id) self.policies[agent_id] = policy @@ -259,7 +259,7 @@ def forward( # type: ignore return holder # Violates Liskov substitution principle - def learn( # type: ignore + def _update_with_batch( # type: ignore self, batch: MAPRolloutBatchProtocol, *args: Any, @@ -273,7 +273,7 @@ def learn( # type: ignore for agent_id, policy in self.policies.items(): data = batch[agent_id] if len(data.get_keys()) != 0: - train_stats = policy.learn(batch=data, **kwargs) + train_stats = policy._update_with_batch(batch=data, **kwargs) agent_id_to_stats[agent_id] = train_stats return MapTrainingStats(agent_id_to_stats) diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index bf665bc2b..0f596782e 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -5,7 +5,7 @@ from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import TrainingStats @@ -16,7 +16,7 @@ class MARLRandomTrainingStats(TrainingStats): TMARLRandomTrainingStats = TypeVar("TMARLRandomTrainingStats", bound=MARLRandomTrainingStats) -class MARLRandomPolicy(BasePolicy[TMARLRandomTrainingStats]): +class MARLRandomPolicy(Algorithm[TMARLRandomTrainingStats]): """A random agent used in multi-agent learning. It randomly chooses an action from the legal action. @@ -49,6 +49,6 @@ def forward( result = Batch(act=logits.argmax(axis=-1)) return cast(ActBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TMARLRandomTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TMARLRandomTrainingStats: # type: ignore """Since a random agent learns nothing, it returns an empty dict.""" return MARLRandomTrainingStats() # type: ignore[return-value] diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 355c9d33b..67be6407c 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -19,7 +19,7 @@ ) from tianshou.data.buffer.base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase -from tianshou.policy import BasePolicy +from tianshou.policy import Algorithm from tianshou.policy.base import TrainingStats from tianshou.trainer.utils import gather_info, test_episode from tianshou.utils import ( @@ -151,7 +151,7 @@ def gen_doc(learning_type: str) -> str: def __init__( self, - policy: BasePolicy, + policy: Algorithm, max_epoch: int, batch_size: int | None, train_collector: BaseCollector | None = None, @@ -167,7 +167,7 @@ def __init__( test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, compute_score_fn: Callable[[CollectStats], float] | None = None, - save_best_fn: Callable[[BasePolicy], None] | None = None, + save_best_fn: Callable[[Algorithm], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, @@ -278,7 +278,7 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No self._reset_collectors(reset_buffer=reset_buffer) if self.train_collector is not None and ( - self.train_collector.policy != self.policy or self.test_collector is None + self.train_collector.algorithm != self.policy or self.test_collector is None ): self.test_in_train = False diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 1ffb9fcd8..9ee90167c 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -8,7 +8,7 @@ from torch import nn if TYPE_CHECKING: - from tianshou.policy import BasePolicy + from tianshou.policy.base import Policy @contextmanager @@ -23,7 +23,7 @@ def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: @contextmanager -def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> Iterator[None]: +def policy_within_training_step(policy: Policy, enabled: bool = True) -> Iterator[None]: """Temporarily switch to `policy.is_within_training_step=enabled`. Enabling this ensures that the policy is able to adapt its behavior, From 14890a69d187efb91807efd393980473335e1b6e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 3 Mar 2025 20:02:48 +0100 Subject: [PATCH 002/230] v2: Refactoring of Trainer handling, establishing trainer configuration objects; Algorithms can now create and apply their trainer --- test/discrete/test_pg.py | 12 +- tianshou/policy/base.py | 31 +++- tianshou/policy/modelfree/pg.py | 10 +- tianshou/trainer/base.py | 266 +++++++++++++++++++++++--------- tianshou/utils/torch_utils.py | 2 +- 5 files changed, 235 insertions(+), 86 deletions(-) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 045661ba1..2fc2b4ed9 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -12,7 +12,7 @@ from tianshou.policy import Reinforce from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -111,9 +111,8 @@ def save_best_fn(algorithm: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OnpolicyTrainer( - policy=algorithm, + # train + training_config = OnPolicyTrainingConfig( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, @@ -122,8 +121,11 @@ def stop_fn(mean_rewards: float) -> bool: episode_per_test=args.test_num, batch_size=args.batch_size, episode_per_collect=args.episode_per_collect, + step_per_collect=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - ).run() + ) + result = algorithm.train(training_config) + assert stop_fn(result.best_reward) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 798b65ac7..3e5b11946 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -30,7 +30,11 @@ from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode if TYPE_CHECKING: - from tianshou.highlevel.config import SamplingConfig + from tianshou.trainer.base import ( + BaseTrainer, + OnpolicyTrainer, + OnPolicyTrainingConfig, + ) logger = logging.getLogger(__name__) @@ -344,9 +348,12 @@ def _compile() -> None: TPolicy = TypeVar("TPolicy", bound=Policy) +TTrainingConfig = TypeVar( + "TTrainingConfig", +) # TODO Can't use bound=TrainingConfig because of circular import -class Algorithm(Generic[TPolicy, TTrainingStats], ABC): +class Algorithm(Generic[TPolicy, TTrainingConfig, TTrainingStats], ABC): """ TODO fix docstring The base class for any RL policy. @@ -751,13 +758,27 @@ def compute_nstep_return( return cast(BatchWithReturnsProtocol, batch) - def _create_trainer(self): + @abstractmethod + def _create_trainer(self, config: TTrainingConfig) -> "BaseTrainer": pass - def train(self, sampling_config: "SamplingConfig"): - pass + def train(self, config: TTrainingConfig): + trainer = self._create_trainer(config) + return trainer.run() + + +class OnPolicyAlgorithm( + Algorithm[TPolicy, "OnPolicyTrainingConfig", TTrainingStats], + Generic[TPolicy, TTrainingStats], + ABC, +): + def _create_trainer(self, config: "OnPolicyTrainingConfig") -> "OnpolicyTrainer": + from tianshou.trainer.base import OnpolicyTrainer + + return OnpolicyTrainer(self, config) +# TODO must become Policy not Algorithm class RandomActionPolicy(Algorithm): def __init__( self, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index cde4fe545..ad5efeddb 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -21,8 +21,12 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import Algorithm -from tianshou.policy.base import Policy, TLearningRateScheduler, TrainingStats +from tianshou.policy.base import ( + OnPolicyAlgorithm, + Policy, + TLearningRateScheduler, + TrainingStats, +) from tianshou.utils import RunningMeanStd from tianshou.utils.net.continuous import ActorProb from tianshou.utils.net.discrete import Actor @@ -117,7 +121,7 @@ def forward( return cast(DistBatchProtocol, result) -class Reinforce(Algorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats]): +class Reinforce(OnPolicyAlgorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats]): """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm. :param actor: the actor network following the rules: diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 67be6407c..7ccc41b6d 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -3,11 +3,14 @@ from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Callable -from dataclasses import asdict +from dataclasses import asdict, dataclass from functools import partial +from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np import tqdm +from sensai.util.helper import count_none +from sensai.util.string import ToStringMixin from tianshou.data import ( AsyncCollector, @@ -19,7 +22,6 @@ ) from tianshou.data.buffer.base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase -from tianshou.policy import Algorithm from tianshou.policy.base import TrainingStats from tianshou.trainer.utils import gather_info, test_episode from tianshou.utils import ( @@ -30,10 +32,140 @@ from tianshou.utils.logging import set_numerical_fields_to_precision from tianshou.utils.torch_utils import policy_within_training_step +if TYPE_CHECKING: + from tianshou.policy import Algorithm + log = logging.getLogger(__name__) -class BaseTrainer(ABC): +@dataclass +class TrainingConfig(ToStringMixin): + max_epoch: int = 100 + """ + the number of epochs to run training for. An epoch is the outermost iteration level and each + epoch consists of a number of training steps and a test step, where each training step + + * collects environment steps/transitions (collection step), adding them to the (replay) + buffer (see :attr:`step_per_collect`) + * performs one or more gradient updates (see :attr:`update_per_step`), + + and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate + agent performance. + + The number of training steps in each epoch is indirectly determined by + :attr:`step_per_epoch`: As many training steps will be performed as are required in + order to reach :attr:`step_per_epoch` total steps in the training environments. + Specifically, if the number of transitions collected per step is `c` (see + :attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number + of training steps per epoch is `ceil(s / c)`. + + Therefore, if `num_epochs = e`, the total number of environment steps taken during training + can be computed as `e * ceil(s / c) * c`. + """ + + step_per_epoch: int = 30000 + """ + the total number of environment steps to be made per epoch. See :attr:`num_epochs` for + an explanation of epoch semantics. + """ + + episode_per_test: int = 1 + """the total number of episodes to collect in each test step (across all test environments). + """ + + buffer_size: int = 4096 + """the total size of the sample/replay buffer, in which environment steps (transitions) are + stored""" + + step_per_collect: int | None = 2048 + """ + the number of environment steps/transitions to collect in each collection step before the + network update within each training step. + + This is mutually exclusive with :attr:`episode_per_collect`, and one of the two must be set. + + Note that the exact number can be reached only if this is a multiple of the number of + training environments being used, as each training environment will produce the same + (non-zero) number of transitions. + Specifically, if this is set to `n` and `m` training environments are used, then the total + number of transitions collected per collection step is `ceil(n / m) * m =: c`. + + See :attr:`num_epochs` for information on the total number of environment steps being + collected during training. + """ + + episode_per_collect: int | None = None + """ + the number of episodes to collect in each collection step before the network update within + each training step. If this is set, the number of environment steps collected in each + collection step is the sum of the lengths of the episodes collected. + + This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. + """ + + # TODO copy docstrings from BaseTrainer + train_collector: BaseCollector | None = None + test_collector: BaseCollector | None = None + buffer: ReplayBuffer | None = None + train_fn: Callable[[int, int], None] | None = None + test_fn: Callable[[int, int | None], None] | None = None + stop_fn: Callable[[float], bool] | None = None + compute_score_fn: Callable[[CollectStats], float] | None = None + save_best_fn: Callable[["Algorithm"], None] | None = None + save_checkpoint_fn: Callable[[int, int, int], str] | None = None + resume_from_log: bool = False + reward_metric: Callable[[np.ndarray], np.ndarray] | None = None + logger: BaseLogger | None = None + verbose: bool = True + show_progress: bool = True + test_in_train: bool = True + + def __post_init__(self): + if count_none(self.step_per_collect, self.episode_per_collect) != 1: + raise ValueError("Exactly one of {step_per_collect, episode_per_collect} must be set") + + +@dataclass +class OnPolicyTrainingConfig(TrainingConfig): + batch_size: int | None = 64 + """ + Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, a form of regularization). + Set ``batch_size=None`` for the full buffer to be used for the gradient update (no mini-batching). + """ + + repeat_per_collect: int | None = 1 + """ + controls, within one gradient update step of an on-policy algorithm, the number of times an + actual gradient update is applied using the full collected dataset, i.e. if the parameter is + 5, then the collected data shall be used five times to update the policy within the same + training step. + """ + + +@dataclass +class OffPolicyTrainingConfig(TrainingConfig): + batch_size: int = 64 + """ + the the number of environment steps/transitions to sample from the buffer for a gradient update. + """ + + update_per_step: float = 1.0 + """ + the number of gradient steps to perform per sample collected (see :attr:`step_per_collect`). + Specifically, if this is set to `u` and the number of samples collected in the preceding + collection step is `n`, then `round(u * n)` gradient steps will be performed. + """ + + +@dataclass +class OfflineTrainingConfig(OffPolicyTrainingConfig): + pass + + +TConfig = TypeVar("TConfig", bound=TrainingConfig) + + +class BaseTrainer(Generic[TConfig], ABC): """An iterator base class for trainers. Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results @@ -151,40 +283,20 @@ def gen_doc(learning_type: str) -> str: def __init__( self, - policy: Algorithm, - max_epoch: int, - batch_size: int | None, - train_collector: BaseCollector | None = None, - test_collector: BaseCollector | None = None, - buffer: ReplayBuffer | None = None, - step_per_epoch: int | None = None, - repeat_per_collect: int | None = None, - episode_per_test: int | None = None, - update_per_step: float = 1.0, - step_per_collect: int | None = None, - episode_per_collect: int | None = None, - train_fn: Callable[[int, int], None] | None = None, - test_fn: Callable[[int, int | None], None] | None = None, - stop_fn: Callable[[float], bool] | None = None, - compute_score_fn: Callable[[CollectStats], float] | None = None, - save_best_fn: Callable[[Algorithm], None] | None = None, - save_checkpoint_fn: Callable[[int, int, int], str] | None = None, - resume_from_log: bool = False, - reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, - logger: BaseLogger | None = None, - verbose: bool = True, - show_progress: bool = True, - test_in_train: bool = True, + policy: "Algorithm", + config: TConfig, ): + logger = config.logger logger = logger or LazyLogger() self.policy = policy + buffer = config.buffer if buffer is not None: buffer = policy.process_buffer(buffer) self.buffer = buffer - self.train_collector = train_collector - self.test_collector = test_collector + self.train_collector = config.train_collector + self.test_collector = config.test_collector self.logger = logger self.start_time = time.time() @@ -199,27 +311,24 @@ def __init__( self.env_step = 0 self.env_episode = 0 self.policy_update_time = 0.0 - self.max_epoch = max_epoch + self.max_epoch = config.max_epoch assert ( - step_per_epoch is not None + config.step_per_epoch is not None ), "The trainer requires step_per_epoch to be set, sorry for the wrong type hint" - self.step_per_epoch: int = step_per_epoch + self.step_per_epoch: int = config.step_per_epoch # either on of these two - self.step_per_collect = step_per_collect - self.episode_per_collect = episode_per_collect - - self.update_per_step = update_per_step - self.repeat_per_collect = repeat_per_collect + self.step_per_collect = config.step_per_collect + self.episode_per_collect = config.episode_per_collect - self.episode_per_test = episode_per_test + self.config = config + self.episode_per_test = config.episode_per_test - self.batch_size = batch_size - - self.train_fn = train_fn - self.test_fn = test_fn - self.stop_fn = stop_fn + self.train_fn = config.train_fn + self.test_fn = config.test_fn + self.stop_fn = config.stop_fn self.compute_score_fn: Callable[[CollectStats], float] + compute_score_fn = config.compute_score_fn if compute_score_fn is None: def compute_score_fn(stat: CollectStats) -> float: @@ -227,14 +336,14 @@ def compute_score_fn(stat: CollectStats) -> float: return stat.returns_stat.mean self.compute_score_fn = compute_score_fn - self.save_best_fn = save_best_fn - self.save_checkpoint_fn = save_checkpoint_fn + self.save_best_fn = config.save_best_fn + self.save_checkpoint_fn = config.save_checkpoint_fn - self.reward_metric = reward_metric - self.verbose = verbose - self.show_progress = show_progress - self.test_in_train = test_in_train - self.resume_from_log = resume_from_log + self.reward_metric = config.reward_metric + self.verbose = config.verbose + self.show_progress = config.show_progress + self.test_in_train = config.test_in_train + self.resume_from_log = config.resume_from_log self.is_run = False self.last_rew, self.last_len = 0.0, 0.0 @@ -374,6 +483,7 @@ def __next__(self) -> EpochStats: t.update() steps_done_in_this_epoch += 1 + # TODO What is this doing here? Where to put it? # for offline RL if self.train_collector is None: assert self.buffer is not None @@ -464,7 +574,7 @@ def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. """ - with policy_within_training_step(self.policy): + with policy_within_training_step(self.policy.policy): should_stop_training = False collect_stats: CollectStatsBase | CollectStats @@ -547,7 +657,7 @@ def _update_best_reward_and_return_should_stop_training( should_stop_training = False # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics - with policy_within_training_step(self.policy, enabled=False): + with policy_within_training_step(self.policy.policy, enabled=False): if ( collect_stats.n_collected_episodes > 0 and self.test_in_train @@ -642,18 +752,8 @@ def run(self, reset_prior_to_run: bool = True, reset_buffer: bool = False) -> In return info - def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: - """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" - self._gradient_step += 1 - # Note: since sample_size=batch_size, this will perform - # exactly one gradient step. This is why we don't need to calculate the - # number of gradient steps, like in the on-policy case. - update_stat = self.policy.update(sample_size=self.batch_size, buffer=buffer) - self._update_moving_avg_stats_and_log_update_data(update_stat) - return update_stat - -class OfflineTrainer(BaseTrainer): +class OfflineTrainer(BaseTrainer[OfflineTrainingConfig]): """Offline trainer, samples mini-batches from buffer and passes them to update. Uses a buffer directly and usually does not have a collector. @@ -674,8 +774,18 @@ def policy_update_fn( self.policy_update_time += update_stat.train_time return update_stat + def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: + """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" + self._gradient_step += 1 + # Note: since sample_size=batch_size, this will perform + # exactly one gradient step. This is why we don't need to calculate the + # number of gradient steps, like in the on-policy case. + update_stat = self.policy.update(sample_size=self.config.batch_size, buffer=buffer) + self._update_moving_avg_stats_and_log_update_data(update_stat) + return update_stat + -class OffpolicyTrainer(BaseTrainer): +class OffpolicyTrainer(BaseTrainer[OffPolicyTrainingConfig]): """Offpolicy trainer, samples mini-batches from buffer and passes them to update. Note that with this trainer, it is expected that the policy's `learn` method @@ -699,11 +809,11 @@ def policy_update_fn( """ assert self.train_collector is not None n_collected_steps = collect_stats.n_collected_steps - n_gradient_steps = round(self.update_per_step * n_collected_steps) + n_gradient_steps = round(self.config.update_per_step * n_collected_steps) if n_gradient_steps == 0: raise ValueError( f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " - f"update_per_step={self.update_per_step}", + f"update_per_step={self.config.update_per_step}", ) for _ in self._pbar( @@ -717,8 +827,18 @@ def policy_update_fn( # TODO: only the last update_stat is returned, should be improved return update_stat + def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: + """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" + self._gradient_step += 1 + # Note: since sample_size=batch_size, this will perform + # exactly one gradient step. This is why we don't need to calculate the + # number of gradient steps, like in the on-policy case. + update_stat = self.policy.update(sample_size=self.config.batch_size, buffer=buffer) + self._update_moving_avg_stats_and_log_update_data(update_stat) + return update_stat + -class OnpolicyTrainer(BaseTrainer): +class OnpolicyTrainer(BaseTrainer[OnPolicyTrainingConfig]): """On-policy trainer, passes the entire buffer to .update and resets it after. Note that it is expected that the learn method of a policy will perform @@ -747,8 +867,8 @@ def policy_update_fn( # The kwargs are in the end passed to the .learn method, which uses # batch_size to iterate through the buffer in mini-batches # Off-policy algos typically don't use the batch_size kwarg at all - batch_size=self.batch_size, - repeat=self.repeat_per_collect, + batch_size=self.config.batch_size, + repeat=self.config.repeat_per_collect, ) # just for logging, no functional role @@ -756,10 +876,12 @@ def policy_update_fn( # TODO: remove the gradient step counting in trainers? Doesn't seem like # it's important and it adds complexity self._gradient_step += 1 - if self.batch_size is None: + if self.config.batch_size is None: self._gradient_step += 1 - elif self.batch_size > 0: - self._gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) + elif self.config.batch_size > 0: + self._gradient_step += int( + (len(self.train_collector.buffer) - 0.1) // self.config.batch_size, + ) # Note 1: this is the main difference to the off-policy trainer! # The second difference is that batches of data are sampled without replacement diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 9ee90167c..8526c5303 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -23,7 +23,7 @@ def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: @contextmanager -def policy_within_training_step(policy: Policy, enabled: bool = True) -> Iterator[None]: +def policy_within_training_step(policy: "Policy", enabled: bool = True) -> Iterator[None]: """Temporarily switch to `policy.is_within_training_step=enabled`. Enabling this ensures that the policy is able to adapt its behavior, From 1f33bf2965fa56bde12bbdb2d27e7347fba5cfe8 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 3 Mar 2025 22:53:00 +0100 Subject: [PATCH 003/230] Disable ruff COM812 --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5640f948b..d2ba919a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,7 +181,8 @@ ignore = [ "PLW2901", # overwrite vars in loop "B027", # empty and non-abstract method in abstract class "D404", # It's fine to start with "This" in docstrings - "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx + "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx + "COM812", # missing trailing comma: With this enabled, re-application of "poe format" chain can cause additional commas and subsequent reformatting ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all From 9f1e685ef61ad1ad667e98ad547b5d3340a72b46 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 19:42:00 +0100 Subject: [PATCH 004/230] v2: Adapt DQN and example discrete_dqn --- README.md | 2 +- examples/atari/atari_dqn.py | 16 ++- examples/box2d/acrobot_dualdqn.py | 4 +- examples/box2d/lunarlander_dqn.py | 4 +- examples/discrete/discrete_dqn.py | 47 ++++--- test/discrete/test_dqn.py | 4 +- test/discrete/test_drqn.py | 4 +- test/discrete/test_pg.py | 2 +- test/modelbased/test_dqn_icm.py | 4 +- test/pettingzoo/pistonball.py | 4 +- test/pettingzoo/tic_tac_toe.py | 4 +- tianshou/highlevel/agent.py | 8 +- tianshou/highlevel/trainer.py | 8 +- tianshou/policy/__init__.py | 4 +- tianshou/policy/base.py | 20 ++- tianshou/policy/imitation/discrete_bcq.py | 4 +- tianshou/policy/modelfree/bdq.py | 4 +- tianshou/policy/modelfree/c51.py | 4 +- tianshou/policy/modelfree/dqn.py | 159 ++++++++++++---------- tianshou/policy/modelfree/fqf.py | 4 +- tianshou/policy/modelfree/qrdqn.py | 4 +- tianshou/trainer/base.py | 8 +- 22 files changed, 180 insertions(+), 142 deletions(-) diff --git a/README.md b/README.md index ee69f6ae6..6582eea5c 100644 --- a/README.md +++ b/README.md @@ -370,7 +370,7 @@ optim = torch.optim.Adam(net.parameters(), lr=lr) Set up the policy and collectors: ```python -policy = ts.policy.DQNPolicy( +policy = ts.policy.DeepQLearning( model=net, optim=optim, discount_factor=gamma, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 023a961f5..f2863a8d0 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -6,13 +6,13 @@ import numpy as np import torch -from atari_network import DQN from atari_wrapper import make_atari_env +from examples.atari.atari_network import DQN from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DQNPolicy -from tianshou.policy.base import BasePolicy +from tianshou.policy import DeepQLearning +from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -104,8 +104,8 @@ def main(args: argparse.Namespace = get_args()) -> None: net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: DQNPolicy | ICMPolicy - policy = DQNPolicy( + policy: DeepQLearning | ICMPolicy + policy = DeepQLearning( model=net, optim=optim, action_space=env.action_space, @@ -114,7 +114,9 @@ def main(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ) if args.icm_lr_scale > 0: - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + feature_net = DeepQLearning( + *args.state_shape, args.action_shape, args.device, features_only=True + ) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( @@ -172,7 +174,7 @@ def main(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - def save_best_fn(policy: BasePolicy) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index d1709a9ec..47bee824b 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DQNPolicy +from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -75,7 +75,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: dueling_param=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy = DQNPolicy( + policy: DeepQLearning = DeepQLearning( model=net, optim=optim, action_space=env.action_space, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index fea3096ba..75fc5eaee 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import DQNPolicy +from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -77,7 +77,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: dueling_param=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy = DQNPolicy( + policy: DeepQLearning = DeepQLearning( model=net, optim=optim, action_space=env.action_space, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 4e52f4ce2..2916aaaa8 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -4,6 +4,8 @@ import tianshou as ts from tianshou.data import CollectStats +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils.space_info import SpaceInfo @@ -35,22 +37,22 @@ def main() -> None: net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = torch.optim.Adam(net.parameters(), lr=lr) - policy: ts.policy.DQNPolicy = ts.policy.DQNPolicy( - model=net, + policy = DQNPolicy(model=net, action_space=env.action_space) + algorithm: ts.policy.DeepQLearning = ts.policy.DeepQLearning( + policy=policy, optim=optim, discount_factor=gamma, - action_space=env.action_space, estimation_step=n_step, target_update_freq=target_freq, ) train_collector = ts.data.Collector[CollectStats]( - policy, + algorithm, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True, ) test_collector = ts.data.Collector[CollectStats]( - policy, + algorithm, test_envs, exploration_noise=True, ) # because DQN uses epsilon-greedy method @@ -63,26 +65,27 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= env.spec.reward_threshold return False - result = ts.trainer.OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=epoch, - step_per_epoch=step_per_epoch, - step_per_collect=step_per_collect, - episode_per_test=test_num, - batch_size=batch_size, - update_per_step=1 / step_per_collect, - train_fn=lambda epoch, env_step: policy.set_eps(eps_train), - test_fn=lambda epoch, env_step: policy.set_eps(eps_test), - stop_fn=stop_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=epoch, + step_per_epoch=step_per_epoch, + step_per_collect=step_per_collect, + episode_per_test=test_num, + batch_size=batch_size, + update_per_step=1 / step_per_collect, + train_fn=lambda epoch, env_step: algorithm.set_eps(eps_train), + test_fn=lambda epoch, env_step: algorithm.set_eps(eps_test), + stop_fn=stop_fn, + logger=logger, + ) + ) print(f"Finished training in {result.timing.total_time} seconds") # watch performance - policy.set_eps(eps_test) - collector = ts.data.Collector[CollectStats](policy, env, exploration_noise=True) + algorithm.set_eps(eps_test) + collector = ts.data.Collector[CollectStats](algorithm, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 842555d0c..7c7588518 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -14,7 +14,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DQNPolicy +from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -87,7 +87,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # dueling=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy = DQNPolicy( + policy: DeepQLearning = DeepQLearning( model=net, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 6030fc03c..b3e8bb381 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -8,7 +8,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DQNPolicy +from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -74,7 +74,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: args.device, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy = DQNPolicy( + policy: DeepQLearning = DeepQLearning( model=net, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 2fc2b4ed9..5192c174d 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -126,6 +126,6 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ) - result = algorithm.train(training_config) + result = algorithm.run_training(training_config) assert stop_fn(result.best_reward) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 9d3a717b0..13189cadb 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -13,7 +13,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DQNPolicy, ICMPolicy +from tianshou.policy import DeepQLearning, ICMPolicy from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNTrainingStats from tianshou.trainer import OffpolicyTrainer @@ -108,7 +108,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: # dueling=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQNPolicy[DQNTrainingStats] = DQNPolicy( + policy: DeepQLearning[DQNTrainingStats] = DeepQLearning( model=net, optim=optim, action_space=env.action_space, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 88186c177..05ef337c8 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import Algorithm, DQNPolicy, MultiAgentPolicyManager +from tianshou.policy import Algorithm, DeepQLearning, MultiAgentPolicyManager from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -97,7 +97,7 @@ def get_agents( device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent: DQNPolicy = DQNPolicy( + agent: DeepQLearning = DeepQLearning( model=net, optim=optim, action_space=env.action_space, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 9e74c003e..af17c24b7 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -15,7 +15,7 @@ from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import ( BasePolicy, - DQNPolicy, + DeepQLearning, MARLRandomPolicy, MultiAgentPolicyManager, ) @@ -120,7 +120,7 @@ def get_agents( ).to(args.device) if optim is None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent_learn = DQNPolicy( + agent_learn = DeepQLearning( model=net, optim=optim, action_space=env.action_space, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 9f7f1011c..c63f0395e 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -48,8 +48,8 @@ A2CPolicy, Algorithm, DDPGPolicy, + DeepQLearning, DiscreteSACPolicy, - DQNPolicy, IQNPolicy, NPGPolicy, PPOPolicy, @@ -419,9 +419,9 @@ def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: ) -class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]): - def _get_policy_class(self) -> type[DQNPolicy]: - return DQNPolicy +class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DeepQLearning]): + def _get_policy_class(self) -> type[DeepQLearning]: + return DeepQLearning class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]): diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index b4c2357bb..04b43a227 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -8,7 +8,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger -from tianshou.policy import Algorithm, DQNPolicy +from tianshou.policy import Algorithm, DeepQLearning TPolicy = TypeVar("TPolicy", bound=Algorithm) log = logging.getLogger(__name__) @@ -90,7 +90,7 @@ def __init__(self, eps_test: float): self.eps_test = eps_test def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy = cast(DQNPolicy, context.policy) + policy = cast(DeepQLearning, context.policy) policy.set_eps(self.eps_test) @@ -105,7 +105,7 @@ def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = self.decay_steps = decay_steps def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy = cast(DQNPolicy, context.policy) + policy = cast(DeepQLearning, context.policy) logger = context.logger if env_step <= self.decay_steps: eps = self.eps_train - env_step / self.decay_steps * ( @@ -126,7 +126,7 @@ def __init__(self, eps_test: float): self.eps_test = eps_test def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: - policy = cast(DQNPolicy, context.policy) + policy = cast(DeepQLearning, context.policy) policy.set_eps(self.eps_test) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 1716f364e..31c4b1d68 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -3,10 +3,10 @@ from tianshou.policy.base import Algorithm, TrainingStats from tianshou.policy.modelfree.pg import Reinforce +from tianshou.policy.modelfree.dqn import DeepQLearning """ from tianshou.policy.random import MARLRandomPolicy -from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.bdq import BranchingDQNPolicy from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.modelfree.rainbow import RainbowPolicy @@ -38,7 +38,7 @@ __all__ = [ "Algorithm", # "MARLRandomPolicy", - # "DQNPolicy", + "DeepQLearning", # "BranchingDQNPolicy", # "C51Policy", # "RainbowPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3e5b11946..e5de885e1 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -32,6 +32,8 @@ if TYPE_CHECKING: from tianshou.trainer.base import ( BaseTrainer, + OffpolicyTrainer, + OffPolicyTrainingConfig, OnpolicyTrainer, OnPolicyTrainingConfig, ) @@ -353,7 +355,7 @@ def _compile() -> None: ) # TODO Can't use bound=TrainingConfig because of circular import -class Algorithm(Generic[TPolicy, TTrainingConfig, TTrainingStats], ABC): +class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainingConfig, TTrainingStats], ABC): """ TODO fix docstring The base class for any RL policy. @@ -416,7 +418,8 @@ def __init__( policy: TPolicy, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: - self.policy = policy + super().__init__() + self.policy: TPolicy = policy self.lr_scheduler = lr_scheduler # TODO delete this @@ -762,7 +765,7 @@ def compute_nstep_return( def _create_trainer(self, config: TTrainingConfig) -> "BaseTrainer": pass - def train(self, config: TTrainingConfig): + def run_training(self, config: TTrainingConfig): trainer = self._create_trainer(config) return trainer.run() @@ -778,6 +781,17 @@ def _create_trainer(self, config: "OnPolicyTrainingConfig") -> "OnpolicyTrainer" return OnpolicyTrainer(self, config) +class OffPolicyAlgorithm( + Algorithm[TPolicy, "OffPolicyTrainingConfig", TTrainingStats], + Generic[TPolicy, TTrainingStats], + ABC, +): + def _create_trainer(self, config: "OffPolicyTrainingConfig") -> "OffpolicyTrainer": + from tianshou.trainer.base import OffpolicyTrainer + + return OffpolicyTrainer(self, config) + + # TODO must become Policy not Algorithm class RandomActionPolicy(Algorithm): def __init__( diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 86a059d20..f6b0b4d43 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -13,7 +13,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import DQNPolicy +from tianshou.policy import DeepQLearning from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNTrainingStats @@ -31,7 +31,7 @@ class DiscreteBCQTrainingStats(DQNTrainingStats): TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteBCQTrainingStats) -class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): +class DiscreteBCQPolicy(DeepQLearning[TDiscreteBCQTrainingStats]): """Implementation of discrete BCQ algorithm. arXiv:1910.01708. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index 3529d159c..455711ce0 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -14,7 +14,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import DQNPolicy +from tianshou.policy import DeepQLearning from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNTrainingStats from tianshou.utils.net.common import BranchingNet @@ -28,7 +28,7 @@ class BDQNTrainingStats(DQNTrainingStats): TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) -class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): +class BranchingDQNPolicy(DeepQLearning[TBDQNTrainingStats]): """Implementation of the Branching dual Q network arXiv:1711.08946. :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 49758a900..1d858de38 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -7,7 +7,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import DQNPolicy +from tianshou.policy import DeepQLearning from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNTrainingStats @@ -20,7 +20,7 @@ class C51TrainingStats(DQNTrainingStats): TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) -class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): +class C51Policy(DeepQLearning[TC51TrainingStats], Generic[TC51TrainingStats]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 47b97c795..ba0aecb50 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,10 +1,11 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar, cast +from typing import Any, Generic, Self, TypeVar, cast import gymnasium as gym import numpy as np import torch +from sensai.util.helper import mark_used from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data.batch import BatchProtocol @@ -15,10 +16,16 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import ( + OffPolicyAlgorithm, + Policy, + TLearningRateScheduler, + TrainingStats, +) from tianshou.utils.net.common import Net +mark_used(ActBatchProtocol) + @dataclass(kw_only=True) class DQNTrainingStats(TrainingStats): @@ -28,7 +35,72 @@ class DQNTrainingStats(TrainingStats): TDQNTrainingStats = TypeVar("TDQNTrainingStats", bound=DQNTrainingStats) -class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): +class DQNPolicy(Policy): + def __init__( + self, + *, + model: torch.nn.Module | Net, + action_space: gym.spaces.Discrete, + observation_space: gym.Space | None = None, + ) -> None: + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=False, + action_bound_method=None, + ) + self.model = model + self.max_action_num: int | None = None + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: torch.nn.Module | None = None, + **kwargs: Any, + ) -> ModelOutputBatchProtocol: + """Compute action over the given batch data. + + If you need to mask the action, please add a "mask" into batch.obs, for + example, if we have an environment that has "0/1/2" three actions: + :: + + batch == Batch( + obs=Batch( + obs="original obs, with batch_size=1 for demonstration", + mask=np.array([[False, True, False]]), + # action 1 is available + # action 0 and 2 are unavailable + ), + ... + ) + + :return: A :class:`~tianshou.data.Batch` which has 3 keys: + + * ``act`` the action. + * ``logits`` the network's raw output. + * ``state`` the hidden state. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + if model is None: + model = self.model + obs = batch.obs + # TODO: this is convoluted! See also other places where this is done. + obs_next = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info) + q = DeepQLearning.compute_q_value(action_values_BA, getattr(obs, "mask", None)) + if self.max_action_num is None: + self.max_action_num = q.shape[1] + act_B = to_numpy(q.argmax(dim=1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + return cast(ModelOutputBatchProtocol, result) + + +class DeepQLearning(OffPolicyAlgorithm[DQNPolicy, TDQNTrainingStats], Generic[TDQNTrainingStats]): """Implementation of Deep Q Network. arXiv:1312.5602. Implementation of Double Q-Learning. arXiv:1509.06461. @@ -60,27 +132,21 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): def __init__( self, *, - model: torch.nn.Module | Net, + policy: DQNPolicy, optim: torch.optim.Optimizer, # TODO: type violates Liskov substitution principle - action_space: gym.spaces.Discrete, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=False, - action_bound_method=None, + policy=policy, lr_scheduler=lr_scheduler, ) - self.model = model self.optim = optim self.eps = 0.0 assert ( @@ -95,15 +161,13 @@ def __init__( self.freq = target_update_freq self._iter = 0 if self._target: - self.model_old = deepcopy(self.model) + self.model_old = deepcopy(self.policy.model) self.model_old.eval() self.rew_norm = reward_normalization self.is_double = is_double self.clip_loss_grad = clip_loss_grad - # TODO: set in forward, fix this! - self.max_action_num: int | None = None - + # TODO: Should use two eps parameters: one for training, one for inference/testing def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" self.eps = eps @@ -111,22 +175,22 @@ def set_eps(self, eps: float) -> None: def train(self, mode: bool = True) -> Self: """Set the module in training mode, except for the target network.""" self.training = mode - self.model.train(mode) + self.policy.train(mode) return self def sync_weight(self) -> None: """Synchronize the weight for the target network.""" - self.model_old.load_state_dict(self.model.state_dict()) + self.model_old.load_state_dict(self.policy.model.state_dict()) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} - result = self(obs_next_batch) + result = self.policy(obs_next_batch) if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - target_q = self(obs_next_batch, model="model_old").logits + target_q = self.policy(obs_next_batch, model=self.model_old).logits else: target_q = result.logits if self.is_double: @@ -155,7 +219,8 @@ def process_fn( rew_norm=self.rew_norm, ) - def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + @staticmethod + def compute_q_value(logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: """Compute the q value based on the network's raw output and action mask.""" if mask is not None: # the masked q value should be smaller than logits.min() @@ -163,52 +228,6 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc logits = logits + to_torch_as(1 - mask, logits) * min_value return logits - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - model: Literal["model", "model_old"] = "model", - **kwargs: Any, - ) -> ModelOutputBatchProtocol: - """Compute action over the given batch data. - - If you need to mask the action, please add a "mask" into batch.obs, for - example, if we have an environment that has "0/1/2" three actions: - :: - - batch == Batch( - obs=Batch( - obs="original obs, with batch_size=1 for demonstration", - mask=np.array([[False, True, False]]), - # action 1 is available - # action 0 and 2 are unavailable - ), - ... - ) - - :return: A :class:`~tianshou.data.Batch` which has 3 keys: - - * ``act`` the action. - * ``logits`` the network's raw output. - * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - model = getattr(self, model) - obs = batch.obs - # TODO: this is convoluted! See also other places where this is done. - obs_next = obs.obs if hasattr(obs, "obs") else obs - action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info) - q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None)) - if self.max_action_num is None: - self.max_action_num = q.shape[1] - act_B = to_numpy(q.argmax(dim=1)) - result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) - return cast(ModelOutputBatchProtocol, result) - def _update_with_batch( self, batch: RolloutBatchProtocol, @@ -219,7 +238,7 @@ def _update_with_batch( self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) - q = self(batch).logits + q = self.policy(batch).logits q = q[np.arange(len(q)), batch.act] returns = to_torch_as(batch.returns.flatten(), q) td_error = returns - q @@ -249,9 +268,9 @@ def exploration_noise( bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps assert ( - self.max_action_num is not None + self.policy.max_action_num is not None ), "Can't call this method before max_action_num was set in first forward" - q = np.random.rand(bsz, self.max_action_num) # [0, 1] + q = np.random.rand(bsz, self.policy.max_action_num) # [0, 1] if hasattr(batch.obs, "mask"): q += batch.obs.mask rand_act = q.argmax(axis=1) diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 99a64953a..d00e38f09 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -8,7 +8,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import DQNPolicy, QRDQNPolicy +from tianshou.policy import DeepQLearning, QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -138,7 +138,7 @@ def forward( # type: ignore info=batch.info, ) weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits - q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) + q = DeepQLearning.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) if self.max_action_num is None: # type: ignore # TODO: see same thing in DQNPolicy! Also reduce code duplication. self.max_action_num = q.shape[1] diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index bec8c7cce..2f4d5ba04 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -9,7 +9,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import DQNPolicy +from tianshou.policy import DeepQLearning from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNTrainingStats @@ -22,7 +22,7 @@ class QRDQNTrainingStats(DQNTrainingStats): TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats) -class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): +class QRDQNPolicy(DeepQLearning[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. :param model: a model following the rules (s -> action_values_BA) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 7ccc41b6d..d21ac7407 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -38,7 +38,7 @@ log = logging.getLogger(__name__) -@dataclass +@dataclass(kw_only=True) class TrainingConfig(ToStringMixin): max_epoch: int = 100 """ @@ -125,7 +125,7 @@ def __post_init__(self): raise ValueError("Exactly one of {step_per_collect, episode_per_collect} must be set") -@dataclass +@dataclass(kw_only=True) class OnPolicyTrainingConfig(TrainingConfig): batch_size: int | None = 64 """ @@ -142,7 +142,7 @@ class OnPolicyTrainingConfig(TrainingConfig): """ -@dataclass +@dataclass(kw_only=True) class OffPolicyTrainingConfig(TrainingConfig): batch_size: int = 64 """ @@ -157,7 +157,7 @@ class OffPolicyTrainingConfig(TrainingConfig): """ -@dataclass +@dataclass(kw_only=True) class OfflineTrainingConfig(OffPolicyTrainingConfig): pass From 24a288bb88574bed1dc02a4eb30cc4994c7d004a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 20:37:45 +0100 Subject: [PATCH 005/230] v2: Adapt DDPG and test_ddpg --- examples/mujoco/fetch_her_ddpg.py | 6 +- examples/mujoco/mujoco_ddpg.py | 6 +- test/continuous/test_ddpg.py | 49 +++++----- tianshou/highlevel/agent.py | 6 +- tianshou/policy/__init__.py | 4 +- tianshou/policy/base.py | 2 +- tianshou/policy/modelfree/ddpg.py | 152 +++++++++++++++++------------- tianshou/policy/modelfree/redq.py | 10 +- tianshou/policy/modelfree/sac.py | 10 +- tianshou/policy/modelfree/td3.py | 10 +- 10 files changed, 139 insertions(+), 116 deletions(-) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index b5f9f1319..888686300 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -22,7 +22,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated from tianshou.exploration import GaussianNoise -from tianshou.policy import DDPGPolicy +from tianshou.policy import DDPG from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net, get_dict_state_decorator @@ -169,9 +169,9 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: ) critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy: DDPGPolicy = DDPGPolicy( + policy: DDPG = DDPG( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 62a4b3d62..d63390be4 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DDPGPolicy +from tianshou.policy import DDPG from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net @@ -97,9 +97,9 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: ) critic = Critic(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy: DDPGPolicy = DDPGPolicy( + policy: DDPG = DDPG( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index b7b877dee..552287a4a 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -9,9 +9,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import DDPGPolicy +from tianshou.policy import DDPG from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -76,7 +77,6 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -86,25 +86,29 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: ) critic = Critic(net, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy: DDPGPolicy = DDPGPolicy( + policy = DDPGPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + policy_optim = torch.optim.Adam(policy.parameters(), lr=args.actor_lr) + algorithm: DDPG = DDPG( + policy=policy, + policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), estimation_step=args.n_step, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ddpg") writer = SummaryWriter(log_path) @@ -117,18 +121,19 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index c63f0395e..1b0a92332 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -45,9 +45,9 @@ from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.world import World from tianshou.policy import ( + DDPG, A2CPolicy, Algorithm, - DDPGPolicy, DeepQLearning, DiscreteSACPolicy, IQNPolicy, @@ -467,9 +467,9 @@ def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: critic1=critic, ), ) - return DDPGPolicy( + return DDPG( actor=actor.module, - actor_optim=actor.optim, + policy_optim=actor.optim, critic=critic.module, critic_optim=critic.optim, action_space=envs.get_action_space(), diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 31c4b1d68..72d47260a 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -4,6 +4,7 @@ from tianshou.policy.base import Algorithm, TrainingStats from tianshou.policy.modelfree.pg import Reinforce from tianshou.policy.modelfree.dqn import DeepQLearning +from tianshou.policy.modelfree.ddpg import DDPG """ from tianshou.policy.random import MARLRandomPolicy @@ -15,7 +16,6 @@ from tianshou.policy.modelfree.fqf import FQFPolicy from tianshou.policy.modelfree.a2c import A2CPolicy from tianshou.policy.modelfree.npg import NPGPolicy -from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.modelfree.ppo import PPOPolicy from tianshou.policy.modelfree.trpo import TRPOPolicy from tianshou.policy.modelfree.td3 import TD3Policy @@ -48,7 +48,7 @@ "Reinforce", # "A2CPolicy", # "NPGPolicy", - # "DDPGPolicy", + "DDPG", # "PPOPolicy", # "TRPOPolicy", # "TD3Policy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index e5de885e1..d7747cbd5 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -176,7 +176,6 @@ def __init__( raise ValueError(f"Unsupported action space: {action_space}.") self._action_type = cast(Literal["discrete", "continuous"], action_type) self.agent_id = 0 - self.updating = False self.action_scaling = action_scaling self.action_bound_method = action_bound_method self.is_within_training_step = False @@ -421,6 +420,7 @@ def __init__( super().__init__() self.policy: TPolicy = policy self.lr_scheduler = lr_scheduler + self.updating = False # TODO delete this def __setstate__(self, state: dict[str, Any]) -> None: diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 5d66291b7..ef9c0231f 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -6,6 +6,7 @@ import gymnasium as gym import numpy as np import torch +from sensai.util.helper import mark_used from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol @@ -17,10 +18,16 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise, GaussianNoise -from tianshou.policy import Algorithm -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import ( + OffPolicyAlgorithm, + Policy, + TLearningRateScheduler, + TrainingStats, +) from tianshou.utils.net.continuous import Actor, Critic +mark_used(ActBatchProtocol) + @dataclass(kw_only=True) class DDPGTrainingStats(TrainingStats): @@ -31,24 +38,78 @@ class DDPGTrainingStats(TrainingStats): TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) -class DDPGPolicy(Algorithm[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): +class DDPGPolicy(Policy): + def __init__( + self, + *, + actor: torch.nn.Module | Actor, + action_space: gym.Space, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + ): + """ + :param actor: The actor network following the rules (s -> actions) + :param action_space: Env's action space. + :param tau: Param for soft update of the target network. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + """ + if action_scaling and not np.isclose(actor.max_action, 1.0): + warnings.warn( + "action_scaling and action_bound_method are only intended to deal" + "with unbounded model action space, but find actor model bound" + f"action space with max_action={actor.max_action}." + "Consider using unbounded=True option of the actor model," + "or set action_scaling to False and action_bound_method to None.", + ) + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + self.actor = actor + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: torch.nn.Module | None = None, + **kwargs: Any, + ) -> ActStateBatchProtocol: + """Compute action over the given batch data. + + :return: A :class:`~tianshou.data.Batch` which has 2 keys: + + * ``act`` the action. + * ``state`` the hidden state. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + if model is None: + model = self.actor + actions, hidden = model(batch.obs, state=state, info=batch.info) + return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden)) + + +class DDPG(OffPolicyAlgorithm[DDPGPolicy, TDDPGTrainingStats], Generic[TDDPGTrainingStats]): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. - :param actor: The actor network following the rules (s -> actions) - :param actor_optim: The optimizer for actor network. + :param policy: the policy + :param policy_optim: The optimizer for actor network. :param critic: The critic network. (s, a -> Q(s, a)) :param critic_optim: The optimizer for critic network. - :param action_space: Env's action space. :param tau: Param for soft update of the target network. :param gamma: Discount factor, in [0, 1]. :param exploration_noise: The exploration noise, added to the action. Defaults to ``GaussianNoise(sigma=0.1)``. :param estimation_step: The number of steps to look ahead. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. :param lr_scheduler: if not None, will be called in `policy.update()`. .. seealso:: @@ -60,47 +121,25 @@ class DDPGPolicy(Algorithm[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): def __init__( self, *, - actor: torch.nn.Module | Actor, - actor_optim: torch.optim.Optimizer, + policy: DDPGPolicy, + policy_optim: torch.optim.Optimizer, critic: torch.nn.Module | Critic, critic_optim: torch.optim.Optimizer, - action_space: gym.Space, tau: float = 0.005, gamma: float = 0.99, exploration_noise: BaseNoise | Literal["default"] | None = "default", estimation_step: int = 1, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - # tanh not supported, see assert below - action_bound_method: Literal["clip"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: assert 0.0 <= tau <= 1.0, f"tau should be in [0, 1] but got: {tau}" assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}" - assert action_bound_method != "tanh", ( # type: ignore[comparison-overlap] - "tanh mapping is not supported" - "in policies where action is used as input of critic , because" - "raw action in range (-inf, inf) will cause instability in training" - ) super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, + policy=policy, lr_scheduler=lr_scheduler, ) - if action_scaling and not np.isclose(actor.max_action, 1.0): - warnings.warn( - "action_scaling and action_bound_method are only intended to deal" - "with unbounded model action space, but find actor model bound" - f"action space with max_action={actor.max_action}." - "Consider using unbounded=True option of the actor model," - "or set action_scaling to False and action_bound_method to None.", - ) - self.actor = actor - self.actor_old = deepcopy(actor) + self.actor_old = deepcopy(policy.actor) self.actor_old.eval() - self.actor_optim = actor_optim + self.policy_optim = policy_optim self.critic = critic self.critic_old = deepcopy(critic) self.critic_old.eval() @@ -124,13 +163,13 @@ def set_exp_noise(self, noise: BaseNoise | None) -> None: def train(self, mode: bool = True) -> Self: """Set the module in training mode, except for the target network.""" self.training = mode - self.actor.train(mode) + self.policy.actor.train(mode) self.critic.train(mode) return self def sync_weight(self) -> None: """Soft-update the weight for the target network.""" - self.soft_update(self.actor_old, self.actor, self.tau) + self.soft_update(self.actor_old, self.policy.actor, self.tau) self.soft_update(self.critic_old, self.critic, self.tau) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: @@ -138,7 +177,9 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} - return self.critic_old(obs_next_batch.obs, self(obs_next_batch, model="actor_old").act) + return self.critic_old( + obs_next_batch.obs, self.policy(obs_next_batch, model=self.actor_old).act + ) def process_fn( self, @@ -155,29 +196,6 @@ def process_fn( n_step=self.estimation_step, ) - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - model: Literal["actor", "actor_old"] = "actor", - **kwargs: Any, - ) -> ActStateBatchProtocol: - """Compute action over the given batch data. - - :return: A :class:`~tianshou.data.Batch` which has 2 keys: - - * ``act`` the action. - * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - model = getattr(self, model) - actions, hidden = model(batch.obs, state=state, info=batch.info) - return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden)) - @staticmethod def _mse_optimizer( batch: RolloutBatchProtocol, @@ -201,10 +219,10 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer # actor - actor_loss = -self.critic(batch.obs, self(batch).act).mean() - self.actor_optim.zero_grad() + actor_loss = -self.critic(batch.obs, self.policy(batch).act).mean() + self.policy_optim.zero_grad() actor_loss.backward() - self.actor_optim.step() + self.policy_optim.step() self.sync_weight() return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index dcfa1c39f..317e4f42d 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -9,7 +9,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.exploration import BaseNoise -from tianshou.policy import DDPGPolicy +from tianshou.policy import DDPG from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.ddpg import DDPGTrainingStats from tianshou.utils.net.continuous import ActorProb @@ -26,7 +26,7 @@ class REDQTrainingStats(DDPGTrainingStats): TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) -class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): +class REDQPolicy(DDPG[TREDQTrainingStats]): """Implementation of REDQ. arXiv:2101.05982. :param actor: The actor network following the rules in @@ -91,7 +91,7 @@ def __init__( ) super().__init__( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, action_space=action_space, @@ -212,9 +212,9 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: a = obs_result.act current_qa = self.critic(batch.obs, a).mean(dim=0).flatten() actor_loss = (self.alpha * obs_result.log_prob.flatten() - current_qa).mean() - self.actor_optim.zero_grad() + self.policy_optim.zero_grad() actor_loss.backward() - self.actor_optim.step() + self.policy_optim.step() if self.is_auto_alpha: log_prob = obs_result.log_prob.detach() + self._target_entropy diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 8ff349ea9..d5678b047 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -14,7 +14,7 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.policy import DDPGPolicy +from tianshou.policy import DDPG from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ActorProb @@ -51,7 +51,7 @@ class SACTrainingStats(TrainingStats): # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] +class SACPolicy(DDPG[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] """Implementation of Soft Actor-Critic. arXiv:1812.05905. :param actor: the actor network following the rules (s -> dist_input_BD) @@ -115,7 +115,7 @@ def __init__( ) -> None: super().__init__( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, action_space=action_space, @@ -237,9 +237,9 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: actor_loss = ( self.alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) ).mean() - self.actor_optim.zero_grad() + self.policy_optim.zero_grad() actor_loss.backward() - self.actor_optim.step() + self.policy_optim.step() alpha_loss = None if self.is_auto_alpha: diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index e2560f9be..66673a805 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -9,7 +9,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.exploration import BaseNoise -from tianshou.policy import DDPGPolicy +from tianshou.policy import DDPG from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.optim import clone_optimizer @@ -25,7 +25,7 @@ class TD3TrainingStats(TrainingStats): # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] +class TD3Policy(DDPG[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] """Implementation of TD3, arXiv:1802.09477. :param actor: the actor network following the rules in @@ -86,7 +86,7 @@ def __init__( # Some intermediate class, like TwoCriticPolicy? super().__init__( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, action_space=action_space, @@ -149,10 +149,10 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: # actor if self._cnt % self.update_actor_freq == 0: actor_loss = -self.critic(batch.obs, self(batch, eps=0.0).act).mean() - self.actor_optim.zero_grad() + self.policy_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() - self.actor_optim.step() + self.policy_optim.step() self.sync_weight() self._cnt += 1 From 8329bdbb0bafd28eecc609119530f76066add7a8 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 21:48:33 +0100 Subject: [PATCH 006/230] v2: High-level API adaptations, Reinforce support implemented --- tianshou/highlevel/{agent.py => algorithm.py} | 185 ++++++++++-------- tianshou/highlevel/experiment.py | 107 +++++----- tianshou/policy/__init__.py | 54 +++-- tianshou/policy/base.py | 8 +- tianshou/policy/imitation/base.py | 2 +- tianshou/policy/imitation/bcq.py | 2 +- tianshou/policy/modelbased/icm.py | 5 +- tianshou/policy/modelbased/psrl.py | 2 +- tianshou/policy/random.py | 2 +- 9 files changed, 197 insertions(+), 170 deletions(-) rename tianshou/highlevel/{agent.py => algorithm.py} (76%) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/algorithm.py similarity index 76% rename from tianshou/highlevel/agent.py rename to tianshou/highlevel/algorithm.py index 1b0a92332..c35c6c6d9 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/algorithm.py @@ -59,8 +59,15 @@ TD3Policy, TRPOPolicy, ) -from tianshou.policy.base import RandomActionPolicy +from tianshou.policy.base import ( + OffPolicyAlgorithm, + OnPolicyAlgorithm, + Policy, + RandomActionPolicy, +) +from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer +from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig from tianshou.utils.net.common import ActorCritic CHECKPOINT_DICT_KEY_MODEL = "model" @@ -78,12 +85,13 @@ "TDiscreteCriticOnlyParams", bound=Params | ParamsMixinLearningRateWithScheduler, ) -TPolicy = TypeVar("TPolicy", bound=Algorithm) +TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) +TPolicy = TypeVar("TPolicy", bound=Policy) log = logging.getLogger(__name__) -class AgentFactory(ABC, ToStringMixin): - """Factory for the creation of an agent's policy, its trainer as well as collectors.""" +class AlgorithmFactory(ABC, ToStringMixin): + """Factory for the creation of an :class:`Algorithm` instance, its policy, trainer as well as collectors.""" def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory): self.sampling_config = sampling_config @@ -142,12 +150,19 @@ def set_policy_wrapper_factory( def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: self.trainer_callbacks = callbacks + @staticmethod + def _create_policy( + constructor: type[TPolicy], params_dict: dict, policy_params: list[str], **kwargs + ) -> TPolicy: + params = {p: params_dict.pop(p) for p in policy_params} + return constructor(**params, **kwargs) + @abstractmethod - def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: + def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: pass - def create_policy(self, envs: Environments, device: TDevice) -> Algorithm: - policy = self._create_policy(envs, device) + def create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: + policy = self._create_algorithm(envs, device) if self.policy_wrapper_factory is not None: policy = self.policy_wrapper_factory.create_wrapped_policy( policy, @@ -162,7 +177,7 @@ def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> pass -class OnPolicyAgentFactory(AgentFactory, ABC): +class OnPolicyAlgorithmFactory(AlgorithmFactory, ABC): def create_trainer( self, world: World, @@ -186,28 +201,30 @@ def create_trainer( if callbacks.epoch_stop_callback else None ) - return OnpolicyTrainer( - policy=world.policy, - train_collector=world.train_collector, - test_collector=world.test_collector, - max_epoch=sampling_config.num_epochs, - step_per_epoch=sampling_config.step_per_epoch, - repeat_per_collect=sampling_config.repeat_per_collect, - episode_per_test=sampling_config.num_test_episodes, - batch_size=sampling_config.batch_size, - step_per_collect=sampling_config.step_per_collect, - save_best_fn=policy_persistence.get_save_best_fn(world), - save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), - logger=world.logger, - test_in_train=False, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - verbose=False, + algorithm = cast(OnPolicyAlgorithm, world.policy) + return algorithm.create_trainer( + OnPolicyTrainingConfig( + train_collector=world.train_collector, + test_collector=world.test_collector, + max_epoch=sampling_config.num_epochs, + step_per_epoch=sampling_config.step_per_epoch, + repeat_per_collect=sampling_config.repeat_per_collect, + episode_per_test=sampling_config.num_test_episodes, + batch_size=sampling_config.batch_size, + step_per_collect=sampling_config.step_per_collect, + save_best_fn=policy_persistence.get_save_best_fn(world), + save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), + logger=world.logger, + test_in_train=False, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + verbose=False, + ) ) -class OffPolicyAgentFactory(AgentFactory, ABC): +class OffPolicyAlgorithmFactory(AlgorithmFactory, ABC): def create_trainer( self, world: World, @@ -231,32 +248,34 @@ def create_trainer( if callbacks.epoch_stop_callback else None ) - return OffpolicyTrainer( - policy=world.policy, - train_collector=world.train_collector, - test_collector=world.test_collector, - max_epoch=sampling_config.num_epochs, - step_per_epoch=sampling_config.step_per_epoch, - step_per_collect=sampling_config.step_per_collect, - episode_per_test=sampling_config.num_test_episodes, - batch_size=sampling_config.batch_size, - save_best_fn=policy_persistence.get_save_best_fn(world), - logger=world.logger, - update_per_step=sampling_config.update_per_step, - test_in_train=False, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - verbose=False, + algorithm = cast(OffPolicyAlgorithm, world.policy) + return algorithm.create_trainer( + OffPolicyTrainingConfig( + train_collector=world.train_collector, + test_collector=world.test_collector, + max_epoch=sampling_config.num_epochs, + step_per_epoch=sampling_config.step_per_epoch, + step_per_collect=sampling_config.step_per_collect, + episode_per_test=sampling_config.num_test_episodes, + batch_size=sampling_config.batch_size, + save_best_fn=policy_persistence.get_save_best_fn(world), + logger=world.logger, + update_per_step=sampling_config.update_per_step, + test_in_train=False, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + verbose=False, + ) ) -class RandomActionAgentFactory(OnPolicyAgentFactory): - def _create_policy(self, envs: Environments, device: TDevice) -> RandomActionPolicy: +class RandomActionAlgorithmFactory(OnPolicyAlgorithmFactory): + def _create_algorithm(self, envs: Environments, device: TDevice) -> RandomActionPolicy: return RandomActionPolicy(envs.get_action_space()) -class PGAgentFactory(OnPolicyAgentFactory): +class PGAlgorithmFactory(OnPolicyAlgorithmFactory): def __init__( self, params: PGParams, @@ -269,7 +288,7 @@ def __init__( self.actor_factory = actor_factory self.optim_factory = optim_factory - def _create_policy(self, envs: Environments, device: TDevice) -> Reinforce: + def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: actor = self.actor_factory.create_module_opt( envs, device, @@ -286,19 +305,25 @@ def _create_policy(self, envs: Environments, device: TDevice) -> Reinforce: ) dist_fn = self.actor_factory.create_dist_fn(envs) assert dist_fn is not None - return Reinforce( + policy = self._create_policy( + ActorPolicy, + kwargs, + ["action_scaling", "action_bound_method", "deterministic_eval"], actor=actor.module, - optim=actor.optim, + dist_fn=dist_fn, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), - dist_fn=dist_fn, + ) + return Reinforce( + policy=policy, + optim=actor.optim, **kwargs, ) -class ActorCriticAgentFactory( - Generic[TActorCriticParams, TPolicy], - OnPolicyAgentFactory, +class ActorCriticAlgorithmFactory( + Generic[TActorCriticParams, TAlgorithm], + OnPolicyAlgorithmFactory, ABC, ): def __init__( @@ -317,7 +342,7 @@ def __init__( self.critic_use_action = False @abstractmethod - def _get_policy_class(self) -> type[TPolicy]: + def _get_policy_class(self) -> type[TAlgorithm]: pass def create_actor_critic_module_opt( @@ -350,34 +375,34 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs) return kwargs - def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: policy_class = self._get_policy_class() return policy_class(**self._create_kwargs(envs, device)) -class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): +class A2CAlgorithmFactory(ActorCriticAlgorithmFactory[A2CParams, A2CPolicy]): def _get_policy_class(self) -> type[A2CPolicy]: return A2CPolicy -class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): +class PPOAlgorithmFactory(ActorCriticAlgorithmFactory[PPOParams, PPOPolicy]): def _get_policy_class(self) -> type[PPOPolicy]: return PPOPolicy -class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): +class NPGAlgorithmFactory(ActorCriticAlgorithmFactory[NPGParams, NPGPolicy]): def _get_policy_class(self) -> type[NPGPolicy]: return NPGPolicy -class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): +class TRPOAlgorithmFactory(ActorCriticAlgorithmFactory[TRPOParams, TRPOPolicy]): def _get_policy_class(self) -> type[TRPOPolicy]: return TRPOPolicy -class DiscreteCriticOnlyAgentFactory( - OffPolicyAgentFactory, - Generic[TDiscreteCriticOnlyParams, TPolicy], +class DiscreteCriticOnlyAlgorithmFactory( + OffPolicyAlgorithmFactory, + Generic[TDiscreteCriticOnlyParams, TAlgorithm], ): def __init__( self, @@ -392,11 +417,11 @@ def __init__( self.optim_factory = optim_factory @abstractmethod - def _get_policy_class(self) -> type[TPolicy]: + def _get_policy_class(self) -> type[TAlgorithm]: pass @typing.no_type_check - def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: model = self.model_factory.create_module(envs, device) optim = self.optim_factory.create_optimizer(model, self.params.lr) kwargs = self.params.create_kwargs( @@ -419,17 +444,17 @@ def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: ) -class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DeepQLearning]): +class DQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DeepQLearning]): def _get_policy_class(self) -> type[DeepQLearning]: return DeepQLearning -class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]): +class IQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[IQNParams, IQNPolicy]): def _get_policy_class(self) -> type[IQNPolicy]: return IQNPolicy -class DDPGAgentFactory(OffPolicyAgentFactory): +class DDPGAlgorithmFactory(OffPolicyAlgorithmFactory): def __init__( self, params: DDPGParams, @@ -444,7 +469,7 @@ def __init__( self.params = params self.optim_factory = optim_factory - def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: + def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: actor = self.actor_factory.create_module_opt( envs, device, @@ -478,7 +503,7 @@ def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: ) -class REDQAgentFactory(OffPolicyAgentFactory): +class REDQAlgorithmFactory(OffPolicyAlgorithmFactory): def __init__( self, params: REDQParams, @@ -493,7 +518,7 @@ def __init__( self.params = params self.optim_factory = optim_factory - def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: + def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: envs.get_type().assert_continuous(self) actor = self.actor_factory.create_module_opt( envs, @@ -530,9 +555,9 @@ def _create_policy(self, envs: Environments, device: TDevice) -> Algorithm: ) -class ActorDualCriticsAgentFactory( - OffPolicyAgentFactory, - Generic[TActorDualCriticsParams, TPolicy], +class ActorDualCriticsAlgorithmFactory( + OffPolicyAlgorithmFactory, + Generic[TActorDualCriticsParams, TAlgorithm], ABC, ): def __init__( @@ -552,7 +577,7 @@ def __init__( self.optim_factory = optim_factory @abstractmethod - def _get_policy_class(self) -> type[TPolicy]: + def _get_policy_class(self) -> type[TAlgorithm]: pass def _get_discrete_last_size_use_action_shape(self) -> bool: @@ -563,7 +588,7 @@ def _get_critic_use_action(envs: Environments) -> bool: return envs.get_type().is_continuous() @typing.no_type_check - def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: + def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: actor = self.actor_factory.create_module_opt( envs, device, @@ -612,16 +637,18 @@ def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: ) -class SACAgentFactory(ActorDualCriticsAgentFactory[SACParams, SACPolicy]): +class SACAlgorithmFactory(ActorDualCriticsAlgorithmFactory[SACParams, SACPolicy]): def _get_policy_class(self) -> type[SACPolicy]: return SACPolicy -class DiscreteSACAgentFactory(ActorDualCriticsAgentFactory[DiscreteSACParams, DiscreteSACPolicy]): +class DiscreteSACAlgorithmFactory( + ActorDualCriticsAlgorithmFactory[DiscreteSACParams, DiscreteSACPolicy] +): def _get_policy_class(self) -> type[DiscreteSACPolicy]: return DiscreteSACPolicy -class TD3AgentFactory(ActorDualCriticsAgentFactory[TD3Params, TD3Policy]): +class TD3AlgorithmFactory(ActorDualCriticsAlgorithmFactory[TD3Params, TD3Policy]): def _get_policy_class(self) -> type[TD3Policy]: return TD3Policy diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 9874d26c9..6bf3b563c 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -38,21 +38,21 @@ from tianshou.data import BaseCollector, Collector, CollectStats, InfoStats from tianshou.env import BaseVectorEnv -from tianshou.highlevel.agent import ( - A2CAgentFactory, - AgentFactory, - DDPGAgentFactory, - DiscreteSACAgentFactory, - DQNAgentFactory, - IQNAgentFactory, - NPGAgentFactory, - PGAgentFactory, - PPOAgentFactory, - RandomActionAgentFactory, - REDQAgentFactory, - SACAgentFactory, - TD3AgentFactory, - TRPOAgentFactory, +from tianshou.highlevel.algorithm import ( + A2CAlgorithmFactory, + AlgorithmFactory, + DDPGAlgorithmFactory, + DiscreteSACAlgorithmFactory, + DQNAlgorithmFactory, + IQNAlgorithmFactory, + NPGAlgorithmFactory, + PGAlgorithmFactory, + PPOAlgorithmFactory, + RandomActionAlgorithmFactory, + REDQAlgorithmFactory, + SACAlgorithmFactory, + TD3AlgorithmFactory, + TRPOAlgorithmFactory, ) from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import EnvFactory @@ -185,7 +185,7 @@ def __init__( self, config: ExperimentConfig, env_factory: EnvFactory, - agent_factory: AgentFactory, + algorithm_factory: AlgorithmFactory, sampling_config: SamplingConfig, name: str, logger_factory: LoggerFactory | None = None, @@ -195,7 +195,7 @@ def __init__( self.config = config self.sampling_config = sampling_config self.env_factory = env_factory - self.agent_factory = agent_factory + self.algorithm_factory = algorithm_factory self.logger_factory = logger_factory self.name = name @@ -318,7 +318,7 @@ def create_experiment_world( full_config["experiment_config"] = asdict(self.config) full_config["sampling_config"] = asdict(self.sampling_config) with suppress(AttributeError): - full_config["policy_params"] = asdict(self.agent_factory.params) + full_config["policy_params"] = asdict(self.algorithm_factory.params) logger: TLogger if use_persistence: @@ -333,13 +333,16 @@ def create_experiment_world( # create policy and collectors log.info("Creating policy") - policy = self.agent_factory.create_policy(envs, self.config.device) + policy = self.algorithm_factory.create_algorithm(envs, self.config.device) log.info("Creating collectors") train_collector: BaseCollector | None = None test_collector: BaseCollector | None = None if self.config.train: - train_collector, test_collector = self.agent_factory.create_train_test_collector( + ( + train_collector, + test_collector, + ) = self.algorithm_factory.create_train_test_collector( policy, envs, reset_collectors=reset_collectors, @@ -365,7 +368,7 @@ def create_experiment_world( ) if self.config.train: - trainer = self.agent_factory.create_trainer(world, policy_persistence) + trainer = self.algorithm_factory.create_trainer(world, policy_persistence) world.trainer = trainer return world @@ -632,7 +635,7 @@ def with_name( return self @abstractmethod - def _create_agent_factory(self) -> AgentFactory: + def _create_algorithm_factory(self) -> AlgorithmFactory: pass def _get_optim_factory(self) -> OptimizerFactory: @@ -647,14 +650,14 @@ def build(self) -> Experiment: :return: the experiment """ - agent_factory = self._create_agent_factory() - agent_factory.set_trainer_callbacks(self._trainer_callbacks) + algorithm_factory = self._create_algorithm_factory() + algorithm_factory.set_trainer_callbacks(self._trainer_callbacks) if self._policy_wrapper_factory: - agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) + algorithm_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) experiment: Experiment = Experiment( config=self._config, env_factory=self._env_factory, - agent_factory=agent_factory, + algorithm_factory=algorithm_factory, sampling_config=self._sampling_config, name=self._name, logger_factory=self._logger_factory, @@ -686,8 +689,8 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: class RandomActionExperimentBuilder(ExperimentBuilder): - def _create_agent_factory(self) -> RandomActionAgentFactory: - return RandomActionAgentFactory( + def _create_algorithm_factory(self) -> RandomActionAlgorithmFactory: + return RandomActionAlgorithmFactory( sampling_config=self.sampling_config, optim_factory=self._get_optim_factory(), ) @@ -1038,8 +1041,8 @@ def with_pg_params(self, params: PGParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return PGAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return PGAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1068,8 +1071,8 @@ def with_a2c_params(self, params: A2CParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return A2CAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return A2CAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1098,8 +1101,8 @@ def with_ppo_params(self, params: PPOParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return PPOAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return PPOAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1128,8 +1131,8 @@ def with_npg_params(self, params: NPGParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return NPGAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return NPGAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1158,8 +1161,8 @@ def with_trpo_params(self, params: TRPOParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return TRPOAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return TRPOAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1216,8 +1219,8 @@ def with_model_factory_default( ) return self - def _create_agent_factory(self) -> AgentFactory: - return DQNAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return DQNAlgorithmFactory( self._params, self._sampling_config, self._model_factory, @@ -1248,13 +1251,13 @@ def with_preprocess_network_factory(self, module_factory: IntermediateModuleFact self._preprocess_network_factory = module_factory return self - def _create_agent_factory(self) -> AgentFactory: + def _create_algorithm_factory(self) -> AlgorithmFactory: model_factory = ImplicitQuantileNetworkFactory( self._preprocess_network_factory, hidden_sizes=self._params.hidden_sizes, num_cosines=self._params.num_cosines, ) - return IQNAgentFactory( + return IQNAlgorithmFactory( self._params, self._sampling_config, model_factory, @@ -1282,8 +1285,8 @@ def with_ddpg_params(self, params: DDPGParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return DDPGAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return DDPGAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1312,8 +1315,8 @@ def with_redq_params(self, params: REDQParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return REDQAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return REDQAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1342,8 +1345,8 @@ def with_sac_params(self, params: SACParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return SACAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return SACAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1373,8 +1376,8 @@ def with_sac_params(self, params: DiscreteSACParams) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return DiscreteSACAgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return DiscreteSACAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1404,8 +1407,8 @@ def with_td3_params(self, params: TD3Params) -> Self: self._params = params return self - def _create_agent_factory(self) -> AgentFactory: - return TD3AgentFactory( + def _create_algorithm_factory(self) -> AlgorithmFactory: + return TD3AlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 72d47260a..3bf28ec5f 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -6,7 +6,6 @@ from tianshou.policy.modelfree.dqn import DeepQLearning from tianshou.policy.modelfree.ddpg import DDPG -""" from tianshou.policy.random import MARLRandomPolicy from tianshou.policy.modelfree.bdq import BranchingDQNPolicy from tianshou.policy.modelfree.c51 import C51Policy @@ -33,38 +32,37 @@ from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager -""" __all__ = [ "Algorithm", - # "MARLRandomPolicy", + "MARLRandomPolicy", "DeepQLearning", - # "BranchingDQNPolicy", - # "C51Policy", - # "RainbowPolicy", - # "QRDQNPolicy", - # "IQNPolicy", - # "FQFPolicy", + "BranchingDQNPolicy", + "C51Policy", + "RainbowPolicy", + "QRDQNPolicy", + "IQNPolicy", + "FQFPolicy", "Reinforce", - # "A2CPolicy", - # "NPGPolicy", + "A2CPolicy", + "NPGPolicy", "DDPG", - # "PPOPolicy", - # "TRPOPolicy", - # "TD3Policy", - # "SACPolicy", - # "REDQPolicy", - # "DiscreteSACPolicy", - # "ImitationPolicy", - # "BCQPolicy", - # "CQLPolicy", - # "TD3BCPolicy", - # "DiscreteBCQPolicy", - # "DiscreteCQLPolicy", - # "DiscreteCRRPolicy", - # "GAILPolicy", - # "PSRLPolicy", - # "ICMPolicy", - # "MultiAgentPolicyManager", + "PPOPolicy", + "TRPOPolicy", + "TD3Policy", + "SACPolicy", + "REDQPolicy", + "DiscreteSACPolicy", + "ImitationPolicy", + "BCQPolicy", + "CQLPolicy", + "TD3BCPolicy", + "DiscreteBCQPolicy", + "DiscreteCQLPolicy", + "DiscreteCRRPolicy", + "GAILPolicy", + "PSRLPolicy", + "ICMPolicy", + "MultiAgentPolicyManager", "TrainingStats", ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index d7747cbd5..7f10b9f91 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -762,11 +762,11 @@ def compute_nstep_return( return cast(BatchWithReturnsProtocol, batch) @abstractmethod - def _create_trainer(self, config: TTrainingConfig) -> "BaseTrainer": + def create_trainer(self, config: TTrainingConfig) -> "BaseTrainer": pass def run_training(self, config: TTrainingConfig): - trainer = self._create_trainer(config) + trainer = self.create_trainer(config) return trainer.run() @@ -775,7 +775,7 @@ class OnPolicyAlgorithm( Generic[TPolicy, TTrainingStats], ABC, ): - def _create_trainer(self, config: "OnPolicyTrainingConfig") -> "OnpolicyTrainer": + def create_trainer(self, config: "OnPolicyTrainingConfig") -> "OnpolicyTrainer": from tianshou.trainer.base import OnpolicyTrainer return OnpolicyTrainer(self, config) @@ -786,7 +786,7 @@ class OffPolicyAlgorithm( Generic[TPolicy, TTrainingStats], ABC, ): - def _create_trainer(self, config: "OffPolicyTrainingConfig") -> "OffpolicyTrainer": + def create_trainer(self, config: "OffPolicyTrainingConfig") -> "OffpolicyTrainer": from tianshou.trainer.base import OffpolicyTrainer return OffpolicyTrainer(self, config) diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 5135295e7..ecf66578e 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -31,7 +31,7 @@ class ImitationTrainingStats(TrainingStats): TImitationTrainingStats = TypeVar("TImitationTrainingStats", bound=ImitationTrainingStats) -class ImitationPolicy(Algorithm[TImitationTrainingStats], Generic[TImitationTrainingStats]): +class ImitationPolicy(Algorithm, Generic[TImitationTrainingStats]): """Implementation of vanilla imitation learning. :param actor: a model following the rules in diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index c94c30d5b..c98a6b712 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -27,7 +27,7 @@ class BCQTrainingStats(TrainingStats): TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) -class BCQPolicy(Algorithm[TBCQTrainingStats], Generic[TBCQTrainingStats]): +class BCQPolicy(Algorithm, Generic[TBCQTrainingStats]): """Implementation of BCQ algorithm. arXiv:1812.02900. :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 1bd13bb1b..ab6a452e1 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -13,7 +13,6 @@ TLearningRateScheduler, TrainingStats, TrainingStatsWrapper, - TTrainingStats, ) from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -33,7 +32,7 @@ def __init__( super().__init__(wrapped_stats) -class ICMPolicy(Algorithm[ICMTrainingStats]): +class ICMPolicy(Algorithm): """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. :param policy: a base policy to add ICM to. @@ -57,7 +56,7 @@ class ICMPolicy(Algorithm[ICMTrainingStats]): def __init__( self, *, - policy: Algorithm[TTrainingStats], + policy: Algorithm, # [TTrainingStats] model: IntrinsicCuriosityModule, optim: torch.optim.Optimizer, lr_scale: float, diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 66711c74f..c6f433d1f 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -150,7 +150,7 @@ def __call__( return self.policy[obs] -class PSRLPolicy(Algorithm[TPSRLTrainingStats]): +class PSRLPolicy(Algorithm): """Implementation of Posterior Sampling Reinforcement Learning. Reference: Strens M. A Bayesian framework for reinforcement learning [C] diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 0f596782e..c20f8419c 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -16,7 +16,7 @@ class MARLRandomTrainingStats(TrainingStats): TMARLRandomTrainingStats = TypeVar("TMARLRandomTrainingStats", bound=MARLRandomTrainingStats) -class MARLRandomPolicy(Algorithm[TMARLRandomTrainingStats]): +class MARLRandomPolicy(Algorithm): """A random agent used in multi-agent learning. It randomly chooses an action from the legal action. From 62c817c38658293875015cb79b73776bd432ee42 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 22:30:17 +0100 Subject: [PATCH 007/230] ExperimentConfig: Do not inherit from anything (breaks jsonargparse auto-handling with defaults) --- tianshou/highlevel/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 6bf3b563c..fc6ab566b 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -118,7 +118,7 @@ @dataclass -class ExperimentConfig(ToStringMixin, DataclassPPrintMixin): +class ExperimentConfig: """Generic config for setting up the experiment, not RL or training specific.""" seed: int = 42 From 8ec202381948a3b1a3989f9916d9948e5bdb4b19 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 23:10:18 +0100 Subject: [PATCH 008/230] v2: Restore high-level API support for DDPG and DeepQLearning --- CHANGELOG.md | 10 +++++ tianshou/highlevel/algorithm.py | 66 ++++++++++++++++++++++++-------- tianshou/highlevel/experiment.py | 8 ++-- tianshou/policy/modelfree/dqn.py | 1 + 4 files changed, 66 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c032e5553..30ce4c6f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## Release 2.0.0 + +* We now conceptually differentiate between the learning algorithm and the policy being optimised: + * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`. + Migration information (`BasePolicy` -> `Algorithm`): + * `PGPolicy` -> `Reinforce` + * `DQNPolicy` -> `DeepQLearning` + * `DDPGPolicy` -> `DDPG` + * The `Algorithm` abstraction can directly initiate the learning process via method `run_training`. + ## Unreleased ### Changes/Improvements diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index c35c6c6d9..fa0ba4c3e 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -4,6 +4,7 @@ from typing import Any, Generic, TypeVar, cast import gymnasium +import torch from sensai.util.string import ToStringMixin from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer @@ -65,6 +66,8 @@ Policy, RandomActionPolicy, ) +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig @@ -275,7 +278,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> RandomAction return RandomActionPolicy(envs.get_action_space()) -class PGAlgorithmFactory(OnPolicyAlgorithmFactory): +class ReinforceAlgorithmFactory(OnPolicyAlgorithmFactory): def __init__( self, params: PGParams, @@ -417,14 +420,24 @@ def __init__( self.optim_factory = optim_factory @abstractmethod - def _get_policy_class(self) -> type[TAlgorithm]: + def _get_algorithm_class(self) -> type[TAlgorithm]: + pass + + @abstractmethod + def _create_discrete_critic_only_policy( + self, + model: torch.nn.Module, + params: dict, + action_space: gymnasium.spaces.Discrete, + observation_space: gymnasium.spaces.Space, + ) -> TPolicy: pass @typing.no_type_check def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: model = self.model_factory.create_module(envs, device) optim = self.optim_factory.create_optimizer(model, self.params.lr) - kwargs = self.params.create_kwargs( + params_dict = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, @@ -434,23 +447,40 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: ) envs.get_type().assert_discrete(self) action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) - policy_class = self._get_policy_class() - return policy_class( - model=model, + policy = self._create_discrete_critic_only_policy( + model, params_dict, action_space, envs.get_observation_space() + ) + algorithm_class = self._get_algorithm_class() + return algorithm_class( + policy=policy, optim=optim, - action_space=action_space, - observation_space=envs.get_observation_space(), - **kwargs, + **params_dict, ) -class DQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DeepQLearning]): - def _get_policy_class(self) -> type[DeepQLearning]: +class DeepQLearningAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DeepQLearning]): + def _create_discrete_critic_only_policy( + self, + model: torch.nn.Module, + params: dict, + action_space: gymnasium.spaces.Discrete, + observation_space: gymnasium.spaces.Space, + ) -> TPolicy: + return self._create_policy( + constructor=DQNPolicy, + params_dict=params, + policy_params=[], + model=model, + action_space=action_space, + observation_space=observation_space, + ) + + def _get_algorithm_class(self) -> type[DeepQLearning]: return DeepQLearning class IQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[IQNParams, IQNPolicy]): - def _get_policy_class(self) -> type[IQNPolicy]: + def _get_algorithm_class(self) -> type[IQNPolicy]: return IQNPolicy @@ -492,13 +522,19 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: critic1=critic, ), ) - return DDPG( + policy = self._create_policy( + DDPGPolicy, + kwargs, + ["action_scaling", "action_bound_method"], actor=actor.module, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + return DDPG( + policy=policy, policy_optim=actor.optim, critic=critic.module, critic_optim=critic.optim, - action_space=envs.get_action_space(), - observation_space=envs.get_observation_space(), **kwargs, ) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index fc6ab566b..d6f86eb95 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -42,14 +42,14 @@ A2CAlgorithmFactory, AlgorithmFactory, DDPGAlgorithmFactory, + DeepQLearningAlgorithmFactory, DiscreteSACAlgorithmFactory, - DQNAlgorithmFactory, IQNAlgorithmFactory, NPGAlgorithmFactory, - PGAlgorithmFactory, PPOAlgorithmFactory, RandomActionAlgorithmFactory, REDQAlgorithmFactory, + ReinforceAlgorithmFactory, SACAlgorithmFactory, TD3AlgorithmFactory, TRPOAlgorithmFactory, @@ -1042,7 +1042,7 @@ def with_pg_params(self, params: PGParams) -> Self: return self def _create_algorithm_factory(self) -> AlgorithmFactory: - return PGAlgorithmFactory( + return ReinforceAlgorithmFactory( self._params, self._sampling_config, self._get_actor_factory(), @@ -1220,7 +1220,7 @@ def with_model_factory_default( return self def _create_algorithm_factory(self) -> AlgorithmFactory: - return DQNAlgorithmFactory( + return DeepQLearningAlgorithmFactory( self._params, self._sampling_config, self._model_factory, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index ba0aecb50..24422d313 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -174,6 +174,7 @@ def set_eps(self, eps: float) -> None: def train(self, mode: bool = True) -> Self: """Set the module in training mode, except for the target network.""" + # TODO: Determine whether this is called correctly and who relies on this being called (for all subclasses) self.training = mode self.policy.train(mode) return self From e34d37b463260e1f02a3f6aa6b0e4e0a8bbecf68 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 23:44:30 +0100 Subject: [PATCH 009/230] v2: Set train mode on Algorithm instead of Policy (undoing previous change before Algorithm was an nn.Module) --- tianshou/data/collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 7fc661064..06307a829 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -469,7 +469,7 @@ def collect( self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) pre_collect_time = time.time() - with torch_train_mode(self.algorithm.policy, enabled=False): + with torch_train_mode(self.algorithm, enabled=False): collect_stats = self._collect( n_step=n_step, n_episode=n_episode, From f32e51b5af5ea782c6bc12e0dae17167d22df66d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 23:29:19 +0100 Subject: [PATCH 010/230] v2: Adjust QRDQN and corresponding test, adapt test_dqn --- examples/atari/atari_network.py | 2 +- examples/atari/atari_qrdqn.py | 8 +-- examples/offline/atari_cql.py | 4 +- test/discrete/test_dqn.py | 58 +++++++++++---------- test/discrete/test_qrdqn.py | 63 ++++++++++++----------- test/offline/gather_cartpole_data.py | 4 +- tianshou/policy/__init__.py | 4 +- tianshou/policy/imitation/discrete_bcq.py | 4 +- tianshou/policy/imitation/discrete_cql.py | 4 +- tianshou/policy/modelfree/bdq.py | 4 +- tianshou/policy/modelfree/c51.py | 2 +- tianshou/policy/modelfree/dqn.py | 28 +++++----- tianshou/policy/modelfree/fqf.py | 4 +- tianshou/policy/modelfree/iqn.py | 4 +- tianshou/policy/modelfree/qrdqn.py | 29 +++++------ 15 files changed, 115 insertions(+), 107 deletions(-) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 87797f760..54e81d12f 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -213,7 +213,7 @@ def forward( return probs, state -class QRDQN(DQN): +class QRDQNetwork(DQN): """Reference: Distributional Reinforcement Learning with Quantile Regression. For advanced usage (how to customize the network), please refer to diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index c5a658b08..33ef4f99f 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -6,12 +6,12 @@ import numpy as np import torch -from atari_network import QRDQN +from atari_network import QRDQNetwork from atari_wrapper import make_atari_env from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import QRDQNPolicy +from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer @@ -83,7 +83,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model c, h, w = args.state_shape - net = QRDQN( + net = QRDQNetwork( c=c, h=h, w=w, @@ -93,7 +93,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: QRDQNPolicy = QRDQNPolicy( + policy: QRDQN = QRDQN( model=net, optim=optim, action_space=env.action_space, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 3b8bd2783..e16bfbd81 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -12,7 +12,7 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import QRDQN +from examples.atari.atari_network import QRDQNetwork from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer @@ -97,7 +97,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = QRDQN( + net = QRDQNetwork( c=c, h=h, w=w, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 7c7588518..e6f102120 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -16,7 +16,8 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -87,13 +88,15 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # dueling=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DeepQLearning = DeepQLearning( - model=net, + policy = DQNPolicy( + model=net, action_space=env.action_space, observation_space=env.observation_space + ) + algorithm: DeepQLearning = DeepQLearning( + policy=policy, optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, - action_space=env.action_space, ) # buffer buf: ReplayBuffer @@ -107,8 +110,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -126,33 +129,34 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + algorithm.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + algorithm.set_eps(eps) else: - policy.set_eps(0.1 * args.eps_train) + algorithm.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index bf6928e7e..8fa02c417 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -13,10 +13,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import QRDQNPolicy +from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -92,10 +92,12 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: num_atoms=args.num_quantiles, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: QRDQNPolicy[QRDQNTrainingStats] = QRDQNPolicy( - model=net, + policy = QRDQNPolicy( + model=net, action_space=env.action_space, observation_space=env.observation_space + ) + algorithm = QRDQN( + policy=policy, optim=optim, - action_space=env.action_space, discount_factor=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, @@ -113,8 +115,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -123,8 +125,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: Algorithm) -> None: - torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + def save_best_fn(algo: Algorithm) -> None: + torch.save(algo.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold @@ -132,33 +134,34 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + algorithm.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + algorithm.set_eps(eps) else: - policy.set_eps(0.1 * args.eps_train) + algorithm.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - ).run() + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + ) + ) assert stop_fn(result.best_reward) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 4c69d0315..8c5252f72 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -14,7 +14,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import QRDQNPolicy +from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.trainer import OffpolicyTrainer @@ -97,7 +97,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: num_atoms=args.num_quantiles, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: QRDQNPolicy[QRDQNTrainingStats] = QRDQNPolicy( + policy: QRDQN[QRDQNTrainingStats] = QRDQN( model=net, optim=optim, action_space=env.action_space, diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 3bf28ec5f..0444c7bbc 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -10,7 +10,7 @@ from tianshou.policy.modelfree.bdq import BranchingDQNPolicy from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.modelfree.rainbow import RainbowPolicy -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.policy.modelfree.qrdqn import QRDQN from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.modelfree.fqf import FQFPolicy from tianshou.policy.modelfree.a2c import A2CPolicy @@ -40,7 +40,7 @@ "BranchingDQNPolicy", "C51Policy", "RainbowPolicy", - "QRDQNPolicy", + "QRDQN", "IQNPolicy", "FQFPolicy", "Reinforce", diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index f6b0b4d43..294c4cd9e 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import Any, Self, TypeVar, cast +from typing import Any, Generic, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -31,7 +31,7 @@ class DiscreteBCQTrainingStats(DQNTrainingStats): TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteBCQTrainingStats) -class DiscreteBCQPolicy(DeepQLearning[TDiscreteBCQTrainingStats]): +class DiscreteBCQPolicy(DeepQLearning, Generic[TDiscreteBCQTrainingStats]): """Implementation of discrete BCQ algorithm. arXiv:1910.01708. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 640bff0d2..3087b354e 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -8,7 +8,7 @@ from tianshou.data import to_torch from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import QRDQNPolicy +from tianshou.policy import QRDQN from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats @@ -22,7 +22,7 @@ class DiscreteCQLTrainingStats(QRDQNTrainingStats): TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteCQLTrainingStats) -class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): +class DiscreteCQLPolicy(QRDQN[TDiscreteCQLTrainingStats]): """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index 455711ce0..336bbcd0c 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -16,7 +16,7 @@ ) from tianshou.policy import DeepQLearning from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNTrainingStats +from tianshou.policy.modelfree.dqn import DQNTrainingStats, TDQNPolicy from tianshou.utils.net.common import BranchingNet @@ -28,7 +28,7 @@ class BDQNTrainingStats(DQNTrainingStats): TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) -class BranchingDQNPolicy(DeepQLearning[TBDQNTrainingStats]): +class BranchingDQNPolicy(DeepQLearning[TDQNPolicy, TBDQNTrainingStats]): """Implementation of the Branching dual Q network arXiv:1711.08946. :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 1d858de38..f4d4d3156 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -20,7 +20,7 @@ class C51TrainingStats(DQNTrainingStats): TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) -class C51Policy(DeepQLearning[TC51TrainingStats], Generic[TC51TrainingStats]): +class C51Policy(DeepQLearning, Generic[TC51TrainingStats]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 24422d313..2408b559b 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -92,15 +92,28 @@ def forward( # TODO: this is convoluted! See also other places where this is done. obs_next = obs.obs if hasattr(obs, "obs") else obs action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info) - q = DeepQLearning.compute_q_value(action_values_BA, getattr(obs, "mask", None)) + q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None)) if self.max_action_num is None: self.max_action_num = q.shape[1] act_B = to_numpy(q.argmax(dim=1)) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + """Compute the q value based on the network's raw output and action mask.""" + if mask is not None: + # the masked q value should be smaller than logits.min() + min_value = logits.min() - logits.max() - 1.0 + logits = logits + to_torch_as(1 - mask, logits) * min_value + return logits + + +TDQNPolicy = TypeVar("TDQNPolicy", bound=DQNPolicy) -class DeepQLearning(OffPolicyAlgorithm[DQNPolicy, TDQNTrainingStats], Generic[TDQNTrainingStats]): + +class DeepQLearning( + OffPolicyAlgorithm[TDQNPolicy, TDQNTrainingStats], Generic[TDQNPolicy, TDQNTrainingStats] +): """Implementation of Deep Q Network. arXiv:1312.5602. Implementation of Double Q-Learning. arXiv:1509.06461. @@ -132,7 +145,7 @@ class DeepQLearning(OffPolicyAlgorithm[DQNPolicy, TDQNTrainingStats], Generic[TD def __init__( self, *, - policy: DQNPolicy, + policy: TDQNPolicy, optim: torch.optim.Optimizer, # TODO: type violates Liskov substitution principle discount_factor: float = 0.99, @@ -220,15 +233,6 @@ def process_fn( rew_norm=self.rew_norm, ) - @staticmethod - def compute_q_value(logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: - """Compute the q value based on the network's raw output and action mask.""" - if mask is not None: - # the masked q value should be smaller than logits.min() - min_value = logits.min() - logits.max() - 1.0 - logits = logits + to_torch_as(1 - mask, logits) * min_value - return logits - def _update_with_batch( self, batch: RolloutBatchProtocol, diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index d00e38f09..89cd5d7c5 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -8,7 +8,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import DeepQLearning, QRDQNPolicy +from tianshou.policy import QRDQN, DeepQLearning from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -24,7 +24,7 @@ class FQFTrainingStats(QRDQNTrainingStats): TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) -class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): +class FQFPolicy(QRDQN[TFQFTrainingStats]): """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 1ecfae21f..e01d5eada 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -13,7 +13,7 @@ QuantileRegressionBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import QRDQNPolicy +from tianshou.policy import QRDQN from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats @@ -26,7 +26,7 @@ class IQNTrainingStats(QRDQNTrainingStats): TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) -class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): +class IQNPolicy(QRDQN[TIQNTrainingStats]): """Implementation of Implicit Quantile Network. arXiv:1806.06923. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 2f4d5ba04..1458e0e15 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Any, Generic, TypeVar -import gymnasium as gym import numpy as np import torch import torch.nn.functional as F @@ -11,7 +10,7 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import DeepQLearning from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNTrainingStats +from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats @dataclass(kw_only=True) @@ -22,7 +21,12 @@ class QRDQNTrainingStats(DQNTrainingStats): TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats) -class QRDQNPolicy(DeepQLearning[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): +class QRDQNPolicy(DQNPolicy): + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + return super().compute_q_value(logits.mean(2), mask) + + +class QRDQN(DeepQLearning[QRDQNPolicy, TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. :param model: a model following the rules (s -> action_values_BA) @@ -52,9 +56,8 @@ class QRDQNPolicy(DeepQLearning[TQRDQNTrainingStats], Generic[TQRDQNTrainingStat def __init__( self, *, - model: torch.nn.Module, + policy: QRDQNPolicy, optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, discount_factor: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, @@ -62,21 +65,18 @@ def __init__( reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" super().__init__( - model=model, + policy=policy, optim=optim, - action_space=action_space, discount_factor=discount_factor, estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, is_double=is_double, clip_loss_grad=clip_loss_grad, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) self.num_quantiles = num_quantiles @@ -93,17 +93,14 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: info=[None] * len(indices), ) # obs_next: s_{t+n} if self._target: - act = self(obs_next_batch).act - next_dist = self(obs_next_batch, model="model_old").logits + act = self.policy(obs_next_batch).act + next_dist = self.policy(obs_next_batch, model=self.model_old).logits else: - next_batch = self(obs_next_batch) + next_batch = self.policy(obs_next_batch) act = next_batch.act next_dist = next_batch.logits return next_dist[np.arange(len(act)), act, :] - def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: - return super().compute_q_value(logits.mean(2), mask) - def _update_with_batch( self, batch: RolloutBatchProtocol, @@ -114,7 +111,7 @@ def _update_with_batch( self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) - curr_dist = self(batch).logits + curr_dist = self.policy(batch).logits act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) From 8d4e182c63d1b2bb66f08f0584f9023a6a8cb5ae Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 00:15:56 +0100 Subject: [PATCH 011/230] ActorFactoryDefault: Fix hidden sizes and activation not being passed on for discrete case --- tianshou/highlevel/module/actor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 4a1fe5c2e..ceb1262f7 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -140,7 +140,8 @@ def _create_factory(self, envs: Environments) -> ActorFactory: raise ValueError(self.continuous_actor_type) elif env_type == EnvType.DISCRETE: factory = ActorFactoryDiscreteNet( - self.DEFAULT_HIDDEN_SIZES, + self.hidden_sizes, + activation=self.hidden_activation, softmax_output=self.discrete_softmax, ) else: From b63dcd5740936816232c04b05defaf7397b40130 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 10:43:54 +0100 Subject: [PATCH 012/230] v2: Adapt C51, test_c51 and atari_c51 Move atari_network and atari_wrapper into the library under tianshou.env.atari (this is more convenient and cleans up the example structure) --- CHANGELOG.md | 2 + examples/atari/atari_c51.py | 79 +++++++++------- examples/atari/atari_dqn.py | 4 +- examples/atari/atari_dqn_hl.py | 4 +- examples/atari/atari_fqf.py | 10 +- examples/atari/atari_iqn.py | 10 +- examples/atari/atari_iqn_hl.py | 4 +- examples/atari/atari_ppo.py | 12 +-- examples/atari/atari_ppo_hl.py | 4 +- examples/atari/atari_qrdqn.py | 10 +- examples/atari/atari_rainbow.py | 8 +- examples/atari/atari_sac.py | 8 +- examples/atari/atari_sac_hl.py | 4 +- examples/offline/atari_bcq.py | 4 +- examples/offline/atari_cql.py | 4 +- examples/offline/atari_crr.py | 4 +- examples/offline/atari_il.py | 4 +- examples/vizdoom/network.py | 1 - examples/vizdoom/vizdoom_c51.py | 8 +- examples/vizdoom/vizdoom_ppo.py | 6 +- test/discrete/test_c51.py | 76 ++++++++------- .../env}/atari/atari_network.py | 16 ++-- .../env}/atari/atari_wrapper.py | 1 + tianshou/policy/__init__.py | 4 +- tianshou/policy/modelfree/c51.py | 93 +++++++++++-------- tianshou/policy/modelfree/rainbow.py | 4 +- 26 files changed, 208 insertions(+), 176 deletions(-) delete mode 120000 examples/vizdoom/network.py rename {examples => tianshou/env}/atari/atari_network.py (98%) rename {examples => tianshou/env}/atari/atari_wrapper.py (99%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 30ce4c6f1..b05273746 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ * `DDPGPolicy` -> `DDPG` * The `Algorithm` abstraction can directly initiate the learning process via method `run_training`. +* Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. + ## Unreleased ### Changes/Improvements diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 326b89ea2..886f991de 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -3,17 +3,20 @@ import os import pprint import sys +from typing import cast import numpy as np import torch -from atari_network import C51 -from atari_wrapper import make_atari_env +from gym.spaces import Discrete from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import C51Net +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51Policy +from tianshou.policy import C51 from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.trainer.base import OffPolicyTrainingConfig def get_args() -> argparse.Namespace: @@ -66,7 +69,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_c51(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -84,23 +87,26 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) + net = C51Net(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: C51Policy = C51Policy( + policy = C51Policy( model=net, - optim=optim, - discount_factor=args.gamma, - action_space=env.action_space, + action_space=cast(Discrete, env.action_space), num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + ) + algorithm = C51( + policy=policy, + optim=optim, + discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -112,8 +118,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -152,17 +158,17 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + algorithm.set_eps(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -173,7 +179,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -192,27 +200,28 @@ def watch() -> None: train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_c51(get_args()) + main(get_args()) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index f2863a8d0..18d6b1184 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -8,7 +8,7 @@ import torch from atari_wrapper import make_atari_env -from examples.atari.atari_network import DQN +from examples.atari.atari_network import DQNet from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DeepQLearning @@ -101,7 +101,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) + net = DQNet(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: DeepQLearning | ICMPolicy diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 601481523..dc3a9fd26 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -5,11 +5,11 @@ from sensai.util import logging from sensai.util.logging import datetime_tag -from examples.atari.atari_network import ( +from tianshou.env.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( DQNExperimentBuilder, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 0ac0db560..04d7905d7 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -6,10 +6,10 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import make_atari_env from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import FQFPolicy from tianshou.policy.base import Algorithm @@ -69,7 +69,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_fqf(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -87,7 +87,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) net = FullQuantileFunction( feature_net, args.action_shape, @@ -228,4 +228,4 @@ def watch() -> None: if __name__ == "__main__": - test_fqf(get_args()) + main(get_args()) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 3d6ad57c4..869c0e158 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -6,10 +6,10 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import make_atari_env from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import IQNPolicy from tianshou.policy.base import Algorithm @@ -69,7 +69,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_iqn(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -87,7 +87,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) net = ImplicitQuantileNetwork( feature_net, args.action_shape, @@ -226,4 +226,4 @@ def watch() -> None: if __name__ == "__main__": - test_iqn(get_args()) + main(get_args()) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index b71b0eef3..0cdfa4d81 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -6,10 +6,10 @@ from sensai.util import logging from sensai.util.logging import datetime_tag -from examples.atari.atari_network import ( +from tianshou.env.atari.atari_network import ( IntermediateModuleFactoryAtariDQN, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 3699aa7f1..7a48dc869 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -6,12 +6,12 @@ import numpy as np import torch -from atari_network import DQN, layer_init, scale_obs -from atari_wrapper import make_atari_env from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet, layer_init, scale_obs +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import Algorithm @@ -92,7 +92,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_ppo(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -110,7 +110,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = DQN( + net = DQNet( *args.state_shape, args.action_shape, device=args.device, @@ -156,7 +156,7 @@ def dist(logits: torch.Tensor) -> Categorical: recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( @@ -285,4 +285,4 @@ def watch() -> None: if __name__ == "__main__": - test_ppo(get_args()) + main(get_args()) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 26ebaba08..e0939ecc8 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -6,11 +6,11 @@ from sensai.util import logging from sensai.util.logging import datetime_tag -from examples.atari.atari_network import ( +from tianshou.env.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 33ef4f99f..f4496dbc8 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -6,10 +6,10 @@ import numpy as np import torch -from atari_network import QRDQNetwork -from atari_wrapper import make_atari_env from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import QRDQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm @@ -64,7 +64,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_qrdqn(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -83,7 +83,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model c, h, w = args.state_shape - net = QRDQNetwork( + net = QRDQNet( c=c, h=h, w=w, @@ -219,4 +219,4 @@ def watch() -> None: if __name__ == "__main__": - test_qrdqn(get_args()) + main(get_args()) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 60f07140c..71a6a97d8 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -6,8 +6,6 @@ import numpy as np import torch -from atari_network import Rainbow -from atari_wrapper import make_atari_env from tianshou.data import ( Collector, @@ -15,8 +13,10 @@ PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) +from tianshou.env.atari.atari_network import Rainbow +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51Policy, RainbowPolicy +from tianshou.policy import C51, RainbowPolicy from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer @@ -109,7 +109,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: C51Policy = RainbowPolicy( + policy: C51 = RainbowPolicy( model=net, optim=optim, discount_factor=args.gamma, diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 48984b5b2..24e7d6c86 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -6,10 +6,10 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import make_atari_env from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteSACPolicy, ICMPolicy from tianshou.policy.base import Algorithm @@ -103,7 +103,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = DQN( + net = DQNet( *args.state_shape, args.action_shape, device=args.device, @@ -139,7 +139,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, ).to(args.device) if args.icm_lr_scale > 0: - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 124def768..23186606a 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -6,11 +6,11 @@ from sensai.util import logging from sensai.util.logging import datetime_tag -from examples.atari.atari_network import ( +from tianshou.env.atari.atari_network import ( ActorFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures, ) -from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback +from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( DiscreteSACExperimentBuilder, diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 04dde73bb..310087427 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -11,7 +11,7 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import DQN +from examples.atari.atari_network import DQNet from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer @@ -96,7 +96,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: assert args.state_shape is not None assert len(args.state_shape) == 3 c, h, w = args.state_shape - feature_net = DQN( + feature_net = DQNet( c, h, w, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index e16bfbd81..48ce2faaf 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -12,7 +12,7 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import QRDQNetwork +from examples.atari.atari_network import QRDQNet from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer @@ -97,7 +97,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = QRDQNetwork( + net = QRDQNet( c=c, h=h, w=w, diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 833e7a008..cd77730f6 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -11,7 +11,7 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import DQN +from examples.atari.atari_network import DQNet from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer @@ -98,7 +98,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: assert args.state_shape is not None assert len(args.state_shape) == 3 c, h, w = args.state_shape - feature_net = DQN( + feature_net = DQNet( c, h, w, diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 45d5c5c44..ffc7c5457 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -10,7 +10,7 @@ import numpy as np import torch -from examples.atari.atari_network import DQN +from examples.atari.atari_network import DQNet from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer @@ -87,7 +87,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = DQN(c, h, w, args.action_shape, device=args.device).to(args.device) + net = DQNet(c, h, w, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: ImitationPolicy = ImitationPolicy(actor=net, optim=optim, action_space=env.action_space) diff --git a/examples/vizdoom/network.py b/examples/vizdoom/network.py deleted file mode 120000 index a0c543acb..000000000 --- a/examples/vizdoom/network.py +++ /dev/null @@ -1 +0,0 @@ -../atari/atari_network.py \ No newline at end of file diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index b741e7d7a..c81cb4fdf 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -7,11 +7,11 @@ import numpy as np import torch from env import make_vizdoom_env -from network import C51 from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import C51Net from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51Policy +from tianshou.policy import C51 from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer @@ -92,10 +92,10 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) + net = C51Net(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: C51Policy = C51Policy( + policy: C51 = C51( model=net, optim=optim, discount_factor=args.gamma, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index d858305b8..426ca08ea 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -7,11 +7,11 @@ import numpy as np import torch from env import make_vizdoom_env -from network import DQN from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import Algorithm @@ -118,7 +118,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = DQN( + net = DQNet( *args.state_shape, args.action_shape, device=args.device, @@ -161,7 +161,7 @@ def dist(logits: torch.Tensor) -> Categorical: recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: - feature_net = DQN( + feature_net = DQNet( *args.state_shape, args.action_shape, device=args.device, diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 8d7bd5b6b..96e4628f9 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -15,9 +15,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import C51Policy +from tianshou.policy import C51 from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -93,14 +94,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: num_atoms=args.num_atoms, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: C51Policy = C51Policy( + policy = C51Policy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, + observation_space=env.observation_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + ) + algorithm: C51 = C51( + policy=policy, + optim=optim, + discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) @@ -116,8 +121,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -126,8 +131,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy: Algorithm) -> None: - torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + def save_best_fn(algorithm: Algorithm) -> None: + torch.save(algorithm.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold @@ -135,15 +140,15 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + algorithm.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + algorithm.set_eps(eps) else: - policy.set_eps(0.1 * args.eps_train) + algorithm.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html @@ -152,7 +157,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - "model": policy.state_dict(), + "model": algorithm.state_dict(), "optim": optim.state_dict(), }, ckpt_path, @@ -168,8 +173,8 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) - policy.optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint["model"]) + algorithm.optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") @@ -181,25 +186,26 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: else: print("Fail to restore buffer.") - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) assert stop_fn(result.best_reward) diff --git a/examples/atari/atari_network.py b/tianshou/env/atari/atari_network.py similarity index 98% rename from examples/atari/atari_network.py rename to tianshou/env/atari/atari_network.py index 54e81d12f..3b1d90d64 100644 --- a/examples/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -21,6 +21,7 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: + """TODO.""" torch.nn.init.orthogonal_(layer.weight, std) torch.nn.init.constant_(layer.bias, bias_const) return layer @@ -46,10 +47,11 @@ def forward( def scale_obs(module: NetBase, denom: float = 255.0) -> ScaledObsInputModule: + """TODO.""" return ScaledObsInputModule(module, denom=denom) -class DQN(NetBase[Any]): +class DQNet(NetBase[Any]): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -116,7 +118,7 @@ def forward( return self.net(obs), state -class C51(DQN): +class C51Net(DQNet): """Reference: A distributional perspective on reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -150,7 +152,7 @@ def forward( return obs, state -class Rainbow(DQN): +class Rainbow(DQNet): """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning. For advanced usage (how to customize the network), please refer to @@ -213,7 +215,7 @@ def forward( return probs, state -class QRDQNetwork(DQN): +class QRDQNet(DQNet): """Reference: Distributional Reinforcement Learning with Quantile Regression. For advanced usage (how to customize the network), please refer to @@ -265,8 +267,8 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) - net: DQN | ScaledObsInputModule - net = DQN( + net: DQNet | ScaledObsInputModule + net = DQNet( c=c, h=h, w=w, @@ -305,7 +307,7 @@ def create_intermediate_module(self, envs: Environments, device: TDevice) -> Int action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) - dqn = DQN( + dqn = DQNet( c=c, h=h, w=w, diff --git a/examples/atari/atari_wrapper.py b/tianshou/env/atari/atari_wrapper.py similarity index 99% rename from examples/atari/atari_wrapper.py rename to tianshou/env/atari/atari_wrapper.py index d7234d863..de375eaa3 100644 --- a/examples/atari/atari_wrapper.py +++ b/tianshou/env/atari/atari_wrapper.py @@ -40,6 +40,7 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]: def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]: + """TODO.""" obs_space_dtype: type[np.integer] | type[np.floating] if np.issubdtype(obs_space.dtype, np.integer): obs_space_dtype = np.integer diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 0444c7bbc..dcb314b69 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -8,7 +8,7 @@ from tianshou.policy.random import MARLRandomPolicy from tianshou.policy.modelfree.bdq import BranchingDQNPolicy -from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.modelfree.c51 import C51 from tianshou.policy.modelfree.rainbow import RainbowPolicy from tianshou.policy.modelfree.qrdqn import QRDQN from tianshou.policy.modelfree.iqn import IQNPolicy @@ -38,7 +38,7 @@ "MARLRandomPolicy", "DeepQLearning", "BranchingDQNPolicy", - "C51Policy", + "C51", "RainbowPolicy", "QRDQN", "IQNPolicy", diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index f4d4d3156..4905ae079 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -9,7 +9,8 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import DeepQLearning from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNTrainingStats +from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats +from tianshou.utils.net.common import Net @dataclass(kw_only=True) @@ -20,18 +21,47 @@ class C51TrainingStats(DQNTrainingStats): TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) -class C51Policy(DeepQLearning, Generic[TC51TrainingStats]): - """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. +class C51Policy(DQNPolicy): + def __init__( + self, + model: torch.nn.Module | Net, + action_space: gym.spaces.Discrete, + observation_space: gym.Space | None = None, + num_atoms: int = 51, + v_min: float = -10.0, + v_max: float = 10.0, + ): + """ + :param model: a model following the rules (s_B -> action_values_BA) + :param num_atoms: the number of atoms in the support set of the + value distribution. Default to 51. + :param v_min: the value of the smallest atom in the support set. + Default to -10.0. + :param v_max: the value of the largest atom in the support set. + Default to 10.0. + """ + super().__init__( + model=model, action_space=action_space, observation_space=observation_space + ) + assert num_atoms > 1, f"num_atoms should be greater than 1 but got: {num_atoms}" + assert v_min < v_max, f"v_max should be larger than v_min, but got {v_min=} and {v_max=}" + self.num_atoms = num_atoms + self.v_min = v_min + self.v_max = v_max + self.support = torch.nn.Parameter( + torch.linspace(self.v_min, self.v_max, self.num_atoms), + requires_grad=False, + ) - :param model: a model following the rules (s_B -> action_values_BA) - :param optim: a torch.optim for optimizing the model. + def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: + return super().compute_q_value((logits * self.support).sum(2), mask) + + +class C51(DeepQLearning[C51Policy, TC51TrainingStats], Generic[TC51TrainingStats]): + """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. + :param policy: a policy following the rules (s -> action_values_BA) + :param optim: a torch.optim for optimizing the policy. :param discount_factor: in [0, 1]. - :param num_atoms: the number of atoms in the support set of the - value distribution. Default to 51. - :param v_min: the value of the smallest atom in the support set. - Default to -10.0. - :param v_max: the value of the largest atom in the support set. - Default to 10.0. :param estimation_step: the number of steps to look ahead. :param target_update_freq: the target network update frequency (0 if you do not use the target network). @@ -53,66 +83,49 @@ class C51Policy(DeepQLearning, Generic[TC51TrainingStats]): def __init__( self, *, - model: torch.nn.Module, + policy: C51Policy, optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, discount_factor: float = 0.99, - num_atoms: int = 51, - v_min: float = -10.0, - v_max: float = 10.0, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: - assert num_atoms > 1, f"num_atoms should be greater than 1 but got: {num_atoms}" - assert v_min < v_max, f"v_max should be larger than v_min, but got {v_min=} and {v_max=}" super().__init__( - model=model, + policy=policy, optim=optim, - action_space=action_space, discount_factor=discount_factor, estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, is_double=is_double, clip_loss_grad=clip_loss_grad, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) - self._num_atoms = num_atoms - self._v_min = v_min - self._v_max = v_max - self.support = torch.nn.Parameter( - torch.linspace(self._v_min, self._v_max, self._num_atoms), - requires_grad=False, - ) - self.delta_z = (v_max - v_min) / (num_atoms - 1) + self.delta_z = (policy.v_max - policy.v_min) / (policy.num_atoms - 1) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - return self.support.repeat(len(indices), 1) # shape: [bsz, num_atoms] - - def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: - return super().compute_q_value((logits * self.support).sum(2), mask) + return self.policy.support.repeat(len(indices), 1) # shape: [bsz, num_atoms] def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) if self._target: - act = self(obs_next_batch).act - next_dist = self(obs_next_batch, model="model_old").logits + act = self.policy(obs_next_batch).act + next_dist = self.policy(obs_next_batch, model=self.model_old).logits else: - next_batch = self(obs_next_batch) + next_batch = self.policy(obs_next_batch) act = next_batch.act next_dist = next_batch.logits next_dist = next_dist[np.arange(len(act)), act, :] - target_support = batch.returns.clamp(self._v_min, self._v_max) + target_support = batch.returns.clamp(self.policy.v_min, self.policy.v_max) # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL target_dist = ( - 1 - (target_support.unsqueeze(1) - self.support.view(1, -1, 1)).abs() / self.delta_z + 1 + - (target_support.unsqueeze(1) - self.policy.support.view(1, -1, 1)).abs() + / self.delta_z ).clamp(0, 1) * next_dist.unsqueeze(1) return target_dist.sum(-1) @@ -128,7 +141,7 @@ def _update_with_batch( with torch.no_grad(): target_dist = self._target_dist(batch) weight = batch.pop("weight", 1.0) - curr_dist = self(batch).logits + curr_dist = self.policy(batch).logits act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :] cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index fc5b6637f..fc0af2cf4 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -4,7 +4,7 @@ from torch import nn from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import C51Policy +from tianshou.policy import C51 from tianshou.policy.modelfree.c51 import C51TrainingStats from tianshou.utils.net.discrete import NoisyLinear @@ -36,7 +36,7 @@ class RainbowTrainingStats(C51TrainingStats): # TODO: is this class worth keeping? It barely does anything -class RainbowPolicy(C51Policy[TRainbowTrainingStats]): +class RainbowPolicy(C51[TRainbowTrainingStats]): """Implementation of Rainbow DQN. arXiv:1710.02298. Same parameters as :class:`~tianshou.policy.C51Policy`. From 6eb3170633d2d6f37b515f5d5ea61e0fd34598d1 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 13:10:31 +0100 Subject: [PATCH 013/230] v2: Adapt A2C and test_a2c_with_il (skipping the il part) --- examples/mujoco/mujoco_a2c.py | 4 +-- test/discrete/test_a2c_with_il.py | 52 +++++++++++++++++-------------- tianshou/highlevel/algorithm.py | 8 ++--- tianshou/policy/__init__.py | 4 +-- tianshou/policy/modelfree/a2c.py | 42 +++++-------------------- tianshou/policy/modelfree/npg.py | 4 +-- tianshou/policy/modelfree/pg.py | 52 +++++++++++++++++-------------- tianshou/policy/modelfree/ppo.py | 4 +-- tianshou/utils/net/discrete.py | 1 + 9 files changed, 78 insertions(+), 93 deletions(-) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 80325e40c..4069858c4 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import A2CPolicy +from tianshou.policy import A2C from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import ActorCritic, Net @@ -141,7 +141,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: A2CPolicy = A2CPolicy( + policy: A2C = A2C( actor=actor, critic=critic, optim=optim, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index aafbac62a..eab6407df 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -9,9 +9,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import A2CPolicy, ImitationPolicy +from tianshou.policy import A2C, ImitationPolicy from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic @@ -94,29 +96,31 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy: Algorithm - policy = A2CPolicy( + policy = ActorPolicy( actor=actor, - critic=critic, - optim=optim, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), + action_space=env.action_space, + ) + algorithm = A2C( + policy=policy, + critic=critic, + optim=optim, discount_factor=args.gamma, gae_lambda=args.gae_lambda, vf_coef=args.vf_coef, ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) train_collector.reset() - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) test_collector.reset() # log log_path = os.path.join(args.logdir, args.task, "a2c") @@ -130,20 +134,22 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + episode_per_collect=args.episode_per_collect, + step_per_collect=None, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) # here we define an imitation collector with a trivial policy diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index fa0ba4c3e..6601f7c62 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -46,8 +46,8 @@ from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.world import World from tianshou.policy import ( + A2C, DDPG, - A2CPolicy, Algorithm, DeepQLearning, DiscreteSACPolicy, @@ -383,9 +383,9 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: return policy_class(**self._create_kwargs(envs, device)) -class A2CAlgorithmFactory(ActorCriticAlgorithmFactory[A2CParams, A2CPolicy]): - def _get_policy_class(self) -> type[A2CPolicy]: - return A2CPolicy +class A2CAlgorithmFactory(ActorCriticAlgorithmFactory[A2CParams, A2C]): + def _get_policy_class(self) -> type[A2C]: + return A2C class PPOAlgorithmFactory(ActorCriticAlgorithmFactory[PPOParams, PPOPolicy]): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index dcb314b69..6a37e949d 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -13,7 +13,7 @@ from tianshou.policy.modelfree.qrdqn import QRDQN from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.modelfree.fqf import FQFPolicy -from tianshou.policy.modelfree.a2c import A2CPolicy +from tianshou.policy.modelfree.a2c import A2C from tianshou.policy.modelfree.npg import NPGPolicy from tianshou.policy.modelfree.ppo import PPOPolicy from tianshou.policy.modelfree.trpo import TRPOPolicy @@ -44,7 +44,7 @@ "IQNPolicy", "FQFPolicy", "Reinforce", - "A2CPolicy", + "A2C", "NPGPolicy", "DDPG", "PPOPolicy", diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index d0eb10b4c..046315509 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, cast +from typing import Any, Generic, TypeVar, cast -import gymnasium as gym import numpy as np import torch import torch.nn.functional as F @@ -11,10 +10,9 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import Reinforce from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -29,17 +27,11 @@ class A2CTrainingStats(TrainingStats): TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class A2CPolicy(Reinforce[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] +class A2C(Reinforce[TA2CTrainingStats], Generic[TA2CTrainingStats]): """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. :param max_grad_norm: clipping gradients in back propagation. @@ -47,12 +39,6 @@ class A2CPolicy(Reinforce[TA2CTrainingStats], Generic[TA2CTrainingStats]): # ty :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. :param lr_scheduler: if not None, will be called in `policy.update()`. .. seealso:: @@ -64,11 +50,9 @@ class A2CPolicy(Reinforce[TA2CTrainingStats], Generic[TA2CTrainingStats]): # ty def __init__( self, *, - actor: torch.nn.Module | ActorProb | DiscreteActor, + policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float | None = None, @@ -77,23 +61,13 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: super().__init__( - actor=actor, + policy=policy, optim=optim, - dist_fn=dist_fn, - action_space=action_space, discount_factor=discount_factor, reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, lr_scheduler=lr_scheduler, ) self.critic = critic @@ -103,7 +77,7 @@ def __init__( self.ent_coef = ent_coef self.max_grad_norm = max_grad_norm self.max_batchsize = max_batchsize - self._actor_critic = ActorCritic(self.actor, self.critic) + self._actor_critic = ActorCritic(self.policy.actor, self.critic) def process_fn( self, @@ -170,7 +144,7 @@ def _update_with_batch( # type: ignore for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): # calculate loss for actor - dist = self(minibatch).dist + dist = self.policy(minibatch).dist log_prob = dist.log_prob(minibatch.act) log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1) actor_loss = -(log_prob * minibatch.adv).mean() diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index ce07e088e..bb1043df2 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -10,7 +10,7 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy import A2CPolicy +from tianshou.policy import A2C from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.continuous import ActorProb, Critic @@ -29,7 +29,7 @@ class NPGTrainingStats(TrainingStats): # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] +class NPGPolicy(A2C[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] """Implementation of Natural Policy Gradient. https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index ad5efeddb..0f29b7039 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -67,6 +67,25 @@ def __init__( action_scaling: bool = True, action_bound_method: Literal["clip", "tanh"] | None = "clip", ) -> None: + """ + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). + :param dist_fn: distribution class for computing the action. + Maps model_output -> distribution. Typically, a Gaussian distribution + taking `model_output=mean,std` as input for continuous action spaces, + or a categorical distribution taking `model_output=logits` + for discrete action spaces. Note that as user, you are responsible + for ensuring that the distribution is compatible with the action space. + :param action_space: env's action space. + :param deterministic_eval: if True, will use deterministic action (the dist's mode) + instead of stochastic one during evaluation. Does not affect training. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + """ super().__init__( action_space=action_space, observation_space=observation_space, @@ -124,30 +143,6 @@ def forward( class Reinforce(OnPolicyAlgorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats]): """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm. - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param optim: optimizer for actor network. - :param dist_fn: distribution class for computing the action. - Maps model_output -> distribution. Typically a Gaussian distribution - taking `model_output=mean,std` as input for continuous action spaces, - or a categorical distribution taking `model_output=logits` - for discrete action spaces. Note that as user, you are responsible - for ensuring that the distribution is compatible with the action space. - :param action_space: env's action space. - :param discount_factor: in [0, 1]. - :param reward_normalization: if True, will normalize the *returns* - by subtracting the running mean and dividing by the running standard deviation. - Can be detrimental to performance! See TODO in process_fn. - :param deterministic_eval: if True, will use deterministic action (the dist's mode) - instead of stochastic one during evaluation. Does not affect training. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. @@ -163,6 +158,15 @@ def __init__( optim: torch.optim.Optimizer, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param optim: optimizer for actor network. + :param discount_factor: in [0, 1]. + :param reward_normalization: if True, will normalize the *returns* + by subtracting the running mean and dividing by the running standard deviation. + Can be detrimental to performance! See TODO in process_fn. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ super().__init__( policy=policy, lr_scheduler=lr_scheduler, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index c1b29fbf4..a11c88509 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -9,7 +9,7 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol -from tianshou.policy import A2CPolicy +from tianshou.policy import A2C from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import ActorCritic @@ -49,7 +49,7 @@ def from_sequences( # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] +class PPOPolicy(A2C[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. :param actor: the actor network following the rules: diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index ab9069801..6ea654929 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -10,6 +10,7 @@ from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim +# TODO rename to DiscreteActor? class Actor(BaseActor): """Simple actor network for discrete action spaces. From df9d4bcaaaadda0bfe6605d72ede81c41ae2ad2e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 13:19:24 +0100 Subject: [PATCH 014/230] v2: Adapt PPO and test_ppo --- examples/atari/atari_ppo.py | 4 +- examples/mujoco/mujoco_ppo.py | 4 +- examples/vizdoom/vizdoom_ppo.py | 4 +- test/base/test_policy.py | 8 +-- test/continuous/test_ppo.py | 62 ++++++++--------- test/discrete/test_ppo.py | 4 +- test/modelbased/test_ppo_icm.py | 4 +- test/pettingzoo/pistonball_continuous.py | 4 +- tianshou/highlevel/algorithm.py | 8 +-- tianshou/policy/__init__.py | 4 +- tianshou/policy/imitation/gail.py | 4 +- tianshou/policy/modelfree/ppo.py | 84 +++++++++--------------- 12 files changed, 86 insertions(+), 108 deletions(-) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 7a48dc869..0b441add7 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -13,7 +13,7 @@ from tianshou.env.atari.atari_network import DQNet, layer_init, scale_obs from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.policy import PPO, ICMPolicy from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import ActorCritic @@ -135,7 +135,7 @@ def main(args: argparse.Namespace = get_args()) -> None: def dist(logits: torch.Tensor) -> Categorical: return Categorical(logits=logits) - policy: PPOPolicy = PPOPolicy( + policy: PPO = PPO( actor=actor, critic=critic, optim=optim, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index a385efc59..da18ea3ab 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PPOPolicy +from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import ActorCritic, Net @@ -141,7 +141,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: PPOPolicy = PPOPolicy( + policy: PPO = PPO( actor=actor, critic=critic, optim=optim, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 426ca08ea..7ba1bf503 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -13,7 +13,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.policy import PPO, ICMPolicy from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import ActorCritic @@ -140,7 +140,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: def dist(logits: torch.Tensor) -> Categorical: return Categorical(logits=logits) - policy: PPOPolicy = PPOPolicy( + policy: PPO = PPO( actor=actor, critic=critic, optim=optim, diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 753291a63..618958825 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -5,7 +5,7 @@ from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.data import Batch -from tianshou.policy import Algorithm, PPOPolicy +from tianshou.policy import PPO, Algorithm from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -26,7 +26,7 @@ def test_calculate_discounted_returns() -> None: @pytest.fixture(params=["continuous", "discrete"]) -def policy(request: pytest.FixtureRequest) -> PPOPolicy: +def policy(request: pytest.FixtureRequest) -> PPO: action_type = request.param action_space: gym.spaces.Box | gym.spaces.Discrete actor: Actor | ActorProb @@ -59,7 +59,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3) policy: Algorithm - policy = PPOPolicy( + policy = PPO( actor=actor, critic=critic, dist_fn=dist_fn, @@ -72,7 +72,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: class TestPolicyBasics: - def test_get_action(self, policy: PPOPolicy) -> None: + def test_get_action(self, policy: PPO) -> None: policy.is_within_training_step = False sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 576353c53..cd9e02e73 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -9,10 +9,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PPOPolicy +from tianshou.policy import PPO from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.trainer import OnpolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -103,11 +103,15 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + ) + algorithm = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -119,15 +123,14 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) @@ -146,7 +149,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - "model": policy.state_dict(), + "model": algorithm.state_dict(), "optim": optim.state_dict(), }, ckpt_path, @@ -159,36 +162,33 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) + algorithm.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") # trainer - trainer = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + episode_per_collect=args.episode_per_collect, + step_per_collect=None, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + ) ) - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - print(epoch_stat) - # print(info) - - assert stop_fn(epoch_stat.info_stat.best_reward) + assert stop_fn(result.best_reward) def test_ppo_resume(args: argparse.Namespace = get_args()) -> None: diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 88b0e2b86..a7c80caa7 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PPOPolicy +from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.trainer import OnpolicyTrainer @@ -96,7 +96,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( + policy: PPO[PPOTrainingStats] = PPO( actor=actor, critic=critic, optim=optim, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index beb6561b7..544107e51 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.policy import PPO, ICMPolicy from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.trainer import OnpolicyTrainer @@ -109,7 +109,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( + policy: PPO[PPOTrainingStats] = PPO( actor=actor, critic=critic, optim=optim, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 935e65a6e..ef61508fa 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -15,7 +15,7 @@ from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import Algorithm, MultiAgentPolicyManager, PPOPolicy +from tianshou.policy import PPO, Algorithm, MultiAgentPolicyManager from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.continuous import ActorProb, Critic @@ -186,7 +186,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - agent: PPOPolicy = PPOPolicy( + agent: PPO = PPO( actor=actor, critic=critic, optim=optim, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 6601f7c62..88eda1b76 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -48,12 +48,12 @@ from tianshou.policy import ( A2C, DDPG, + PPO, Algorithm, DeepQLearning, DiscreteSACPolicy, IQNPolicy, NPGPolicy, - PPOPolicy, REDQPolicy, Reinforce, SACPolicy, @@ -388,9 +388,9 @@ def _get_policy_class(self) -> type[A2C]: return A2C -class PPOAlgorithmFactory(ActorCriticAlgorithmFactory[PPOParams, PPOPolicy]): - def _get_policy_class(self) -> type[PPOPolicy]: - return PPOPolicy +class PPOAlgorithmFactory(ActorCriticAlgorithmFactory[PPOParams, PPO]): + def _get_policy_class(self) -> type[PPO]: + return PPO class NPGAlgorithmFactory(ActorCriticAlgorithmFactory[NPGParams, NPGPolicy]): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 6a37e949d..e8bb6e75d 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -15,7 +15,7 @@ from tianshou.policy.modelfree.fqf import FQFPolicy from tianshou.policy.modelfree.a2c import A2C from tianshou.policy.modelfree.npg import NPGPolicy -from tianshou.policy.modelfree.ppo import PPOPolicy +from tianshou.policy.modelfree.ppo import PPO from tianshou.policy.modelfree.trpo import TRPOPolicy from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy @@ -47,7 +47,7 @@ "A2C", "NPGPolicy", "DDPG", - "PPOPolicy", + "PPO", "TRPOPolicy", "TD3Policy", "SACPolicy", diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 9ffd6a6b7..49a80084e 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -13,7 +13,7 @@ to_torch, ) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol -from tianshou.policy import PPOPolicy +from tianshou.policy import PPO from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.policy.modelfree.ppo import PPOTrainingStats @@ -32,7 +32,7 @@ class GailTrainingStats(PPOTrainingStats): TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats) -class GAILPolicy(PPOPolicy[TGailTrainingStats]): +class GAILPolicy(PPO[TGailTrainingStats]): r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. :param actor: the actor network following the rules: diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index a11c88509..215235f03 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,8 +1,7 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar +from typing import Any, Generic, Self, TypeVar -import gymnasium as gym import numpy as np import torch from torch import nn @@ -11,10 +10,9 @@ from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import A2C from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -49,40 +47,9 @@ def from_sequences( # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class PPOPolicy(A2C[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] +class PPO(A2C[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original - paper. - :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, - where c > 1 is a constant indicating the lower bound. Set to None - to disable dual-clip PPO. - :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. - :param advantage_normalization: whether to do per mini-batch advantage - normalization. - :param recompute_advantage: whether to recompute advantage every update - repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - :param lr_scheduler: if not None, will be called in `policy.update()`. - .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed @@ -92,11 +59,9 @@ class PPOPolicy(A2C[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ig def __init__( self, *, - actor: torch.nn.Module | ActorProb | DiscreteActor, + policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = False, @@ -110,22 +75,39 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + r""" + :param policy: the policy + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original + paper. + :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, + where c > 1 is a constant indicating the lower bound. Set to None + to disable dual-clip PPO. + :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param recompute_advantage: whether to recompute advantage every update + repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. + :param vf_coef: weight for value loss. + :param ent_coef: weight for entropy loss. + :param max_grad_norm: clipping gradients in back propagation. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ assert ( dual_clip is None or dual_clip > 1.0 ), f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}" super().__init__( - actor=actor, + policy=policy, critic=critic, optim=optim, - dist_fn=dist_fn, - action_space=action_space, vf_coef=vf_coef, ent_coef=ent_coef, max_grad_norm=max_grad_norm, @@ -133,10 +115,6 @@ def __init__( max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, lr_scheduler=lr_scheduler, ) self.eps_clip = eps_clip @@ -160,7 +138,7 @@ def process_fn( logp_old = [] with torch.no_grad(): for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): - logp_old.append(self(minibatch).dist.log_prob(minibatch.act)) + logp_old.append(self.policy(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(logp_old, dim=0).flatten() batch: LogpOldProtocol return batch @@ -184,7 +162,7 @@ def _update_with_batch( # type: ignore gradient_steps += 1 # calculate loss for actor advantages = minibatch.adv - dist = self(minibatch).dist + dist = self.policy(minibatch).dist if self.norm_adv: mean, std = advantages.mean(), advantages.std() advantages = (advantages - mean) / (std + self._eps) # per-batch norm From 60f19e1660c76132f45774fb90085d4f66db3e7e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 13:34:20 +0100 Subject: [PATCH 015/230] v2: Adapt TD3 and test_td3 --- examples/mujoco/mujoco_td3.py | 6 +-- examples/offline/d4rl_td3_bc.py | 2 +- test/continuous/test_td3.py | 54 +++++++++---------- test/offline/test_td3_bc.py | 2 +- tianshou/highlevel/algorithm.py | 8 +-- tianshou/policy/__init__.py | 4 +- tianshou/policy/imitation/td3_bc.py | 10 ++-- tianshou/policy/modelfree/td3.py | 81 +++++++++++++---------------- 8 files changed, 80 insertions(+), 87 deletions(-) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 31d7f1370..46c8fe509 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import TD3Policy +from tianshou.policy import TD3 from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net @@ -112,9 +112,9 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy: TD3Policy = TD3Policy( + policy: TD3 = TD3( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 47c05bda8..a9cc7d8ce 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -137,7 +137,7 @@ def test_td3_bc() -> None: policy: TD3BCPolicy = TD3BCPolicy( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 41a339cb5..13c7a1db9 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -1,6 +1,5 @@ import argparse import os -import pprint import gymnasium as gym import numpy as np @@ -10,9 +9,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3Policy +from tianshou.policy import TD3 from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -98,9 +98,13 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy: TD3Policy = TD3Policy( + policy = DDPGPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: TD3 = TD3( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, @@ -112,16 +116,15 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, estimation_step=args.n_step, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "td3") @@ -134,24 +137,21 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # Iterator trainer - trainer = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) ) - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - pprint.pprint(epoch_stat) - # print(info) - assert stop_fn(epoch_stat.info_stat.best_reward) + assert stop_fn(result.best_reward) diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index cae6f6f06..70aee4ba3 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -129,7 +129,7 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: policy: TD3BCPolicy = TD3BCPolicy( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 88eda1b76..23916c609 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -49,6 +49,7 @@ A2C, DDPG, PPO, + TD3, Algorithm, DeepQLearning, DiscreteSACPolicy, @@ -57,7 +58,6 @@ REDQPolicy, Reinforce, SACPolicy, - TD3Policy, TRPOPolicy, ) from tianshou.policy.base import ( @@ -685,6 +685,6 @@ def _get_policy_class(self) -> type[DiscreteSACPolicy]: return DiscreteSACPolicy -class TD3AlgorithmFactory(ActorDualCriticsAlgorithmFactory[TD3Params, TD3Policy]): - def _get_policy_class(self) -> type[TD3Policy]: - return TD3Policy +class TD3AlgorithmFactory(ActorDualCriticsAlgorithmFactory[TD3Params, TD3]): + def _get_policy_class(self) -> type[TD3]: + return TD3 diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index e8bb6e75d..b9777616f 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -17,7 +17,7 @@ from tianshou.policy.modelfree.npg import NPGPolicy from tianshou.policy.modelfree.ppo import PPO from tianshou.policy.modelfree.trpo import TRPOPolicy -from tianshou.policy.modelfree.td3 import TD3Policy +from tianshou.policy.modelfree.td3 import TD3 from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.redq import REDQPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy @@ -49,7 +49,7 @@ "DDPG", "PPO", "TRPOPolicy", - "TD3Policy", + "TD3", "SACPolicy", "REDQPolicy", "DiscreteSACPolicy", diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index f11b88a5a..71ed24163 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -8,7 +8,7 @@ from tianshou.data import to_torch_as from tianshou.data.types import RolloutBatchProtocol from tianshou.exploration import BaseNoise, GaussianNoise -from tianshou.policy import TD3Policy +from tianshou.policy import TD3 from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.td3 import TD3TrainingStats @@ -21,12 +21,12 @@ class TD3BCTrainingStats(TD3TrainingStats): TTD3BCTrainingStats = TypeVar("TTD3BCTrainingStats", bound=TD3BCTrainingStats) -class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]): +class TD3BCPolicy(TD3[TTD3BCTrainingStats]): """Implementation of TD3+BC. arXiv:2106.06860. :param actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> actions) - :param actor_optim: the optimizer for actor network. + :param policy_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. :param action_space: Env's action space. Should be gym.spaces.Box. @@ -62,7 +62,7 @@ def __init__( self, *, actor: torch.nn.Module, - actor_optim: torch.optim.Optimizer, + policy_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, action_space: gym.Space, @@ -84,7 +84,7 @@ def __init__( ) -> None: super().__init__( actor=actor, - actor_optim=actor_optim, + policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, action_space=action_space, diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 66673a805..12f60e851 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Any, Generic, Literal, Self, TypeVar -import gymnasium as gym import numpy as np import torch @@ -11,6 +10,7 @@ from tianshou.exploration import BaseNoise from tianshou.policy import DDPG from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.utils.optim import clone_optimizer @@ -25,35 +25,9 @@ class TD3TrainingStats(TrainingStats): # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class TD3Policy(DDPG[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] +class TD3(DDPG[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] """Implementation of TD3, arXiv:1802.09477. - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> actions) - :param actor_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. - :param exploration_noise: add noise to action for exploration. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). - :param policy_noise: the noise used in updating policy network. - :param update_actor_freq: the update frequency of actor network. - :param noise_clip: the clipping range used in updating policy network. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() - .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed @@ -63,11 +37,10 @@ class TD3Policy(DDPG[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: i def __init__( self, *, - actor: torch.nn.Module, - actor_optim: torch.optim.Optimizer, + policy: DDPGPolicy, + policy_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, - action_space: gym.Space, critic2: torch.nn.Module | None = None, critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, @@ -77,26 +50,46 @@ def __init__( update_actor_freq: int = 2, noise_clip: float = 0.5, estimation_step: int = 1, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> actions) + :param policy_optim: the optimizer for actor network. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer for the first critic network. + :param action_space: Env's action space. Should be gym.spaces.Box. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param policy_noise: the noise used in updating policy network. + :param update_actor_freq: the update frequency of actor network. + :param noise_clip: the clipping range used in updating policy network. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + """ # TODO: reduce duplication with SAC. # Some intermediate class, like TwoCriticPolicy? super().__init__( - actor=actor, - policy_optim=actor_optim, + policy=policy, + policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, - action_space=action_space, tau=tau, gamma=gamma, exploration_noise=exploration_noise, estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) if critic2 and not critic2_optim: @@ -115,7 +108,7 @@ def __init__( def train(self, mode: bool = True) -> Self: self.training = mode - self.actor.train(mode) + self.policy.train(mode) self.critic.train(mode) self.critic2.train(mode) return self @@ -123,14 +116,14 @@ def train(self, mode: bool = True) -> Self: def sync_weight(self) -> None: self.soft_update(self.critic_old, self.critic, self.tau) self.soft_update(self.critic2_old, self.critic2, self.tau) - self.soft_update(self.actor_old, self.actor, self.tau) + self.soft_update(self.actor_old, self.policy.actor, self.tau) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} - act_ = self(obs_next_batch, model="actor_old").act + act_ = self.policy(obs_next_batch, model=self.actor_old).act noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise if self.noise_clip > 0.0: noise = noise.clamp(-self.noise_clip, self.noise_clip) @@ -148,7 +141,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: # actor if self._cnt % self.update_actor_freq == 0: - actor_loss = -self.critic(batch.obs, self(batch, eps=0.0).act).mean() + actor_loss = -self.critic(batch.obs, self.policy(batch, eps=0.0).act).mean() self.policy_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() From 9ef9a8461e910e2837f1a8d2264af88dfc4ecdb2 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 13:45:11 +0100 Subject: [PATCH 016/230] v2: Adapt NPG and test_npg --- examples/mujoco/mujoco_npg.py | 4 +- test/continuous/test_npg.py | 49 +++++++++--------- tianshou/highlevel/algorithm.py | 8 +-- tianshou/policy/__init__.py | 4 +- tianshou/policy/modelfree/npg.py | 82 ++++++++++++------------------- tianshou/policy/modelfree/trpo.py | 4 +- 6 files changed, 68 insertions(+), 83 deletions(-) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 8b8294100..db5af6818 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import NPGPolicy +from tianshou.policy import NPG from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import Net @@ -138,7 +138,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: NPGPolicy = NPGPolicy( + policy: NPG = NPG( actor=actor, critic=critic, optim=optim, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 393225e33..4879afec6 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -10,10 +10,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import NPGPolicy +from tianshou.policy import NPG from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.npg import NPGTrainingStats -from tianshou.trainer import OnpolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -106,27 +107,30 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: NPGPolicy[NPGTrainingStats] = NPGPolicy( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + deterministic_eval=True, + ) + algorithm: NPG[NPGTrainingStats] = NPG( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, gae_lambda=args.gae_lambda, - action_space=env.action_space, optim_critic_iters=args.optim_critic_iters, actor_step_size=args.actor_step_size, - deterministic_eval=True, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "npg") writer = SummaryWriter(log_path) @@ -139,18 +143,19 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 23916c609..6368b4283 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -48,13 +48,13 @@ from tianshou.policy import ( A2C, DDPG, + NPG, PPO, TD3, Algorithm, DeepQLearning, DiscreteSACPolicy, IQNPolicy, - NPGPolicy, REDQPolicy, Reinforce, SACPolicy, @@ -393,9 +393,9 @@ def _get_policy_class(self) -> type[PPO]: return PPO -class NPGAlgorithmFactory(ActorCriticAlgorithmFactory[NPGParams, NPGPolicy]): - def _get_policy_class(self) -> type[NPGPolicy]: - return NPGPolicy +class NPGAlgorithmFactory(ActorCriticAlgorithmFactory[NPGParams, NPG]): + def _get_policy_class(self) -> type[NPG]: + return NPG class TRPOAlgorithmFactory(ActorCriticAlgorithmFactory[TRPOParams, TRPOPolicy]): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index b9777616f..a1b623012 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -14,7 +14,7 @@ from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.modelfree.fqf import FQFPolicy from tianshou.policy.modelfree.a2c import A2C -from tianshou.policy.modelfree.npg import NPGPolicy +from tianshou.policy.modelfree.npg import NPG from tianshou.policy.modelfree.ppo import PPO from tianshou.policy.modelfree.trpo import TRPOPolicy from tianshou.policy.modelfree.td3 import TD3 @@ -45,7 +45,7 @@ "FQFPolicy", "Reinforce", "A2C", - "NPGPolicy", + "NPG", "DDPG", "PPO", "TRPOPolicy", diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index bb1043df2..fc006deaa 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar +from typing import Any, Generic, TypeVar -import gymnasium as gym import numpy as np import torch import torch.nn.functional as F @@ -12,9 +11,8 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import A2C from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -29,42 +27,18 @@ class NPGTrainingStats(TrainingStats): # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class NPGPolicy(A2C[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] +class NPG(A2C[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] """Implementation of Natural Policy Gradient. https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param optim_critic_iters: Number of times to optimize critic network per update. - :param actor_step_size: step size for actor update in natural gradient direction. - :param advantage_normalization: whether to do per mini-batch advantage - normalization. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ def __init__( self, *, - actor: torch.nn.Module | ActorProb | DiscreteActor, + policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, optim_critic_iters: int = 5, actor_step_size: float = 0.5, advantage_normalization: bool = True, @@ -73,18 +47,26 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param optim_critic_iters: Number of times to optimize critic network per update. + :param actor_step_size: step size for actor update in natural gradient direction. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ super().__init__( - actor=actor, + policy=policy, critic=critic, optim=optim, - dist_fn=dist_fn, - action_space=action_space, # TODO: violates Liskov substitution principle, see the del statement below vf_coef=None, # type: ignore ent_coef=None, # type: ignore @@ -93,10 +75,6 @@ def __init__( max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, lr_scheduler=lr_scheduler, ) # TODO: see above, it ain't pretty... @@ -117,7 +95,7 @@ def process_fn( old_log_prob = [] with torch.no_grad(): for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): - old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act)) + old_log_prob.append(self.policy(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) if self.norm_adv: batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() @@ -136,29 +114,31 @@ def _update_with_batch( # type: ignore for minibatch in batch.split(split_batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient - dist = self(minibatch).dist + dist = self.policy(minibatch).dist log_prob = dist.log_prob(minibatch.act) log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) actor_loss = -(log_prob * minibatch.adv).mean() - flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() + flat_grads = self._get_flat_grad( + actor_loss, self.policy.actor, retain_graph=True + ).detach() # direction: calculate natural gradient with torch.no_grad(): - old_dist = self(minibatch).dist + old_dist = self.policy(minibatch).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta - flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) + flat_kl_grad = self._get_flat_grad(kl, self.policy.actor, create_graph=True) search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) # step with torch.no_grad(): flat_params = torch.cat( - [param.data.view(-1) for param in self.actor.parameters()], + [param.data.view(-1) for param in self.policy.actor.parameters()], ) new_flat_params = flat_params + self.actor_step_size * search_direction - self._set_from_flat_params(self.actor, new_flat_params) - new_dist = self(minibatch).dist + self._set_from_flat_params(self.policy.actor, new_flat_params) + new_dist = self.policy(minibatch).dist kl = kl_divergence(old_dist, new_dist).mean() # optimize critic @@ -187,7 +167,7 @@ def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: """Matrix vector product.""" # caculate second order gradient of kl with respect to theta kl_v = (flat_kl_grad * v).sum() - flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor, retain_graph=True).detach() + flat_kl_grad_grad = self._get_flat_grad(kl_v, self.policy.actor, retain_graph=True).detach() return flat_kl_grad_grad + v * self._damping def _conjugate_gradients( diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index c8530e258..0075fce08 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -8,7 +8,7 @@ from torch.distributions import kl_divergence from tianshou.data import Batch, SequenceSummaryStats -from tianshou.policy import NPGPolicy +from tianshou.policy import NPG from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont @@ -25,7 +25,7 @@ class TRPOTrainingStats(NPGTrainingStats): TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats) -class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): +class TRPOPolicy(NPG[TTRPOTrainingStats]): """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. :param actor: the actor network following the rules: From 2555eb8457c1bca4bc92a4e424855d95b70d6223 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 14:20:57 +0100 Subject: [PATCH 017/230] v2: Fix class hierarchy issues: NPG now no longer inherits from A2C but from a the new abstract base class AbstractActorCriticWithAdvantage (which A2C also inherits from) --- CHANGELOG.md | 2 + tianshou/policy/modelfree/a2c.py | 103 +++++++++++++++++++------------ tianshou/policy/modelfree/npg.py | 16 ++--- 3 files changed, 72 insertions(+), 49 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b05273746..cdb95a920 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ * `DQNPolicy` -> `DeepQLearning` * `DDPGPolicy` -> `DDPG` * The `Algorithm` abstraction can directly initiate the learning process via method `run_training`. + * Fixed issues in the class hierarchy (e.g. violations of the Liskov substitution principle): + * `NPG` no longer inherits from `A2C` but from a new abstract base class * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 046315509..f210800e5 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,3 +1,4 @@ +from abc import ABC from dataclasses import dataclass from typing import Any, Generic, TypeVar, cast @@ -10,7 +11,7 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import Reinforce from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicy, TPGTrainingStats from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -27,39 +28,16 @@ class A2CTrainingStats(TrainingStats): TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) -class A2C(Reinforce[TA2CTrainingStats], Generic[TA2CTrainingStats]): - """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. - - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - +class AbstractActorCriticWithAdvantage(Reinforce[TPGTrainingStats], Generic[TPGTrainingStats], ABC): def __init__( self, *, policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - vf_coef: float = 0.5, - ent_coef: float = 0.01, - max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, discount_factor: float = 0.99, - # TODO: rename to return_normalization? reward_normalization: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: @@ -73,22 +51,9 @@ def __init__( self.critic = critic assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}" self.gae_lambda = gae_lambda - self.vf_coef = vf_coef - self.ent_coef = ent_coef - self.max_grad_norm = max_grad_norm self.max_batchsize = max_batchsize self._actor_critic = ActorCritic(self.policy.actor, self.critic) - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> BatchWithAdvantagesProtocol: - batch = self._compute_returns(batch, buffer, indices) - batch.act = to_torch_as(batch.act, batch.v_s) - return batch - def _compute_returns( self, batch: RolloutBatchProtocol, @@ -129,6 +94,68 @@ def _compute_returns( batch.adv = to_torch_as(advantages, batch.v_s) return cast(BatchWithAdvantagesProtocol, batch) + +class A2C(AbstractActorCriticWithAdvantage[TA2CTrainingStats], Generic[TA2CTrainingStats]): + """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + *, + policy: ActorPolicy, + critic: torch.nn.Module | Critic | DiscreteCritic, + optim: torch.optim.Optimizer, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: float | None = None, + gae_lambda: float = 0.95, + max_batchsize: int = 256, + discount_factor: float = 0.99, + # TODO: rename to return_normalization? + reward_normalization: bool = False, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param vf_coef: weight for value loss. + :param ent_coef: weight for entropy loss. + :param max_grad_norm: clipping gradients in back propagation. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + critic=critic, + optim=optim, + gae_lambda=gae_lambda, + max_batchsize=max_batchsize, + discount_factor=discount_factor, + reward_normalization=reward_normalization, + lr_scheduler=lr_scheduler, + ) + self.vf_coef = vf_coef + self.ent_coef = ent_coef + self.max_grad_norm = max_grad_norm + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithAdvantagesProtocol: + batch = self._compute_returns(batch, buffer, indices) + batch.act = to_torch_as(batch.act, batch.v_s) + return batch + # TODO: mypy complains b/c signature is different from superclass, although # it's compatible. Can this be fixed? def _update_with_batch( # type: ignore diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index fc006deaa..1c9571295 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -7,10 +7,10 @@ from torch import nn from torch.distributions import kl_divergence -from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats +from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy import A2C from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.modelfree.a2c import AbstractActorCriticWithAdvantage from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -26,8 +26,7 @@ class NPGTrainingStats(TrainingStats): TNPGTrainingStats = TypeVar("TNPGTrainingStats", bound=NPGTrainingStats) -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class NPG(A2C[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] +class NPG(AbstractActorCriticWithAdvantage[TNPGTrainingStats], Generic[TNPGTrainingStats]): """Implementation of Natural Policy Gradient. https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf @@ -67,18 +66,12 @@ def __init__( policy=policy, critic=critic, optim=optim, - # TODO: violates Liskov substitution principle, see the del statement below - vf_coef=None, # type: ignore - ent_coef=None, # type: ignore - max_grad_norm=None, gae_lambda=gae_lambda, max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, lr_scheduler=lr_scheduler, ) - # TODO: see above, it ain't pretty... - del self.vf_coef, self.ent_coef, self.max_grad_norm self.norm_adv = advantage_normalization self.optim_critic_iters = optim_critic_iters self.actor_step_size = actor_step_size @@ -91,7 +84,8 @@ def process_fn( buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithAdvantagesProtocol: - batch = super().process_fn(batch, buffer, indices) + batch = self._compute_returns(batch, buffer, indices) + batch.act = to_torch_as(batch.act, batch.v_s) old_log_prob = [] with torch.no_grad(): for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): From 629905b92efb9530ebe34f4a95670d7c600878ba Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 14:32:43 +0100 Subject: [PATCH 018/230] v2: Adapt TRPO and test_trpo --- examples/mujoco/mujoco_trpo.py | 4 +- test/continuous/test_trpo.py | 49 +++++++++-------- tianshou/highlevel/algorithm.py | 8 +-- tianshou/policy/__init__.py | 4 +- tianshou/policy/modelfree/trpo.py | 90 ++++++++++++------------------- 5 files changed, 69 insertions(+), 86 deletions(-) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index e1357afef..350439a70 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import TRPOPolicy +from tianshou.policy import TRPO from tianshou.policy.base import Algorithm from tianshou.trainer import OnpolicyTrainer from tianshou.utils.net.common import Net @@ -141,7 +141,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: TRPOPolicy = TRPOPolicy( + policy: TRPO = TRPO( actor=actor, critic=critic, optim=optim, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 3d1ffda36..1f1d1fecb 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -10,9 +10,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import TRPOPolicy +from tianshou.policy import TRPO from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -105,16 +106,19 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: TRPOPolicy = TRPOPolicy( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + ) + algorithm: TRPO = TRPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, gae_lambda=args.gae_lambda, - action_space=env.action_space, optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, backtrack_coeff=args.backtrack_coeff, @@ -122,11 +126,11 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "trpo") writer = SummaryWriter(log_path) @@ -138,19 +142,20 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 6368b4283..39a24877c 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -51,6 +51,7 @@ NPG, PPO, TD3, + TRPO, Algorithm, DeepQLearning, DiscreteSACPolicy, @@ -58,7 +59,6 @@ REDQPolicy, Reinforce, SACPolicy, - TRPOPolicy, ) from tianshou.policy.base import ( OffPolicyAlgorithm, @@ -398,9 +398,9 @@ def _get_policy_class(self) -> type[NPG]: return NPG -class TRPOAlgorithmFactory(ActorCriticAlgorithmFactory[TRPOParams, TRPOPolicy]): - def _get_policy_class(self) -> type[TRPOPolicy]: - return TRPOPolicy +class TRPOAlgorithmFactory(ActorCriticAlgorithmFactory[TRPOParams, TRPO]): + def _get_policy_class(self) -> type[TRPO]: + return TRPO class DiscreteCriticOnlyAlgorithmFactory( diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index a1b623012..ed3ce71d9 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -16,7 +16,7 @@ from tianshou.policy.modelfree.a2c import A2C from tianshou.policy.modelfree.npg import NPG from tianshou.policy.modelfree.ppo import PPO -from tianshou.policy.modelfree.trpo import TRPOPolicy +from tianshou.policy.modelfree.trpo import TRPO from tianshou.policy.modelfree.td3 import TD3 from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.redq import REDQPolicy @@ -48,7 +48,7 @@ "NPG", "DDPG", "PPO", - "TRPOPolicy", + "TRPO", "TD3", "SACPolicy", "REDQPolicy", diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 0075fce08..f84433f85 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -1,8 +1,7 @@ import warnings from dataclasses import dataclass -from typing import Any, Literal, TypeVar +from typing import Any, TypeVar -import gymnasium as gym import torch import torch.nn.functional as F from torch.distributions import kl_divergence @@ -11,9 +10,8 @@ from tianshou.policy import NPG from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.npg import NPGTrainingStats -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -25,44 +23,15 @@ class TRPOTrainingStats(NPGTrainingStats): TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats) -class TRPOPolicy(NPG[TTRPOTrainingStats]): - """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param max_kl: max kl-divergence used to constrain each actor network update. - :param backtrack_coeff: Coefficient to be multiplied by step size when - constraints are not met. - :param max_backtracks: Max number of backtracking times in linesearch. - :param optim_critic_iters: Number of times to optimize critic network per update. - :param actor_step_size: step size for actor update in natural gradient direction. - :param advantage_normalization: whether to do per mini-batch advantage - normalization. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - :param lr_scheduler: if not None, will be called in `policy.update()`. - """ +class TRPO(NPG[TTRPOTrainingStats]): + """Implementation of Trust Region Policy Optimization. arXiv:1502.05477.""" def __init__( self, *, - actor: torch.nn.Module | ActorProb | DiscreteActor, + policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, max_kl: float = 0.01, backtrack_coeff: float = 0.8, max_backtracks: int = 10, @@ -74,18 +43,29 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param max_kl: max kl-divergence used to constrain each actor network update. + :param backtrack_coeff: Coefficient to be multiplied by step size when + constraints are not met. + :param max_backtracks: Max number of backtracking times in linesearch. + :param optim_critic_iters: Number of times to optimize critic network per update. + :param actor_step_size: step size for actor update in natural gradient direction. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ super().__init__( - actor=actor, + policy=policy, critic=critic, optim=optim, - dist_fn=dist_fn, - action_space=action_space, optim_critic_iters=optim_critic_iters, actor_step_size=actor_step_size, advantage_normalization=advantage_normalization, @@ -93,10 +73,6 @@ def __init__( max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, lr_scheduler=lr_scheduler, ) self.max_backtracks = max_backtracks @@ -116,19 +92,21 @@ def _update_with_batch( # type: ignore for minibatch in batch.split(split_batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient - dist = self(minibatch).dist # TODO could come from batch + dist = self.policy(minibatch).dist # TODO could come from batch ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * minibatch.adv).mean() - flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() + flat_grads = self._get_flat_grad( + actor_loss, self.policy.actor, retain_graph=True + ).detach() # direction: calculate natural gradient with torch.no_grad(): - old_dist = self(minibatch).dist + old_dist = self.policy(minibatch).dist kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta - flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) + flat_kl_grad = self._get_flat_grad(kl, self.policy.actor, create_graph=True) search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, nsteps=10) # stepsize: calculate max stepsize constrained by kl bound @@ -144,13 +122,13 @@ def _update_with_batch( # type: ignore # stepsize: linesearch stepsize with torch.no_grad(): flat_params = torch.cat( - [param.data.view(-1) for param in self.actor.parameters()], + [param.data.view(-1) for param in self.policy.actor.parameters()], ) for i in range(self.max_backtracks): new_flat_params = flat_params + step_size * search_direction - self._set_from_flat_params(self.actor, new_flat_params) + self._set_from_flat_params(self.policy.actor, new_flat_params) # calculate kl and if in bound, loss actually down - new_dist = self(minibatch).dist + new_dist = self.policy(minibatch).dist new_dratio = ( (new_dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ) @@ -165,7 +143,7 @@ def _update_with_batch( # type: ignore if i < self.max_backtracks - 1: step_size = step_size * self.backtrack_coeff else: - self._set_from_flat_params(self.actor, new_flat_params) + self._set_from_flat_params(self.policy.actor, new_flat_params) step_size = torch.tensor([0.0]) warnings.warn( "Line search failed! It seems hyperparamters" From 571ce919ba9873a6de1ba9f89ef92e6f407fc63d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 15:00:37 +0100 Subject: [PATCH 019/230] Add registration of log configuration callback to tianshou.__init__ This takes effect for examples using sensai.util.logging --- tianshou/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index a43d87ed8..cfa162a43 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -2,6 +2,19 @@ __version__ = "1.2.0-dev" + +def _register_log_config_callback(): + from sensai.util import logging + + def configure(): + logging.getLogger("numba").setLevel(logging.INFO) + + logging.set_configure_callback(configure) + + +_register_log_config_callback() + + __all__ = [ "env", "data", From 85d3a199ab2f6001d4f34cb6be7e5588cc68e17e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 15:06:15 +0100 Subject: [PATCH 020/230] v2: Restore high-level API support for A2C, PPO, TRPO, NPG --- tianshou/highlevel/algorithm.py | 39 +++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 39a24877c..72da2850d 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -345,7 +345,7 @@ def __init__( self.critic_use_action = False @abstractmethod - def _get_policy_class(self) -> type[TAlgorithm]: + def _get_algorithm_class(self) -> type[TAlgorithm]: pass def create_actor_critic_module_opt( @@ -375,31 +375,46 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: kwargs["critic"] = actor_critic.critic kwargs["optim"] = actor_critic.optim kwargs["action_space"] = envs.get_action_space() + kwargs["observation_space"] = envs.get_observation_space() kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs) return kwargs def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: - policy_class = self._get_policy_class() - return policy_class(**self._create_kwargs(envs, device)) + params = self._create_kwargs(envs, device) + policy = self._create_policy( + ActorPolicy, + params, + [ + "actor", + "dist_fn", + "action_space", + "deterministic_eval", + "observation_space", + "action_scaling", + "action_bound_method", + ], + ) + algorithm_class = self._get_algorithm_class() + return algorithm_class(policy=policy, **params) class A2CAlgorithmFactory(ActorCriticAlgorithmFactory[A2CParams, A2C]): - def _get_policy_class(self) -> type[A2C]: + def _get_algorithm_class(self) -> type[A2C]: return A2C class PPOAlgorithmFactory(ActorCriticAlgorithmFactory[PPOParams, PPO]): - def _get_policy_class(self) -> type[PPO]: + def _get_algorithm_class(self) -> type[PPO]: return PPO class NPGAlgorithmFactory(ActorCriticAlgorithmFactory[NPGParams, NPG]): - def _get_policy_class(self) -> type[NPG]: + def _get_algorithm_class(self) -> type[NPG]: return NPG class TRPOAlgorithmFactory(ActorCriticAlgorithmFactory[TRPOParams, TRPO]): - def _get_policy_class(self) -> type[TRPO]: + def _get_algorithm_class(self) -> type[TRPO]: return TRPO @@ -613,7 +628,7 @@ def __init__( self.optim_factory = optim_factory @abstractmethod - def _get_policy_class(self) -> type[TAlgorithm]: + def _get_algorithm_class(self) -> type[TAlgorithm]: pass def _get_discrete_last_size_use_action_shape(self) -> bool: @@ -659,7 +674,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: critic2=critic2, ), ) - policy_class = self._get_policy_class() + policy_class = self._get_algorithm_class() return policy_class( actor=actor.module, actor_optim=actor.optim, @@ -674,17 +689,17 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: class SACAlgorithmFactory(ActorDualCriticsAlgorithmFactory[SACParams, SACPolicy]): - def _get_policy_class(self) -> type[SACPolicy]: + def _get_algorithm_class(self) -> type[SACPolicy]: return SACPolicy class DiscreteSACAlgorithmFactory( ActorDualCriticsAlgorithmFactory[DiscreteSACParams, DiscreteSACPolicy] ): - def _get_policy_class(self) -> type[DiscreteSACPolicy]: + def _get_algorithm_class(self) -> type[DiscreteSACPolicy]: return DiscreteSACPolicy class TD3AlgorithmFactory(ActorDualCriticsAlgorithmFactory[TD3Params, TD3]): - def _get_policy_class(self) -> type[TD3]: + def _get_algorithm_class(self) -> type[TD3]: return TD3 From 7f79994f3d1405e698e9f0fded1ec4ab51a4b3db Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 15:29:53 +0100 Subject: [PATCH 021/230] Fix method reference (map_action_inverse) --- tianshou/data/collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 06307a829..1c34aa130 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -691,7 +691,7 @@ def _compute_action_policy_hidden( # TODO: test whether envpool env explicitly except TypeError: # envpool's action space is not for per-env act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) - act_RA = self.algorithm.map_action_inverse(np.array(act_normalized_RA)) + act_RA = self.algorithm.policy.map_action_inverse(np.array(act_normalized_RA)) policy_R = Batch() hidden_state_RH = None # TODO: instead use a (uniform) Distribution instance that corresponds to sampling from action_space From 74f595688ba14c9f9bdd242418ac1e3b2efea483 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 15:30:40 +0100 Subject: [PATCH 022/230] v2: Restore high-level API support for TD3 --- tianshou/highlevel/algorithm.py | 56 +++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 72da2850d..1e2463343 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -72,6 +72,7 @@ from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.discrete import Actor CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" @@ -154,7 +155,7 @@ def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: self.trainer_callbacks = callbacks @staticmethod - def _create_policy( + def _create_policy_from_args( constructor: type[TPolicy], params_dict: dict, policy_params: list[str], **kwargs ) -> TPolicy: params = {p: params_dict.pop(p) for p in policy_params} @@ -308,7 +309,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: ) dist_fn = self.actor_factory.create_dist_fn(envs) assert dist_fn is not None - policy = self._create_policy( + policy = self._create_policy_from_args( ActorPolicy, kwargs, ["action_scaling", "action_bound_method", "deterministic_eval"], @@ -381,7 +382,7 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: params = self._create_kwargs(envs, device) - policy = self._create_policy( + policy = self._create_policy_from_args( ActorPolicy, params, [ @@ -439,7 +440,7 @@ def _get_algorithm_class(self) -> type[TAlgorithm]: pass @abstractmethod - def _create_discrete_critic_only_policy( + def _create_policy( self, model: torch.nn.Module, params: dict, @@ -462,9 +463,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: ) envs.get_type().assert_discrete(self) action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) - policy = self._create_discrete_critic_only_policy( - model, params_dict, action_space, envs.get_observation_space() - ) + policy = self._create_policy(model, params_dict, action_space, envs.get_observation_space()) algorithm_class = self._get_algorithm_class() return algorithm_class( policy=policy, @@ -474,14 +473,14 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: class DeepQLearningAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DeepQLearning]): - def _create_discrete_critic_only_policy( + def _create_policy( self, model: torch.nn.Module, params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, ) -> TPolicy: - return self._create_policy( + return self._create_policy_from_args( constructor=DQNPolicy, params_dict=params, policy_params=[], @@ -537,7 +536,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: critic1=critic, ), ) - policy = self._create_policy( + policy = self._create_policy_from_args( DDPGPolicy, kwargs, ["action_scaling", "action_bound_method"], @@ -608,7 +607,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: class ActorDualCriticsAlgorithmFactory( OffPolicyAlgorithmFactory, - Generic[TActorDualCriticsParams, TAlgorithm], + Generic[TActorDualCriticsParams, TAlgorithm, TPolicy], ABC, ): def __init__( @@ -638,6 +637,12 @@ def _get_discrete_last_size_use_action_shape(self) -> bool: def _get_critic_use_action(envs: Environments) -> bool: return envs.get_type().is_continuous() + @abstractmethod + def _create_policy( + self, actor: torch.nn.Module | Actor, envs: Environments, params: dict + ) -> TPolicy: + pass + @typing.no_type_check def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: actor = self.actor_factory.create_module_opt( @@ -674,32 +679,43 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: critic2=critic2, ), ) - policy_class = self._get_algorithm_class() - return policy_class( - actor=actor.module, - actor_optim=actor.optim, + policy = self._create_policy(actor.module, envs, kwargs) + algorithm_class = self._get_algorithm_class() + return algorithm_class( + policy=policy, + policy_optim=actor.optim, critic=critic1.module, critic_optim=critic1.optim, critic2=critic2.module, critic2_optim=critic2.optim, - action_space=envs.get_action_space(), - observation_space=envs.get_observation_space(), **kwargs, ) -class SACAlgorithmFactory(ActorDualCriticsAlgorithmFactory[SACParams, SACPolicy]): +class SACAlgorithmFactory(ActorDualCriticsAlgorithmFactory[SACParams, SACPolicy, TPolicy]): def _get_algorithm_class(self) -> type[SACPolicy]: return SACPolicy class DiscreteSACAlgorithmFactory( - ActorDualCriticsAlgorithmFactory[DiscreteSACParams, DiscreteSACPolicy] + ActorDualCriticsAlgorithmFactory[DiscreteSACParams, DiscreteSACPolicy, TPolicy] ): def _get_algorithm_class(self) -> type[DiscreteSACPolicy]: return DiscreteSACPolicy -class TD3AlgorithmFactory(ActorDualCriticsAlgorithmFactory[TD3Params, TD3]): +class TD3AlgorithmFactory(ActorDualCriticsAlgorithmFactory[TD3Params, TD3, DDPGPolicy]): + def _create_policy( + self, actor: torch.nn.Module | Actor, envs: Environments, params: dict + ) -> DDPGPolicy: + return self._create_policy_from_args( + DDPGPolicy, + params, + ["action_scaling", "action_bound_method"], + actor=actor, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + def _get_algorithm_class(self) -> type[TD3]: return TD3 From e36bbd612dd7ec077a3e7353e2c6f823793d0165 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 15:59:07 +0100 Subject: [PATCH 023/230] v2: Adapt SAC and test_sac_with_il (without the il part) --- examples/atari/atari_sac.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 6 +- examples/box2d/mcc_sac.py | 6 +- examples/mujoco/mujoco_sac.py | 6 +- examples/offline/d4rl_cql.py | 2 +- test/continuous/test_sac_with_il.py | 48 +++--- test/discrete/test_sac.py | 2 +- test/offline/gather_pendulum_data.py | 6 +- test/offline/test_cql.py | 2 +- tianshou/highlevel/algorithm.py | 8 +- tianshou/policy/__init__.py | 4 +- tianshou/policy/base.py | 20 ++- tianshou/policy/imitation/cql.py | 10 +- tianshou/policy/modelfree/discrete_sac.py | 10 +- tianshou/policy/modelfree/sac.py | 170 ++++++++++++---------- 15 files changed, 164 insertions(+), 138 deletions(-) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 24e7d6c86..d9d427f34 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -127,7 +127,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: policy: DiscreteSACPolicy | ICMPolicy policy = DiscreteSACPolicy( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 9c46b4dc6..0961347c5 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.policy import SACPolicy +from tianshou.policy import SAC from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -144,9 +144,9 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: SACPolicy = SACPolicy( + policy: SAC = SAC( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 838c40a37..45246dc60 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise -from tianshou.policy import SACPolicy +from tianshou.policy import SAC from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -95,9 +95,9 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: SACPolicy = SACPolicy( + policy: SAC = SAC( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 07c91085b..1ca6fb396 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import SACPolicy +from tianshou.policy import SAC from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import Net @@ -117,9 +117,9 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: SACPolicy = SACPolicy( + policy: SAC = SAC( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index a896fffb7..23e95fcf9 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -286,7 +286,7 @@ def test_cql() -> None: policy: CQLPolicy = CQLPolicy( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, action_space=env.action_space, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 6634501a3..2c80429bf 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -8,9 +8,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import ImitationPolicy, SACPolicy +from tianshou.policy import SAC, ImitationPolicy from tianshou.policy.base import Algorithm +from tianshou.policy.modelfree.sac import SACPolicy from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic @@ -110,9 +112,13 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: SACPolicy = SACPolicy( + policy = SACPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm = SAC( + policy=policy, + policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, @@ -121,16 +127,15 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: gamma=args.gamma, alpha=args.alpha, estimation_step=args.n_step, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") @@ -143,21 +148,22 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) # here we define an imitation collector with a trivial policy diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 432d8c2c8..06a0e9fd9 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -94,7 +94,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: policy: DiscreteSACPolicy[DiscreteSACTrainingStats] = DiscreteSACPolicy( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic1, action_space=env.action_space, critic_optim=critic1_optim, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index feca48794..a698530a6 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import SACPolicy +from tianshou.policy import SAC from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.sac import SACTrainingStats from tianshou.trainer import OffpolicyTrainer @@ -116,9 +116,9 @@ def gather_data() -> VectorReplayBuffer: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: SACPolicy[SACTrainingStats] = SACPolicy( + policy: SAC[SACTrainingStats] = SAC( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 56be57a9a..8eac8fee3 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -138,7 +138,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: policy: CQLPolicy[CQLTrainingStats] = CQLPolicy( actor=actor, - actor_optim=actor_optim, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, # CQL seems to perform better without action scaling diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 1e2463343..e18f7771f 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -50,6 +50,7 @@ DDPG, NPG, PPO, + SAC, TD3, TRPO, Algorithm, @@ -58,7 +59,6 @@ IQNPolicy, REDQPolicy, Reinforce, - SACPolicy, ) from tianshou.policy.base import ( OffPolicyAlgorithm, @@ -692,9 +692,9 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: ) -class SACAlgorithmFactory(ActorDualCriticsAlgorithmFactory[SACParams, SACPolicy, TPolicy]): - def _get_algorithm_class(self) -> type[SACPolicy]: - return SACPolicy +class SACAlgorithmFactory(ActorDualCriticsAlgorithmFactory[SACParams, SAC, TPolicy]): + def _get_algorithm_class(self) -> type[SAC]: + return SAC class DiscreteSACAlgorithmFactory( diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index ed3ce71d9..1a347eccf 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -18,7 +18,7 @@ from tianshou.policy.modelfree.ppo import PPO from tianshou.policy.modelfree.trpo import TRPO from tianshou.policy.modelfree.td3 import TD3 -from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.modelfree.sac import SAC from tianshou.policy.modelfree.redq import REDQPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.imitation.base import ImitationPolicy @@ -50,7 +50,7 @@ "PPO", "TRPO", "TD3", - "SACPolicy", + "SAC", "REDQPolicy", "DiscreteSACPolicy", "ImitationPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 7f10b9f91..5c0fd3471 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -151,6 +151,14 @@ def __init__( action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", ): + """ + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + """ allowed_action_bound_methods = ("clip", "tanh") if ( action_bound_method is not None @@ -401,14 +409,6 @@ class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainingConfig, TTrainingStat torch.save(policy.state_dict(), "policy.pth") policy.load_state_dict(torch.load("policy.pth")) - - :param action_space: Env's action_space. - :param observation_space: Env's observation space. TODO: appears unused... - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ def __init__( @@ -417,6 +417,10 @@ def __init__( policy: TPolicy, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param lr_scheduler: if not None, will be called in `update()`. + """ super().__init__() self.policy: TPolicy = policy self.lr_scheduler = lr_scheduler diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 35e3fe51e..cb45b9f4d 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -12,7 +12,7 @@ from tianshou.data.buffer.base import TBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.exploration import BaseNoise -from tianshou.policy import SACPolicy +from tianshou.policy import SAC from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.sac import SACTrainingStats from tianshou.utils.conversion import to_optional_float @@ -30,12 +30,12 @@ class CQLTrainingStats(SACTrainingStats): TCQLTrainingStats = TypeVar("TCQLTrainingStats", bound=CQLTrainingStats) -class CQLPolicy(SACPolicy[TCQLTrainingStats]): +class CQLPolicy(SAC[TCQLTrainingStats]): """Implementation of CQL algorithm. arXiv:2006.04779. :param actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param actor_optim: The optimizer for actor network. + :param policy_optim: The optimizer for actor network. :param critic: The first critic network. :param critic_optim: The optimizer for the first critic network. :param action_space: Env's action space. @@ -83,7 +83,7 @@ def __init__( self, *, actor: ActorProb, - actor_optim: torch.optim.Optimizer, + policy_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, action_space: gym.spaces.Box, @@ -116,7 +116,7 @@ def __init__( ) -> None: super().__init__( actor=actor, - actor_optim=actor_optim, + policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, action_space=action_space, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 21575dcf0..d4b38a1e5 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -9,7 +9,7 @@ from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import SACPolicy +from tianshou.policy import SAC from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.sac import SACTrainingStats from tianshou.utils.net.discrete import Actor, Critic @@ -23,11 +23,11 @@ class DiscreteSACTrainingStats(SACTrainingStats): TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) -class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): +class DiscreteSACPolicy(SAC[TDiscreteSACTrainingStats]): """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. :param actor: the actor network following the rules (s_B -> dist_input_BD) - :param actor_optim: the optimizer for actor network. + :param policy_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. :param action_space: Env's action space. Should be gym.spaces.Box. @@ -55,7 +55,7 @@ def __init__( self, *, actor: torch.nn.Module | Actor, - actor_optim: torch.optim.Optimizer, + policy_optim: torch.optim.Optimizer, critic: torch.nn.Module | Critic, critic_optim: torch.optim.Optimizer, action_space: gym.spaces.Discrete, @@ -70,7 +70,7 @@ def __init__( ) -> None: super().__init__( actor=actor, - actor_optim=actor_optim, + policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, action_space=action_space, diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index d5678b047..6ee47dfab 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -15,7 +15,7 @@ ) from tianshou.exploration import BaseNoise from tianshou.policy import DDPG -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import Policy, TLearningRateScheduler, TrainingStats from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ActorProb from tianshou.utils.optim import clone_optimizer @@ -50,42 +50,71 @@ class SACTrainingStats(TrainingStats): TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) +class SACPolicy(Policy): + def __init__( + self, + *, + actor: torch.nn.Module | ActorProb, + deterministic_eval: bool = True, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + action_space: gym.Space, + observation_space: gym.Space | None = None, + ): + """ + :param actor: the actor network following the rules (s -> dist_input_BD) + :param deterministic_eval: whether to use deterministic action + (mode of Gaussian policy) in evaluation mode instead of stochastic + action sampled by the policy. Does not affect training. + :param action_scaling: whether to map actions from range [-1, 1] + to range[action_spaces.low, action_spaces.high]. + :param action_bound_method: method to bound action to range [-1, 1], + can be either "clip" (for simply clipping the action) + or empty string for no bounding. Only used if the action_space is continuous. + This parameter is ignored in SAC, which used tanh squashing after sampling + unbounded from the gaussian policy (as in (arXiv 1801.01290): Equation 21.). + :param action_space: the action space of the environment + :param observation_space: the observation space of the environment + """ + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + self.actor = actor + self.deterministic_eval = deterministic_eval + + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> DistLogProbBatchProtocol: + (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) + if self.deterministic_eval and not self.is_within_training_step: + act_B = dist.mode + else: + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) + + squashed_action = torch.tanh(act_B) + log_prob = correct_log_prob_gaussian_tanh(log_prob, squashed_action) + result = Batch( + logits=(loc_B, scale_B), + act=squashed_action, + state=hidden_BH, + dist=dist, + log_prob=log_prob, + ) + return cast(DistLogProbBatchProtocol, result) + + # TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class SACPolicy(DDPG[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] +class SAC(DDPG[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] """Implementation of Soft Actor-Critic. arXiv:1812.05905. - :param actor: the actor network following the rules (s -> dist_input_BD) - :param actor_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, - then alpha is automatically tuned. - :param estimation_step: The number of steps to look ahead. - :param exploration_noise: add noise to action for exploration. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). - :param deterministic_eval: whether to use deterministic action - (mode of Gaussian policy) in evaluation mode instead of stochastic - action sampled by the policy. Does not affect training. - :param action_scaling: whether to map actions from range [-1, 1] - to range[action_spaces.low, action_spaces.high]. - :param action_bound_method: method to bound action to range [-1, 1], - can be either "clip" (for simply clipping the action) - or empty string for no bounding. Only used if the action_space is continuous. - This parameter is ignored in SAC, which used tanh squashing after sampling - unbounded from the gaussian policy (as in (arXiv 1801.01290): Equation 21.). - :param observation_space: Env's observation space. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() - .. seealso:: Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed @@ -95,11 +124,10 @@ class SACPolicy(DDPG[TSACTrainingStats], Generic[TSACTrainingStats]): # type: i def __init__( self, *, - actor: torch.nn.Module | ActorProb, - actor_optim: torch.optim.Optimizer, + policy: SACPolicy, + policy_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, - action_space: gym.Space, critic2: torch.nn.Module | None = None, critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, @@ -108,24 +136,39 @@ def __init__( estimation_step: int = 1, exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer for actor network. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer for the first critic network. + :param action_space: Env's action space. Should be gym.spaces.Box. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param alpha: entropy regularization coefficient. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, + then alpha is automatically tuned. + :param estimation_step: The number of steps to look ahead. + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + """ super().__init__( - actor=actor, - policy_optim=actor_optim, + policy=policy, # TODO: violation of Liskov substitution principle + policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, - action_space=action_space, tau=tau, gamma=gamma, exploration_noise=exploration_noise, estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) critic2 = critic2 or deepcopy(critic) @@ -157,11 +200,10 @@ def __init__( ) # can we convert alpha to a constant tensor here? then mypy wouldn't complain self.alpha = alpha - # TODO or not TODO: add to BasePolicy? self._check_field_validity() def _check_field_validity(self) -> None: - if not isinstance(self.action_space, gym.spaces.Box): + if not isinstance(self.policy.action_space, gym.spaces.Box): raise ValueError( f"SACPolicy only supports gym.spaces.Box, but got {self.action_space=}." f"Please use DiscreteSACPolicy for discrete action spaces.", @@ -173,7 +215,7 @@ def is_auto_alpha(self) -> bool: def train(self, mode: bool = True) -> Self: self.training = mode - self.actor.train(mode) + self.policy.train(mode) self.critic.train(mode) self.critic2.train(mode) return self @@ -182,38 +224,12 @@ def sync_weight(self) -> None: self.soft_update(self.critic_old, self.critic, self.tau) self.soft_update(self.critic2_old, self.critic2, self.tau) - # TODO: violates Liskov substitution principle - def forward( # type: ignore - self, - batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - **kwargs: Any, - ) -> DistLogProbBatchProtocol: - (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) - dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) - if self.deterministic_eval and not self.is_within_training_step: - act_B = dist.mode - else: - act_B = dist.rsample() - log_prob = dist.log_prob(act_B).unsqueeze(-1) - - squashed_action = torch.tanh(act_B) - log_prob = correct_log_prob_gaussian_tanh(log_prob, squashed_action) - result = Batch( - logits=(loc_B, scale_B), - act=squashed_action, - state=hidden_BH, - dist=dist, - log_prob=log_prob, - ) - return cast(DistLogProbBatchProtocol, result) - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} - obs_next_result = self(obs_next_batch) + obs_next_result = self.policy(obs_next_batch) act_ = obs_next_result.act return ( torch.min( @@ -230,7 +246,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor - obs_result = self(batch) + obs_result = self.policy(batch) act = obs_result.act current_q1a = self.critic(batch.obs, act).flatten() current_q2a = self.critic2(batch.obs, act).flatten() From 9b06cd9e749b8014956609f2b196a2773c2cd0a6 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 16:17:20 +0100 Subject: [PATCH 024/230] v2: Use train mode for full Algorithm in update() [fix] --- tianshou/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5c0fd3471..f85ead605 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -588,7 +588,7 @@ def update( batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - with torch_train_mode(self.policy): + with torch_train_mode(self): training_stat = self._update_with_batch(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: From d1fd07f821cb6f862f2c08a5ca62e679212599b7 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 6 Mar 2025 12:47:58 +0100 Subject: [PATCH 025/230] v2: Major refactoring of DDPG, TD3 and SAC Introduce appropriate base classes * ActorCriticOffPolicyAlgorithm * ActorDualCriticsOffPolicyAlgorithm eliminating the inheritance issues that caused violations of the Liskov substitution principle: * DDPG inherits from ActorCriticOffPolicyAlgorithm * ActorDualCriticsOffPolicyAlgorithm extends ActorCriticOffPolicyAlgorithm * SAC and TD3 now inherit from ActorDualCriticsOffPolicyAlgorithm instead of DDPG --- CHANGELOG.md | 7 +- tianshou/policy/base.py | 15 +- tianshou/policy/imitation/bcq.py | 8 +- tianshou/policy/imitation/cql.py | 8 +- tianshou/policy/imitation/td3_bc.py | 10 +- tianshou/policy/modelfree/ddpg.py | 232 +++++++++++++++------- tianshou/policy/modelfree/discrete_sac.py | 2 +- tianshou/policy/modelfree/redq.py | 4 +- tianshou/policy/modelfree/sac.py | 71 +++---- tianshou/policy/modelfree/td3.py | 176 +++++++++++----- 10 files changed, 339 insertions(+), 194 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdb95a920..a333aeeba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,12 @@ * `DDPGPolicy` -> `DDPG` * The `Algorithm` abstraction can directly initiate the learning process via method `run_training`. * Fixed issues in the class hierarchy (e.g. violations of the Liskov substitution principle): - * `NPG` no longer inherits from `A2C` but from a new abstract base class + * Introduced base classes (to retain factorization without abusive inheritance): + * `ActorCriticOffPolicyAlgorithm` + * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) + * `NPG` no longer inherits from `A2C` but from a new abstract base class + * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` + * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index f85ead605..6be8ced61 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -462,8 +462,15 @@ def exploration_noise( """ return act - def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: - """Softly update the parameters of target module towards the parameters of source module.""" + def _polyak_parameter_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: + """Softly updates the parameters of a target network `tgt` with the parameters of a source network `src` + using Polyak averaging: `tau * src + (1 - tau) * tgt`. + + :param tgt: the target network that receives the parameter update + :param src: the source network whose parameters are used for the update + :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being + the fraction with which to retain the target network's parameters. + """ for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) @@ -692,8 +699,8 @@ def compute_nstep_return( :param batch: a data batch, which is equal to buffer[indices]. :param buffer: the data buffer. :param indices: tell batch's location in buffer - :param function target_q_fn: a function which compute target Q value - of "obs_next" given data buffer and wanted indices. + :param target_q_fn: a function which computes the target Q value + of "obs_next" given data buffer and wanted indices (`n_step` steps ahead). :param gamma: the discount factor, should be in [0, 1]. :param n_step: the number of estimation step, should be an int greater than 0. diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index c98a6b712..6d42f29bb 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -150,9 +150,11 @@ def forward( def sync_weight(self) -> None: """Soft-update the weight for the target network.""" - self.soft_update(self.critic_target, self.critic, self.tau) - self.soft_update(self.critic2_target, self.critic2, self.tau) - self.soft_update(self.actor_perturbation_target, self.actor_perturbation, self.tau) + self._polyak_parameter_update(self.critic_target, self.critic, self.tau) + self._polyak_parameter_update(self.critic2_target, self.critic2, self.tau) + self._polyak_parameter_update( + self.actor_perturbation_target, self.actor_perturbation, self.tau + ) def _update_with_batch( self, diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index cb45b9f4d..7b57d1fda 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -164,10 +164,10 @@ def train(self, mode: bool = True) -> Self: self.critic2.train(mode) return self - def sync_weight(self) -> None: + def _update_lagged_network_weights(self) -> None: """Soft-update the weight for the target network.""" - self.soft_update(self.critic_old, self.critic, self.tau) - self.soft_update(self.critic2_old, self.critic2, self.tau) + self._polyak_parameter_update(self.critic_old, self.critic, self.tau) + self._polyak_parameter_update(self.critic2_old, self.critic2, self.tau) def actor_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch = Batch(obs=obs, info=[None] * len(obs)) @@ -389,7 +389,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: clip_grad_norm_(self.critic2.parameters(), self.clip_grad) self.critic2_optim.step() - self.sync_weight() + self._update_lagged_network_weights() return CQLTrainingStats( # type: ignore[return-value] actor_loss=to_optional_float(actor_loss), diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 71ed24163..efde31a8c 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -106,8 +106,12 @@ def __init__( def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore # critic 1&2 - td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) - td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) + td1, critic1_loss = self._minimize_critic_squared_loss( + batch, self.critic, self.critic_optim + ) + td2, critic2_loss = self._minimize_critic_squared_loss( + batch, self.critic2, self.critic2_optim + ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor @@ -120,7 +124,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: actor_loss.backward() self._last = actor_loss.item() self.actor_optim.step() - self.sync_weight() + self._update_lagged_network_weights() self._cnt += 1 return TD3BCTrainingStats( # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index ef9c0231f..d867d7214 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,4 +1,5 @@ import warnings +from abc import ABC from copy import deepcopy from dataclasses import dataclass from typing import Any, Generic, Literal, Self, TypeVar, cast @@ -22,7 +23,9 @@ OffPolicyAlgorithm, Policy, TLearningRateScheduler, + TPolicy, TrainingStats, + TTrainingStats, ) from tianshou.utils.net.continuous import Actor, Critic @@ -98,36 +101,36 @@ def forward( return cast(ActStateBatchProtocol, Batch(act=actions, state=hidden)) -class DDPG(OffPolicyAlgorithm[DDPGPolicy, TDDPGTrainingStats], Generic[TDDPGTrainingStats]): - """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. +TActBatchProtocol = TypeVar("TActBatchProtocol", bound=ActBatchProtocol) - :param policy: the policy - :param policy_optim: The optimizer for actor network. - :param critic: The critic network. (s, a -> Q(s, a)) - :param critic_optim: The optimizer for critic network. - :param tau: Param for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. - :param exploration_noise: The exploration noise, added to the action. Defaults - to ``GaussianNoise(sigma=0.1)``. - :param estimation_step: The number of steps to look ahead. - :param lr_scheduler: if not None, will be called in `policy.update()`. - .. seealso:: +class ActorCriticOffPolicyAlgorithm( + OffPolicyAlgorithm[TPolicy, TTrainingStats], + Generic[TPolicy, TTrainingStats, TActBatchProtocol], + ABC, +): + """Base class for actor-critic off-policy algorithms that use a lagged critic + as a target network. - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. + Its implementation of `process_fn` adds the n-step return to the batch, using the + Q-values computed by the target network (lagged critic, `critic_old`) in order to compute the + reward-to-go. + + Specializations can override the action computation (`_target_q_compute_action`) or the + Q-value computation based on these actions (`_target_q_compute_value`) to customize the + target Q-value computation. """ def __init__( self, *, - policy: DDPGPolicy, + policy: Any, policy_optim: torch.optim.Optimizer, - critic: torch.nn.Module | Critic, + critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, tau: float = 0.005, gamma: float = 0.99, - exploration_noise: BaseNoise | Literal["default"] | None = "default", + exploration_noise: BaseNoise | Literal["default"] | None = None, estimation_step: int = 1, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: @@ -137,8 +140,6 @@ def __init__( policy=policy, lr_scheduler=lr_scheduler, ) - self.actor_old = deepcopy(policy.actor) - self.actor_old.eval() self.policy_optim = policy_optim self.critic = critic self.critic_old = deepcopy(critic) @@ -152,34 +153,48 @@ def __init__( # there is already a method called exploration_noise() in the base class # Now this method doesn't apply any noise and is also not overridden. See TODO there self._exploration_noise = exploration_noise - # it is only a little difference to use GaussianNoise - # self.noise = OUNoise() self.estimation_step = estimation_step def set_exp_noise(self, noise: BaseNoise | None) -> None: """Set the exploration noise.""" self._exploration_noise = noise - def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" - self.training = mode - self.policy.actor.train(mode) - self.critic.train(mode) - return self + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - def sync_weight(self) -> None: - """Soft-update the weight for the target network.""" - self.soft_update(self.actor_old, self.policy.actor, self.tau) - self.soft_update(self.critic_old, self.critic, self.tau) + def exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + if self._exploration_noise is None: + return act + if isinstance(act, np.ndarray): + return act + self._exploration_noise(act.shape) + warnings.warn("Cannot add exploration noise to non-numpy_array action.") + return act - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - return self.critic_old( - obs_next_batch.obs, self.policy(obs_next_batch, model=self.actor_old).act - ) + @staticmethod + def _minimize_critic_squared_loss( + batch: RolloutBatchProtocol, + critic: torch.nn.Module, + optimizer: torch.optim.Optimizer, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Takes an optimizer step to minimize the squared loss of the critic given a batch of data. + + :param batch: the batch containing the observations, actions, returns, and (optionally) weights. + :param critic: the critic network to minimize the loss for. + :param optimizer: the optimizer for the critic's parameters. + :return: a pair (`td`, `loss`), where `td` is the tensor of errors (current - target) and `loss` is the MSE loss. + """ + weight = getattr(batch, "weight", 1.0) + current_q = critic(batch.obs, batch.act).flatten() + target_q = batch.returns.flatten() + td = current_q - target_q + critic_loss = (td.pow(2) * weight).mean() + optimizer.zero_grad() + critic_loss.backward() + optimizer.step() + return td, critic_loss def process_fn( self, @@ -187,6 +202,8 @@ def process_fn( buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol | BatchWithReturnsProtocol: + # add the n-step return to the batch, which the critic (Q-functions) seeks to match, + # based the Q-values computed by the target network (lagged critic) return self.compute_nstep_return( batch=batch, buffer=buffer, @@ -196,47 +213,116 @@ def process_fn( n_step=self.estimation_step, ) - @staticmethod - def _mse_optimizer( - batch: RolloutBatchProtocol, - critic: torch.nn.Module, - optimizer: torch.optim.Optimizer, - ) -> tuple[torch.Tensor, torch.Tensor]: - """A simple wrapper script for updating critic network.""" - weight = getattr(batch, "weight", 1.0) - current_q = critic(batch.obs, batch.act).flatten() - target_q = batch.returns.flatten() - td = current_q - target_q - # critic_loss = F.mse_loss(current_q1, target_q) - critic_loss = (td.pow(2) * weight).mean() - optimizer.zero_grad() - critic_loss.backward() - optimizer.step() - return td, critic_loss + def _target_q_compute_action(self, batch: Batch) -> TActBatchProtocol: + """ + Computes the action to be taken for the given batch (containing the observations) + within the context of Q-value target computation. + + :param batch: the batch containing the observations. + :return: batch containing the actions to be taken. + """ + return self.policy(batch) + + def _target_q_compute_value(self, batch: Batch, act_batch: TActBatchProtocol) -> torch.Tensor: + """ + Computes the target Q-value given a batch with observations and actions taken. + + :param batch: the batch containing the observations. + :param act_batch: the batch containing the actions taken. + :return: a tensor containing the target Q-values. + """ + # compute the target Q-value using the lagged critic network (target network) + return self.critic_old(batch.obs, act_batch.act) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + """ + Computes the target Q-value for the given buffer and indices. + + :param buffer: the replay buffer + :param indices: the indices within the buffer to compute the target Q-value for + """ + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + act_batch = self._target_q_compute_action(obs_next_batch) + return self._target_q_compute_value(obs_next_batch, act_batch) + + def _update_lagged_network_weights(self) -> None: + """Updates the lagged network weights with the current weights using Polyak averaging.""" + self._polyak_parameter_update(self.critic_old, self.critic, self.tau) + + def train(self, mode: bool = True) -> Self: + """Sets the module to training mode, except for the lagged components.""" + # exclude `critic_old` from training + self.training = mode + self.policy.train(mode) + self.critic.train(mode) + return self + + +class DDPG( + ActorCriticOffPolicyAlgorithm[DDPGPolicy, TDDPGTrainingStats, ActBatchProtocol], + Generic[TDDPGTrainingStats], +): + """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.""" + + def __init__( + self, + *, + policy: DDPGPolicy, + policy_optim: torch.optim.Optimizer, + critic: torch.nn.Module | Critic, + critic_optim: torch.optim.Optimizer, + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: BaseNoise | Literal["default"] | None = "default", + estimation_step: int = 1, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :param policy_optim: The optimizer for actor network. + :param critic: The critic network. (s, a -> Q(s, a)) + :param critic_optim: The optimizer for critic network. + :param tau: Param for soft update of the target network. + :param gamma: Discount factor, in [0, 1]. + :param exploration_noise: The exploration noise, added to the action. Defaults + to ``GaussianNoise(sigma=0.1)``. + :param estimation_step: The number of steps to look ahead. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + policy_optim=policy_optim, + lr_scheduler=lr_scheduler, + critic=critic, + critic_optim=critic_optim, + tau=tau, + gamma=gamma, + exploration_noise=exploration_noise, + estimation_step=estimation_step, + ) + self.actor_old = deepcopy(policy.actor) + self.actor_old.eval() + + def _target_q_compute_action(self, batch: Batch) -> ActBatchProtocol: + # compute the action using the lagged actor network + return self.policy(batch, model=self.actor_old) + + def _update_lagged_network_weights(self) -> None: + super()._update_lagged_network_weights() + self._polyak_parameter_update(self.actor_old, self.policy.actor, self.tau) def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore # critic - td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) + td, critic_loss = self._minimize_critic_squared_loss(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer # actor actor_loss = -self.critic(batch.obs, self.policy(batch).act).mean() self.policy_optim.zero_grad() actor_loss.backward() self.policy_optim.step() - self.sync_weight() + self._update_lagged_network_weights() return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - if self._exploration_noise is None: - return act - if isinstance(act, np.ndarray): - return act + self._exploration_noise(act.shape) - warnings.warn("Cannot add exploration noise to non-numpy_array action.") - return act diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index d4b38a1e5..823c476bf 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -171,7 +171,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: self.alpha_optim.step() self.alpha = self.log_alpha.detach().exp() - self.sync_weight() + self._update_lagged_network_weights() if self.is_auto_alpha: self.alpha = cast(torch.Tensor, self.alpha) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 317e4f42d..8319d3df3 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -141,7 +141,7 @@ def is_auto_alpha(self) -> bool: return self._is_auto_alpha # TODO: why override from the base class? - def sync_weight(self) -> None: + def _update_lagged_network_weights(self) -> None: for o, n in zip(self.critic_old.parameters(), self.critic.parameters(), strict=True): o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau) @@ -224,7 +224,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: self.alpha_optim.step() self.alpha = self.log_alpha.detach().exp() - self.sync_weight() + self._update_lagged_network_weights() if self.critic_gradient_step % self.actor_delay == 0: self._last_actor_loss = actor_loss.item() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 6ee47dfab..ccb9da906 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,24 +1,22 @@ -from copy import deepcopy from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar, cast +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np import torch from torch.distributions import Independent, Normal -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch from tianshou.data.types import ( DistLogProbBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.policy import DDPG from tianshou.policy.base import Policy, TLearningRateScheduler, TrainingStats +from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ActorProb -from tianshou.utils.optim import clone_optimizer def correct_log_prob_gaussian_tanh( @@ -111,15 +109,11 @@ def forward( # type: ignore return cast(DistLogProbBatchProtocol, result) -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class SAC(DDPG[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] - """Implementation of Soft Actor-Critic. arXiv:1812.05905. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ +class SAC( + ActorDualCriticsOffPolicyAlgorithm[SACPolicy, TSACTrainingStats, DistLogProbBatchProtocol], + Generic[TSACTrainingStats], +): + """Implementation of Soft Actor-Critic. arXiv:1812.05905.""" def __init__( self, @@ -161,21 +155,18 @@ def __init__( in optimizer in each policy.update() """ super().__init__( - policy=policy, # TODO: violation of Liskov substitution principle + policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, + critic2=critic2, + critic2_optim=critic2_optim, tau=tau, gamma=gamma, exploration_noise=exploration_noise, estimation_step=estimation_step, lr_scheduler=lr_scheduler, ) - critic2 = critic2 or deepcopy(critic) - critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2, self.critic2_old = critic2, deepcopy(critic2) - self.critic2_old.eval() - self.critic2_optim = critic2_optim self.deterministic_eval = deterministic_eval self.alpha: float | torch.Tensor @@ -213,36 +204,20 @@ def _check_field_validity(self) -> None: def is_auto_alpha(self) -> bool: return self._is_auto_alpha - def train(self, mode: bool = True) -> Self: - self.training = mode - self.policy.train(mode) - self.critic.train(mode) - self.critic2.train(mode) - return self - - def sync_weight(self) -> None: - self.soft_update(self.critic_old, self.critic, self.tau) - self.soft_update(self.critic2_old, self.critic2, self.tau) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - obs_next_result = self.policy(obs_next_batch) - act_ = obs_next_result.act - return ( - torch.min( - self.critic_old(obs_next_batch.obs, act_), - self.critic2_old(obs_next_batch.obs, act_), - ) - - self.alpha * obs_next_result.log_prob - ) + def _target_q_compute_value( + self, batch: Batch, act_batch: DistLogProbBatchProtocol + ) -> torch.Tensor: + min_q_value = super()._target_q_compute_value(batch, act_batch) + return min_q_value - self.alpha * act_batch.log_prob def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore # critic 1&2 - td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) - td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) + td1, critic1_loss = self._minimize_critic_squared_loss( + batch, self.critic, self.critic_optim + ) + td2, critic2_loss = self._minimize_critic_squared_loss( + batch, self.critic2, self.critic2_optim + ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor @@ -267,7 +242,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: self.alpha_optim.step() self.alpha = self.log_alpha.detach().exp() - self.sync_weight() + self._update_lagged_network_weights() return SACTrainingStats( # type: ignore[return-value] actor_loss=actor_loss.item(), diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 12f60e851..114927a62 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,16 +1,27 @@ +from abc import ABC from copy import deepcopy from dataclasses import dataclass from typing import Any, Generic, Literal, Self, TypeVar -import numpy as np import torch -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data import Batch +from tianshou.data.types import ( + ActStateBatchProtocol, + RolloutBatchProtocol, +) from tianshou.exploration import BaseNoise -from tianshou.policy import DDPG -from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.base import ( + TLearningRateScheduler, + TPolicy, + TrainingStats, + TTrainingStats, +) +from tianshou.policy.modelfree.ddpg import ( + ActorCriticOffPolicyAlgorithm, + DDPGPolicy, + TActBatchProtocol, +) from tianshou.utils.optim import clone_optimizer @@ -24,15 +35,88 @@ class TD3TrainingStats(TrainingStats): TTD3TrainingStats = TypeVar("TTD3TrainingStats", bound=TD3TrainingStats) -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class TD3(DDPG[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] - """Implementation of TD3, arXiv:1802.09477. +class ActorDualCriticsOffPolicyAlgorithm( + ActorCriticOffPolicyAlgorithm[TPolicy, TTrainingStats, TActBatchProtocol], + Generic[TPolicy, TTrainingStats, TActBatchProtocol], + ABC, +): + """A base class for off-policy algorithms with two critics, where the target Q-value is computed as the minimum + of the two lagged critics' values. + """ - .. seealso:: + def __init__( + self, + *, + policy: Any, + policy_optim: torch.optim.Optimizer, + critic: torch.nn.Module, + critic_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: BaseNoise | Literal["default"] | None = None, + estimation_step: int = 1, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy_optim: the optimizer for actor network. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer for the first critic network. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + """ + super().__init__( + policy=policy, + policy_optim=policy_optim, + lr_scheduler=lr_scheduler, + critic=critic, + critic_optim=critic_optim, + tau=tau, + gamma=gamma, + exploration_noise=exploration_noise, + estimation_step=estimation_step, + ) + if critic2 and not critic2_optim: + raise ValueError("critic2_optim must be provided if critic2 is provided") + critic2 = critic2 or deepcopy(critic) + critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) + self.critic2, self.critic2_old = critic2, deepcopy(critic2) + self.critic2_old.eval() + self.critic2_optim = critic2_optim + + def _target_q_compute_value(self, batch: Batch, act_batch: TActBatchProtocol) -> torch.Tensor: + # compute the Q-value as the minimum of the two lagged critics + act = act_batch.act + return torch.min( + self.critic_old(batch.obs, act), + self.critic2_old(batch.obs, act), + ) + + def train(self, mode: bool = True) -> Self: + super().train(mode=mode) + self.critic2.train(mode) + return self + + def _update_lagged_network_weights(self) -> None: + super()._update_lagged_network_weights() + self._polyak_parameter_update(self.critic2_old, self.critic2, self.tau) - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ + +class TD3( + ActorDualCriticsOffPolicyAlgorithm[DDPGPolicy, TTD3TrainingStats, ActStateBatchProtocol], + Generic[TTD3TrainingStats], +): + """Implementation of TD3, arXiv:1802.09477.""" def __init__( self, @@ -53,12 +137,10 @@ def __init__( lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> actions) + :param policy: the policy :param policy_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. :param critic2: the second critic network. (s, a -> Q(s, a)). If None, use the same network as critic (via deepcopy). :param critic2_optim: the optimizer for the second critic network. @@ -71,72 +153,56 @@ def __init__( :param policy_noise: the noise used in updating policy network. :param update_actor_freq: the update frequency of actor network. :param noise_clip: the clipping range used in updating policy network. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() """ - # TODO: reduce duplication with SAC. - # Some intermediate class, like TwoCriticPolicy? super().__init__( policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, + critic2=critic2, + critic2_optim=critic2_optim, tau=tau, gamma=gamma, exploration_noise=exploration_noise, estimation_step=estimation_step, lr_scheduler=lr_scheduler, ) - if critic2 and not critic2_optim: - raise ValueError("critic2_optim must be provided if critic2 is provided") - critic2 = critic2 or deepcopy(critic) - critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2, self.critic2_old = critic2, deepcopy(critic2) - self.critic2_old.eval() - self.critic2_optim = critic2_optim - + self.actor_old = deepcopy(policy.actor) + self.actor_old.eval() self.policy_noise = policy_noise self.update_actor_freq = update_actor_freq self.noise_clip = noise_clip self._cnt = 0 self._last = 0 - def train(self, mode: bool = True) -> Self: - self.training = mode - self.policy.train(mode) - self.critic.train(mode) - self.critic2.train(mode) - return self + def _target_q_compute_action(self, batch: Batch) -> ActStateBatchProtocol: + # compute action using lagged actor + act_batch = self.policy(batch, model=self.actor_old) + act_ = act_batch.act - def sync_weight(self) -> None: - self.soft_update(self.critic_old, self.critic, self.tau) - self.soft_update(self.critic2_old, self.critic2, self.tau) - self.soft_update(self.actor_old, self.policy.actor, self.tau) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - act_ = self.policy(obs_next_batch, model=self.actor_old).act + # add noise noise = torch.randn(size=act_.shape, device=act_.device) * self.policy_noise if self.noise_clip > 0.0: noise = noise.clamp(-self.noise_clip, self.noise_clip) act_ += noise - return torch.min( - self.critic_old(obs_next_batch.obs, act_), - self.critic2_old(obs_next_batch.obs, act_), - ) + + act_batch.act = act_ + return act_batch + + def _update_lagged_network_weights(self) -> None: + super()._update_lagged_network_weights() + self._polyak_parameter_update(self.actor_old, self.policy.actor, self.tau) def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore # critic 1&2 - td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) - td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) + td1, critic1_loss = self._minimize_critic_squared_loss( + batch, self.critic, self.critic_optim + ) + td2, critic2_loss = self._minimize_critic_squared_loss( + batch, self.critic2, self.critic2_optim + ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor @@ -146,7 +212,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: actor_loss.backward() self._last = actor_loss.item() self.policy_optim.step() - self.sync_weight() + self._update_lagged_network_weights() self._cnt += 1 return TD3TrainingStats( # type: ignore[return-value] From 228326da335a5a27e98336451de5d4b54c13753a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 6 Mar 2025 23:19:50 +0100 Subject: [PATCH 026/230] v2: Adapt DiscreteSAC and test_discrete_sac --- examples/atari/atari_sac.py | 6 +- .../{test_sac.py => test_discrete_sac.py} | 53 +++-- tianshou/highlevel/algorithm.py | 8 +- tianshou/policy/__init__.py | 4 +- tianshou/policy/modelfree/ddpg.py | 39 +++- tianshou/policy/modelfree/discrete_sac.py | 209 ++++++++++-------- tianshou/policy/modelfree/sac.py | 4 +- tianshou/policy/modelfree/td3.py | 25 ++- 8 files changed, 204 insertions(+), 144 deletions(-) rename test/discrete/{test_sac.py => test_discrete_sac.py} (84%) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index d9d427f34..08c2c6390 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -11,7 +11,7 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteSACPolicy, ICMPolicy +from tianshou.policy import DiscreteSAC, ICMPolicy from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -124,8 +124,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: DiscreteSACPolicy | ICMPolicy - policy = DiscreteSACPolicy( + policy: DiscreteSAC | ICMPolicy + policy = DiscreteSAC( actor=actor, policy_optim=actor_optim, critic=critic1, diff --git a/test/discrete/test_sac.py b/test/discrete/test_discrete_sac.py similarity index 84% rename from test/discrete/test_sac.py rename to test/discrete/test_discrete_sac.py index 06a0e9fd9..8e1f49957 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -8,10 +8,13 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DiscreteSACPolicy +from tianshou.policy import DiscreteSAC from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.discrete_sac import DiscreteSACTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.discrete_sac import ( + DiscreteSACPolicy, + DiscreteSACTrainingStats, +) +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic @@ -92,11 +95,14 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: DiscreteSACPolicy[DiscreteSACTrainingStats] = DiscreteSACPolicy( + policy = DiscreteSACPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: DiscreteSAC[DiscreteSACTrainingStats] = DiscreteSAC( + policy=policy, policy_optim=actor_optim, critic=critic1, - action_space=env.action_space, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, @@ -107,11 +113,11 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "discrete_sac") @@ -124,20 +130,21 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index e18f7771f..d00ba7b10 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -55,7 +55,7 @@ TRPO, Algorithm, DeepQLearning, - DiscreteSACPolicy, + DiscreteSAC, IQNPolicy, REDQPolicy, Reinforce, @@ -698,10 +698,10 @@ def _get_algorithm_class(self) -> type[SAC]: class DiscreteSACAlgorithmFactory( - ActorDualCriticsAlgorithmFactory[DiscreteSACParams, DiscreteSACPolicy, TPolicy] + ActorDualCriticsAlgorithmFactory[DiscreteSACParams, DiscreteSAC, TPolicy] ): - def _get_algorithm_class(self) -> type[DiscreteSACPolicy]: - return DiscreteSACPolicy + def _get_algorithm_class(self) -> type[DiscreteSAC]: + return DiscreteSAC class TD3AlgorithmFactory(ActorDualCriticsAlgorithmFactory[TD3Params, TD3, DDPGPolicy]): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 1a347eccf..4522cf3d8 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -20,7 +20,7 @@ from tianshou.policy.modelfree.td3 import TD3 from tianshou.policy.modelfree.sac import SAC from tianshou.policy.modelfree.redq import REDQPolicy -from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.modelfree.discrete_sac import DiscreteSAC from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.imitation.cql import CQLPolicy @@ -52,7 +52,7 @@ "TD3", "SAC", "REDQPolicy", - "DiscreteSACPolicy", + "DiscreteSAC", "ImitationPolicy", "BCQPolicy", "CQLPolicy", diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index d867d7214..22fae5ecb 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -119,6 +119,9 @@ class ActorCriticOffPolicyAlgorithm( Specializations can override the action computation (`_target_q_compute_action`) or the Q-value computation based on these actions (`_target_q_compute_value`) to customize the target Q-value computation. + The default implementation assumes a continuous action space where a critic receives a + state/observation and an action; for discrete action spaces, where the critic receives only + a state/observation, the method `_target_q_compute_value` must be overridden. """ def __init__( @@ -134,6 +137,24 @@ def __init__( estimation_step: int = 1, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer for actor network. + :param critic: the critic network. + For continuous action spaces: (s, a -> Q(s, a)). + For discrete action spaces: (s -> ). + NOTE: The default implementation of `_target_q_compute_value` assumes + a continuous action space; override this method if using discrete actions. + :param critic_optim: the optimizer for the critic network. + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param exploration_noise: add noise to continuous actions for exploration; + set to None for discrete action spaces. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + """ assert 0.0 <= tau <= 1.0, f"tau should be in [0, 1] but got: {tau}" assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}" super().__init__( @@ -213,26 +234,28 @@ def process_fn( n_step=self.estimation_step, ) - def _target_q_compute_action(self, batch: Batch) -> TActBatchProtocol: + def _target_q_compute_action(self, obs_batch: Batch) -> TActBatchProtocol: """ Computes the action to be taken for the given batch (containing the observations) within the context of Q-value target computation. - :param batch: the batch containing the observations. + :param obs_batch: the batch containing the observations. :return: batch containing the actions to be taken. """ - return self.policy(batch) + return self.policy(obs_batch) - def _target_q_compute_value(self, batch: Batch, act_batch: TActBatchProtocol) -> torch.Tensor: + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: TActBatchProtocol + ) -> torch.Tensor: """ Computes the target Q-value given a batch with observations and actions taken. - :param batch: the batch containing the observations. + :param obs_batch: the batch containing the observations. :param act_batch: the batch containing the actions taken. :return: a tensor containing the target Q-values. """ # compute the target Q-value using the lagged critic network (target network) - return self.critic_old(batch.obs, act_batch.act) + return self.critic_old(obs_batch.obs, act_batch.act) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: """ @@ -306,9 +329,9 @@ def __init__( self.actor_old = deepcopy(policy.actor) self.actor_old.eval() - def _target_q_compute_action(self, batch: Batch) -> ActBatchProtocol: + def _target_q_compute_action(self, obs_batch: Batch) -> ActBatchProtocol: # compute the action using the lagged actor network - return self.policy(batch, model=self.actor_old) + return self.policy(obs_batch, model=self.actor_old) def _update_lagged_network_weights(self) -> None: super()._update_lagged_network_weights() diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 823c476bf..8afb929ea 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -4,15 +4,18 @@ import gymnasium as gym import numpy as np import torch -from overrides import override from torch.distributions import Categorical -from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import SAC -from tianshou.policy.base import TLearningRateScheduler +from tianshou.data import Batch, to_torch +from tianshou.data.types import ( + DistBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.policy.base import Policy, TLearningRateScheduler from tianshou.policy.modelfree.sac import SACTrainingStats -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.utils.net.discrete import Critic @dataclass @@ -23,107 +26,133 @@ class DiscreteSACTrainingStats(SACTrainingStats): TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) -class DiscreteSACPolicy(SAC[TDiscreteSACTrainingStats]): - """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. - - :param actor: the actor network following the rules (s_B -> dist_input_BD) - :param policy_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, - then alpha is automatically tuned. - :param estimation_step: the number of steps to look ahead for calculating - :param observation_space: Env's observation space. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ +# TODO: This is a vanilla discrete actor policy; we may not need this "specific" class. +class DiscreteSACPolicy(Policy): + def __init__( + self, + *, + actor: torch.nn.Module, + deterministic_eval: bool = True, + action_space: gym.Space, + observation_space: gym.Space | None = None, + ): + """ + :param actor: the actor network following the rules (s -> dist_input_BD), + where the distribution input is for a `Categorical` distribution. + :param deterministic_eval: whether, in evaluation/inference mode, to use always + use the most probable action instead of sampling an action from the + categorical distribution. This setting does not affect data collection + for training, where actions are always sampled. + :param action_space: the action space of the environment + :param observation_space: the observation space of the environment + """ + assert isinstance(action_space, gym.spaces.Discrete) + super().__init__( + action_space=action_space, + observation_space=observation_space, + ) + self.actor = actor + self.deterministic_eval = deterministic_eval + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: + logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Categorical(logits=logits_BA) + act_B = ( + dist.mode + if self.deterministic_eval and not self.is_within_training_step + else dist.sample() + ) + return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) + + +class DiscreteSAC( + ActorDualCriticsOffPolicyAlgorithm[ + DiscreteSACPolicy, TDiscreteSACTrainingStats, DistBatchProtocol + ] +): + """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.""" def __init__( self, *, - actor: torch.nn.Module | Actor, + policy: DiscreteSACPolicy, policy_optim: torch.optim.Optimizer, critic: torch.nn.Module | Critic, critic_optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, critic2: torch.nn.Module | Critic | None = None, critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, gamma: float = 0.99, alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, estimation_step: int = 1, - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param policy_optim: the optimizer for actor network. + :param critic: the first critic network. (s -> ). + :param critic_optim: the optimizer for the first critic network. + :param critic2: the second critic network. (s -> ). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param alpha: entropy regularization coefficient. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, + then alpha is automatically tuned. + :param estimation_step: the number of steps to look ahead for calculating + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + """ super().__init__( - actor=actor, + policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, - action_space=action_space, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, - alpha=alpha, estimation_step=estimation_step, - # Note: inheriting from continuous sac reduces code duplication, - # but continuous stuff has to be disabled exploration_noise=None, - action_scaling=False, - action_bound_method=None, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) - # TODO: violates Liskov substitution principle, incompatible action space with SAC - # Not too urgent, but still.. - @override - def _check_field_validity(self) -> None: - if not isinstance(self.action_space, gym.spaces.Discrete): - raise ValueError( - f"DiscreteSACPolicy only supports gym.spaces.Discrete, but got {self.action_space=}." - f"Please use SACPolicy for continuous action spaces.", - ) - - def forward( # type: ignore - self, - batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - **kwargs: Any, - ) -> Batch: - logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) - dist = Categorical(logits=logits_BA) - act_B = ( - dist.mode - if self.deterministic_eval and not self.is_within_training_step - else dist.sample() - ) - return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - obs_next_result = self(obs_next_batch) - dist = obs_next_result.dist + self.alpha: float | torch.Tensor + self._is_auto_alpha = not isinstance(alpha, float) + if self._is_auto_alpha: + # TODO: why doesn't mypy understand that this must be a tuple? + alpha = cast(tuple[float, torch.Tensor, torch.optim.Optimizer], alpha) + if alpha[1].shape != torch.Size([1]): + raise ValueError( + f"Expected log_alpha to have shape torch.Size([1]), " + f"but got {alpha[1].shape} instead.", + ) + if not alpha[1].requires_grad: + raise ValueError("Expected log_alpha to require gradient, but it doesn't.") + + self.target_entropy, self.log_alpha, self.alpha_optim = alpha + self.alpha = self.log_alpha.detach().exp() + else: + alpha = cast( + float, + alpha, + ) # can we convert alpha to a constant tensor here? then mypy wouldn't complain + self.alpha = alpha + + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: DistBatchProtocol + ) -> torch.Tensor: + dist = cast(Categorical, act_batch.dist) target_q = dist.probs * torch.min( - self.critic_old(obs_next_batch.obs), - self.critic2_old(obs_next_batch.obs), + self.critic_old(obs_batch.obs), + self.critic2_old(obs_batch.obs), ) return target_q.sum(dim=-1) + self.alpha * dist.entropy() @@ -152,28 +181,31 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor - dist = self(batch).dist + dist = self.policy(batch).dist entropy = dist.entropy() with torch.no_grad(): current_q1a = self.critic(batch.obs) current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) actor_loss = -(self.alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() - self.actor_optim.zero_grad() + self.policy_optim.zero_grad() actor_loss.backward() - self.actor_optim.step() + self.policy_optim.step() - if self.is_auto_alpha: + if self._is_auto_alpha: log_prob = -entropy.detach() + self.target_entropy alpha_loss = -(self.log_alpha * log_prob).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() self.alpha = self.log_alpha.detach().exp() + alpha_loss_value = alpha_loss.item() + else: + alpha_loss_value = None self._update_lagged_network_weights() - if self.is_auto_alpha: + if self._is_auto_alpha: self.alpha = cast(torch.Tensor, self.alpha) return DiscreteSACTrainingStats( # type: ignore[return-value] @@ -181,14 +213,5 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, - alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(), + alpha_loss=alpha_loss_value, ) - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - return act diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ccb9da906..b21eea17a 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -205,9 +205,9 @@ def is_auto_alpha(self) -> bool: return self._is_auto_alpha def _target_q_compute_value( - self, batch: Batch, act_batch: DistLogProbBatchProtocol + self, obs_batch: Batch, act_batch: DistLogProbBatchProtocol ) -> torch.Tensor: - min_q_value = super()._target_q_compute_value(batch, act_batch) + min_q_value = super()._target_q_compute_value(obs_batch, act_batch) return min_q_value - self.alpha * act_batch.log_prob def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 114927a62..139c52840 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -60,16 +60,21 @@ def __init__( lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ + :param policy: the policy :param policy_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic: the first critic network. + For continuous action spaces: (s, a -> Q(s, a)). + NOTE: The default implementation of `_target_q_compute_value` assumes + a continuous action space; override this method if using discrete actions. :param critic_optim: the optimizer for the first critic network. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). + :param critic2: the second critic network (analogous functionality to the first). + If None, use the same network as the first critic (via deepcopy). :param critic2_optim: the optimizer for the second critic network. If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. - :param exploration_noise: add noise to action for exploration. + :param exploration_noise: add noise to continuous actions for exploration; + set to None for discrete action spaces. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). :param lr_scheduler: a learning rate scheduler that adjusts the learning rate @@ -94,12 +99,14 @@ def __init__( self.critic2_old.eval() self.critic2_optim = critic2_optim - def _target_q_compute_value(self, batch: Batch, act_batch: TActBatchProtocol) -> torch.Tensor: + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: TActBatchProtocol + ) -> torch.Tensor: # compute the Q-value as the minimum of the two lagged critics act = act_batch.act return torch.min( - self.critic_old(batch.obs, act), - self.critic2_old(batch.obs, act), + self.critic_old(obs_batch.obs, act), + self.critic2_old(obs_batch.obs, act), ) def train(self, mode: bool = True) -> Self: @@ -177,9 +184,9 @@ def __init__( self._cnt = 0 self._last = 0 - def _target_q_compute_action(self, batch: Batch) -> ActStateBatchProtocol: + def _target_q_compute_action(self, obs_batch: Batch) -> ActStateBatchProtocol: # compute action using lagged actor - act_batch = self.policy(batch, model=self.actor_old) + act_batch = self.policy(obs_batch, model=self.actor_old) act_ = act_batch.act # add noise From a22c89a0e9318fe7ae74f8cf117bd72741543aac Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 00:50:45 +0100 Subject: [PATCH 027/230] v2: Refactor SAC and DiscreteSAC to use new abstractions for alpha handling This commit introduces a better abstraction for alpha parameter handling in SAC implementations through a dedicated class hierarchy: - Add abstract Alpha base class with value property and update method - Add FixedAlpha for constant entropy coefficients - Add AutoAlpha for automatic entropy tuning The refactoring simplifies the API by: - Replacing the complex tuple-based auto-alpha representation with proper classes - Providing a consistent interface for both fixed and auto-tuned parameters - Encapsulating alpha-related logic in dedicated classes - Improving code readability and maintainability Both implementations (continuous and discrete SAC) now share the same alpha abstraction, making the codebase more consistent while preserving the original functionality. --- CHANGELOG.md | 10 ++ tianshou/highlevel/params/alpha.py | 7 +- tianshou/policy/modelfree/discrete_sac.py | 55 ++-------- tianshou/policy/modelfree/sac.py | 126 +++++++++++++++------- 4 files changed, 110 insertions(+), 88 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a333aeeba..2191ce62c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,16 @@ * `DQNPolicy` -> `DeepQLearning` * `DDPGPolicy` -> `DDPG` * The `Algorithm` abstraction can directly initiate the learning process via method `run_training`. + * Internal design improvements: + * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) + in `SAC`, `DiscreteSAC` and other algorithms. + * Class hierarchy: + * Abstract base class `Alpha` base class with value property and update method + * `FixedAlpha` for constant entropy coefficients + * `AutoAlpha` for automatic entropy tuning (replaces the old tuple-based representation) + * The (auto-)updating logic is now completely encapsulated, reducing the complexity of the algorithms. + * Implementations for continuous and discrete cases now share the same abstraction, + making the codebase more consistent while preserving the original functionality. * Fixed issues in the class hierarchy (e.g. violations of the Liskov substitution principle): * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOffPolicyAlgorithm` diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 2b662eb44..9f276b472 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -7,6 +7,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.optim import OptimizerFactory +from tianshou.policy.modelfree.sac import AutoAlpha class AutoAlphaFactory(ToStringMixin, ABC): @@ -16,7 +17,7 @@ def create_auto_alpha( envs: Environments, optim_factory: OptimizerFactory, device: TDevice, - ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: + ) -> AutoAlpha: pass @@ -29,8 +30,8 @@ def create_auto_alpha( envs: Environments, optim_factory: OptimizerFactory, device: TDevice, - ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: + ) -> AutoAlpha: target_entropy = float(-np.prod(envs.get_action_shape())) log_alpha = torch.zeros(1, requires_grad=True, device=device) alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) - return target_entropy, log_alpha, alpha_optim + return AutoAlpha(target_entropy, log_alpha, alpha_optim) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 8afb929ea..104b89390 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -13,7 +13,7 @@ RolloutBatchProtocol, ) from tianshou.policy.base import Policy, TLearningRateScheduler -from tianshou.policy.modelfree.sac import SACTrainingStats +from tianshou.policy.modelfree.sac import Alpha, FixedAlpha, SACTrainingStats from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.utils.net.discrete import Critic @@ -88,7 +88,7 @@ def __init__( critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, gamma: float = 0.99, - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, + alpha: float | Alpha = 0.2, estimation_step: int = 1, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: @@ -103,9 +103,8 @@ def __init__( If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, - then alpha is automatically tuned. + :param alpha: entropy regularization coefficient or an object + which can be used to automatically tune alpha (e.g. an instance of `AutoAlpha`). :param estimation_step: the number of steps to look ahead for calculating :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() @@ -123,28 +122,8 @@ def __init__( exploration_noise=None, lr_scheduler=lr_scheduler, ) - - self.alpha: float | torch.Tensor - self._is_auto_alpha = not isinstance(alpha, float) - if self._is_auto_alpha: - # TODO: why doesn't mypy understand that this must be a tuple? - alpha = cast(tuple[float, torch.Tensor, torch.optim.Optimizer], alpha) - if alpha[1].shape != torch.Size([1]): - raise ValueError( - f"Expected log_alpha to have shape torch.Size([1]), " - f"but got {alpha[1].shape} instead.", - ) - if not alpha[1].requires_grad: - raise ValueError("Expected log_alpha to require gradient, but it doesn't.") - - self.target_entropy, self.log_alpha, self.alpha_optim = alpha - self.alpha = self.log_alpha.detach().exp() - else: - alpha = cast( - float, - alpha, - ) # can we convert alpha to a constant tensor here? then mypy wouldn't complain - self.alpha = alpha + self.alpha = FixedAlpha(alpha) if isinstance(alpha, float) else alpha + assert isinstance(self.alpha, Alpha) def _target_q_compute_value( self, obs_batch: Batch, act_batch: DistBatchProtocol @@ -154,7 +133,7 @@ def _target_q_compute_value( self.critic_old(obs_batch.obs), self.critic2_old(obs_batch.obs), ) - return target_q.sum(dim=-1) + self.alpha * dist.entropy() + return target_q.sum(dim=-1) + self.alpha.value * dist.entropy() def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore weight = batch.pop("weight", 1.0) @@ -187,31 +166,19 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: current_q1a = self.critic(batch.obs) current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) - actor_loss = -(self.alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() + actor_loss = -(self.alpha.value * entropy + (dist.probs * q).sum(dim=-1)).mean() self.policy_optim.zero_grad() actor_loss.backward() self.policy_optim.step() - if self._is_auto_alpha: - log_prob = -entropy.detach() + self.target_entropy - alpha_loss = -(self.log_alpha * log_prob).mean() - self.alpha_optim.zero_grad() - alpha_loss.backward() - self.alpha_optim.step() - self.alpha = self.log_alpha.detach().exp() - alpha_loss_value = alpha_loss.item() - else: - alpha_loss_value = None + alpha_loss = self.alpha.update(entropy.detach()) self._update_lagged_network_weights() - if self._is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - return DiscreteSACTrainingStats( # type: ignore[return-value] actor_loss=actor_loss.item(), critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), - alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, - alpha_loss=alpha_loss_value, + alpha=self.alpha.value, + alpha_loss=alpha_loss, ) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index b21eea17a..58e31fbfc 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Generic, Literal, TypeVar, cast @@ -109,6 +110,81 @@ def forward( # type: ignore return cast(DistLogProbBatchProtocol, result) +class Alpha(ABC): + """Defines the interface for the entropy regularization coefficient alpha.""" + + @property + @abstractmethod + def value(self) -> float: + """Retrieves the current value of alpha.""" + + @abstractmethod + def update(self, entropy: torch.Tensor) -> float | None: + """ + Updates the alpha value based on the entropy. + + :param entropy: the entropy of the policy. + :return: the loss value if alpha is auto-tuned, otherwise None. + """ + return None + + +class FixedAlpha(Alpha): + """Represents a fixed entropy regularization coefficient alpha.""" + + def __init__(self, alpha: float): + self._value = alpha + + @property + def value(self) -> float: + return self._value + + def update(self, entropy: torch.Tensor) -> float | None: + return None + + +class AutoAlpha(torch.nn.Module, Alpha): + """Represents an entropy regularization coefficient alpha that is automatically tuned.""" + + def __init__( + self, target_entropy: float, log_alpha: torch.Tensor, optim: torch.optim.Optimizer + ): + """ + :param target_entropy: the target entropy value. + For discrete action spaces, it is usually -log(|A|) for a balance between stochasticity + and determinism or -log(1/|A|)=log(|A|) for maximum stochasticity or, more generally, + lambda*log(|A|), e.g. with lambda close to 1 (e.g. 0.98) for pronounced stochasticity. + For continuous action spaces, it is usually -dim(A) for a balance between stochasticity + and determinism, with similar generalizations as for discrete action spaces. + :param log_alpha: the (initial) log of the entropy regularization coefficient alpha. + This must be a scalar tensor with requires_grad=True. + :param optim: the optimizer for `log_alpha`. + """ + super().__init__() + if not log_alpha.requires_grad: + raise ValueError("Expected log_alpha to require gradient, but it doesn't.") + if log_alpha.shape != torch.Size([1]): + raise ValueError( + f"Expected log_alpha to have shape torch.Size([1]), " + f"but got {log_alpha.shape} instead.", + ) + self._target_entropy = target_entropy + self._log_alpha = log_alpha + self._optim = optim + + @property + def value(self) -> float: + return self._log_alpha.detach().exp().item() + + def update(self, entropy: torch.Tensor) -> float: + entropy_deficit = self._target_entropy - entropy + alpha_loss = -(self._log_alpha * entropy_deficit).mean() + self._optim.zero_grad() + alpha_loss.backward() + self._optim.step() + return alpha_loss.item() + + class SAC( ActorDualCriticsOffPolicyAlgorithm[SACPolicy, TSACTrainingStats, DistLogProbBatchProtocol], Generic[TSACTrainingStats], @@ -126,7 +202,7 @@ def __init__( critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, gamma: float = 0.99, - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, + alpha: float | Alpha = 0.2, estimation_step: int = 1, exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, @@ -137,7 +213,6 @@ def __init__( :param policy_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. :param critic2: the second critic network. (s, a -> Q(s, a)). If None, use the same network as critic (via deepcopy). :param critic2_optim: the optimizer for the second critic network. @@ -168,29 +243,8 @@ def __init__( lr_scheduler=lr_scheduler, ) self.deterministic_eval = deterministic_eval - - self.alpha: float | torch.Tensor - self._is_auto_alpha = not isinstance(alpha, float) - if self._is_auto_alpha: - # TODO: why doesn't mypy understand that this must be a tuple? - alpha = cast(tuple[float, torch.Tensor, torch.optim.Optimizer], alpha) - if alpha[1].shape != torch.Size([1]): - raise ValueError( - f"Expected log_alpha to have shape torch.Size([1]), " - f"but got {alpha[1].shape} instead.", - ) - if not alpha[1].requires_grad: - raise ValueError("Expected log_alpha to require gradient, but it doesn't.") - - self.target_entropy, self.log_alpha, self.alpha_optim = alpha - self.alpha = self.log_alpha.detach().exp() - else: - alpha = cast( - float, - alpha, - ) # can we convert alpha to a constant tensor here? then mypy wouldn't complain - self.alpha = alpha - + self.alpha = FixedAlpha(alpha) if isinstance(alpha, float) else alpha + assert isinstance(self.alpha, Alpha) self._check_field_validity() def _check_field_validity(self) -> None: @@ -200,15 +254,11 @@ def _check_field_validity(self) -> None: f"Please use DiscreteSACPolicy for discrete action spaces.", ) - @property - def is_auto_alpha(self) -> bool: - return self._is_auto_alpha - def _target_q_compute_value( self, obs_batch: Batch, act_batch: DistLogProbBatchProtocol ) -> torch.Tensor: min_q_value = super()._target_q_compute_value(obs_batch, act_batch) - return min_q_value - self.alpha * act_batch.log_prob + return min_q_value - self.alpha.value * act_batch.log_prob def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore # critic 1&2 @@ -226,21 +276,15 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: current_q1a = self.critic(batch.obs, act).flatten() current_q2a = self.critic2(batch.obs, act).flatten() actor_loss = ( - self.alpha * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) + self.alpha.value * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) ).mean() self.policy_optim.zero_grad() actor_loss.backward() self.policy_optim.step() - alpha_loss = None - if self.is_auto_alpha: - log_prob = obs_result.log_prob.detach() + self.target_entropy - # please take a look at issue #258 if you'd like to change this line - alpha_loss = -(self.log_alpha * log_prob).mean() - self.alpha_optim.zero_grad() - alpha_loss.backward() - self.alpha_optim.step() - self.alpha = self.log_alpha.detach().exp() + # The entropy of a Gaussian policy can be expressed as -log_prob + a constant (which we ignore) + entropy = -obs_result.log_prob.detach() + alpha_loss = self.alpha.update(entropy) self._update_lagged_network_weights() @@ -248,6 +292,6 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: actor_loss=actor_loss.item(), critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), - alpha=to_optional_float(self.alpha), + alpha=to_optional_float(self.alpha.value), alpha_loss=to_optional_float(alpha_loss), ) From ad747761a2f743bc2cb95dc5644d7ff74b41e796 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 01:52:32 +0100 Subject: [PATCH 028/230] AutoAlphaFactoryDefault: Differentiate discrete and continuous action spaces and allow coefficient to be modified, adding an informative docstring (previous implementation was reasonable only for continuous action spaces) Adjust parametrisation to match procedural example in atari_sac_hl --- examples/atari/atari_sac_hl.py | 5 +++-- tianshou/highlevel/params/alpha.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 23186606a..049928856 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -45,7 +45,6 @@ def main( training_num: int = 10, test_num: int = 10, frames_stack: int = 4, - save_buffer_name: str | None = None, # TODO add support in high-level API? icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, @@ -84,7 +83,9 @@ def main( critic2_lr=critic_lr, gamma=gamma, tau=tau, - alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, + alpha=AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98) + if auto_alpha + else alpha, estimation_step=n_step, ), ) diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 9f276b472..5f42d10e1 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -22,8 +22,18 @@ def create_auto_alpha( class AutoAlphaFactoryDefault(AutoAlphaFactory): - def __init__(self, lr: float = 3e-4): + def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0): + """ + :param lr: the learning rate for the optimizer of the alpha parameter + :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; + The base value being scaled is dim(A) for continuous action spaces and log(|A|) for discrete action spaces, + i.e. with the default coefficient -1, we obtain -dim(A) and -log(dim(A)) for continuous and discrete action + spaces respectively, which gives a reasonable trade-off between exploration and exploitation. + For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); + 1.0 would give full entropy exploration. + """ self.lr = lr + self.target_entropy_coefficient = target_entropy_coefficient def create_auto_alpha( self, @@ -31,7 +41,11 @@ def create_auto_alpha( optim_factory: OptimizerFactory, device: TDevice, ) -> AutoAlpha: - target_entropy = float(-np.prod(envs.get_action_shape())) + action_dim = np.prod(envs.get_action_shape()) + if envs.get_type().is_continuous(): + target_entropy = self.target_entropy_coefficient * float(action_dim) + else: + target_entropy = self.target_entropy_coefficient * np.log(action_dim) log_alpha = torch.zeros(1, requires_grad=True, device=device) alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) return AutoAlpha(target_entropy, log_alpha, alpha_optim) From dc4dce2fd96815a70ceb72ecd3adace0a6cc38c6 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 03:27:02 +0100 Subject: [PATCH 029/230] v2: Adapt REDQ (and test_redq), fixing problematic inheritance from DDPG (inherit from ActorCriticOffPolicyAlgorithm instead) --- CHANGELOG.md | 3 +- examples/mujoco/mujoco_redq.py | 8 +- test/continuous/test_redq.py | 49 +++-- tianshou/highlevel/algorithm.py | 8 +- tianshou/policy/__init__.py | 4 +- tianshou/policy/modelfree/discrete_sac.py | 4 +- tianshou/policy/modelfree/redq.py | 252 +++++++++++----------- tianshou/policy/modelfree/sac.py | 5 +- 8 files changed, 163 insertions(+), 170 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2191ce62c..09fe609f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,8 @@ * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOffPolicyAlgorithm` * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) - * `NPG` no longer inherits from `A2C` but from a new abstract base class + * `NPG`: Inherit from `AbstractActorCriticWithAdvantage` instead of `A2C` (which is now has the same base class) + * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 444b550fc..48625ca2e 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import REDQPolicy +from tianshou.policy import REDQ from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.common import EnsembleLinear, Net @@ -121,9 +121,9 @@ def linear(x: int, y: int) -> EnsembleLinear: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: REDQPolicy = REDQPolicy( - actor=actor, - actor_optim=actor_optim, + policy: REDQ = REDQ( + policy=actor, + policy_optim=actor_optim, critic=critics, critic_optim=critics_optim, ensemble_size=args.ensemble_size, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index fa947a3fc..739bdba42 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -9,9 +9,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import REDQPolicy +from tianshou.policy import REDQ from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.redq import REDQPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -111,9 +112,13 @@ def linear(x: int, y: int) -> nn.Module: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: REDQPolicy = REDQPolicy( + policy = REDQPolicy( actor=actor, - actor_optim=actor_optim, + action_space=env.action_space, + ) + algorithm: REDQ = REDQ( + policy=policy, + policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, ensemble_size=args.ensemble_size, @@ -124,16 +129,15 @@ def linear(x: int, y: int) -> nn.Module: estimation_step=args.n_step, actor_delay=args.update_per_step, target_mode=args.target_mode, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log @@ -147,19 +151,20 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index d00ba7b10..0667f6e24 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -50,6 +50,7 @@ DDPG, NPG, PPO, + REDQ, SAC, TD3, TRPO, @@ -57,7 +58,6 @@ DeepQLearning, DiscreteSAC, IQNPolicy, - REDQPolicy, Reinforce, ) from tianshou.policy.base import ( @@ -594,9 +594,9 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: ), ) action_space = cast(gymnasium.spaces.Box, envs.get_action_space()) - return REDQPolicy( - actor=actor.module, - actor_optim=actor.optim, + return REDQ( + policy=actor.module, + policy_optim=actor.optim, critic=critic_ensemble.module, critic_optim=critic_ensemble.optim, action_space=action_space, diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 4522cf3d8..561a76a3b 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -19,7 +19,7 @@ from tianshou.policy.modelfree.trpo import TRPO from tianshou.policy.modelfree.td3 import TD3 from tianshou.policy.modelfree.sac import SAC -from tianshou.policy.modelfree.redq import REDQPolicy +from tianshou.policy.modelfree.redq import REDQ from tianshou.policy.modelfree.discrete_sac import DiscreteSAC from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.bcq import BCQPolicy @@ -51,7 +51,7 @@ "TRPO", "TD3", "SAC", - "REDQPolicy", + "REDQ", "DiscreteSAC", "ImitationPolicy", "BCQPolicy", diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 104b89390..a3a879e83 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -103,8 +103,8 @@ def __init__( If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient or an object - which can be used to automatically tune alpha (e.g. an instance of `AutoAlpha`). + :param alpha: the entropy regularization coefficient alpha or an object + which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param estimation_step: the number of steps to look ahead for calculating :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 8319d3df3..1638242d8 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -6,12 +6,19 @@ import torch from torch.distributions import Independent, Normal -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol +from tianshou.data import Batch +from tianshou.data.types import ( + DistLogProbBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.exploration import BaseNoise -from tianshou.policy import DDPG -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.ddpg import DDPGTrainingStats +from tianshou.policy.base import Policy, TLearningRateScheduler +from tianshou.policy.modelfree.ddpg import ( + ActorCriticOffPolicyAlgorithm, + DDPGTrainingStats, +) +from tianshou.policy.modelfree.sac import Alpha, FixedAlpha from tianshou.utils.net.continuous import ActorProb @@ -26,62 +33,111 @@ class REDQTrainingStats(DDPGTrainingStats): TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) -class REDQPolicy(DDPG[TREDQTrainingStats]): - """Implementation of REDQ. arXiv:2101.05982. - - :param actor: The actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> model_output) - :param actor_optim: The optimizer for actor network. - :param critic: The critic network. (s, a -> Q(s, a)) - :param critic_optim: The optimizer for critic network. - :param action_space: Env's action space. - :param ensemble_size: Number of sub-networks in the critic ensemble. - :param subset_size: Number of networks in the subset. - :param tau: Param for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then - alpha is automatically tuned. - :param exploration_noise: The exploration noise, added to the action. Defaults - to ``GaussianNoise(sigma=0.1)``. - :param estimation_step: The number of steps to look ahead. - :param actor_delay: Number of critic updates before an actor update. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - +class REDQPolicy(Policy): def __init__( self, *, actor: torch.nn.Module | ActorProb, - actor_optim: torch.optim.Optimizer, + action_space: gym.spaces.Box, + deterministic_eval: bool = True, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + observation_space: gym.Space | None = None, + ): + """ + :param actor: The actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> model_output) + :param action_space: Env's action space. + :param deterministic_eval: whether, in evaluation/inference mode, to use always + use the most probable action instead of sampling an action from the + categorical distribution. This setting does not affect data collection + for training, where actions are always sampled. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + """ + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + self.actor = actor + self.deterministic_eval = deterministic_eval + self._eps = np.finfo(np.float32).eps.item() + + def forward( # type: ignore + self, + batch: ObsBatchProtocol, + state: dict | Batch | np.ndarray | None = None, + **kwargs: Any, + ) -> DistLogProbBatchProtocol: + (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc_B, scale_B), 1) + if self.deterministic_eval and not self.is_within_training_step: + act_B = dist.mode + else: + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) + # apply correction for Tanh squashing when computing logprob from Gaussian + # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. + # in appendix C to get some understanding of this equation. + squashed_action = torch.tanh(act_B) + log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self._eps).sum( + -1, + keepdim=True, + ) + result = Batch( + logits=(loc_B, scale_B), + act=squashed_action, + state=h_BH, + dist=dist, + log_prob=log_prob, + ) + return cast(DistLogProbBatchProtocol, result) + + +class REDQ(ActorCriticOffPolicyAlgorithm[REDQPolicy, TREDQTrainingStats, DistLogProbBatchProtocol]): + """Implementation of REDQ. arXiv:2101.05982.""" + + def __init__( + self, + *, + policy: REDQPolicy, + policy_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, - action_space: gym.spaces.Box, ensemble_size: int = 10, subset_size: int = 2, tau: float = 0.005, gamma: float = 0.99, - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, + alpha: float | Alpha = 0.2, estimation_step: int = 1, actor_delay: int = 20, exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, target_mode: Literal["mean", "min"] = "min", - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param policy_optim: The optimizer for actor network. + :param critic: The critic network. (s, a -> Q(s, a)) + :param critic_optim: The optimizer for critic network. + :param ensemble_size: Number of sub-networks in the critic ensemble. + :param subset_size: Number of networks in the subset. + :param tau: Param for soft update of the target network. + :param gamma: Discount factor, in [0, 1]. + :param alpha: the entropy regularization coefficient alpha or an object + which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). + :param exploration_noise: The exploration noise, added to the action. Defaults + to ``GaussianNoise(sigma=0.1)``. + :param estimation_step: The number of steps to look ahead. + :param actor_delay: Number of critic updates before an actor update. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ if target_mode not in ("min", "mean"): raise ValueError(f"Unsupported target_mode: {target_mode}") if not 0 < subset_size <= ensemble_size: @@ -90,18 +146,14 @@ def __init__( f"Should be 0 < {subset_size=} <= {ensemble_size=}", ) super().__init__( - actor=actor, - policy_optim=actor_optim, + policy=policy, + policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, - action_space=action_space, tau=tau, gamma=gamma, exploration_noise=exploration_noise, estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) self.ensemble_size = ensemble_size @@ -115,80 +167,23 @@ def __init__( self._last_actor_loss = 0.0 # only for logging purposes - # TODO: reduce duplication with SACPolicy - self.alpha: float | torch.Tensor - self._is_auto_alpha = not isinstance(alpha, float) - if self._is_auto_alpha: - # TODO: why doesn't mypy understand that this must be a tuple? - alpha = cast(tuple[float, torch.Tensor, torch.optim.Optimizer], alpha) - if alpha[1].shape != torch.Size([1]): - raise ValueError( - f"Expected log_alpha to have shape torch.Size([1]), " - f"but got {alpha[1].shape} instead.", - ) - if not alpha[1].requires_grad: - raise ValueError("Expected log_alpha to require gradient, but it doesn't.") - - self.target_entropy, self.log_alpha, self.alpha_optim = alpha - self.alpha = self.log_alpha.detach().exp() - else: - # TODO: make mypy undestand this, or switch to something like pyright... - alpha = cast(float, alpha) - self.alpha = alpha - - @property - def is_auto_alpha(self) -> bool: - return self._is_auto_alpha - - # TODO: why override from the base class? - def _update_lagged_network_weights(self) -> None: - for o, n in zip(self.critic_old.parameters(), self.critic.parameters(), strict=True): - o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau) - - def forward( # type: ignore - self, - batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - **kwargs: Any, - ) -> Batch: - (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) - dist = Independent(Normal(loc_B, scale_B), 1) - if self.deterministic_eval and not self.is_within_training_step: - act_B = dist.mode - else: - act_B = dist.rsample() - log_prob = dist.log_prob(act_B).unsqueeze(-1) - # apply correction for Tanh squashing when computing logprob from Gaussian - # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. - # in appendix C to get some understanding of this equation. - squashed_action = torch.tanh(act_B) - log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( - -1, - keepdim=True, - ) - return Batch( - logits=(loc_B, scale_B), - act=squashed_action, - state=h_BH, - dist=dist, - log_prob=log_prob, - ) + self.alpha = FixedAlpha(alpha) if isinstance(alpha, float) else alpha + assert isinstance(self.alpha, Alpha) - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - obs_next_result = self(obs_next_batch) - a_ = obs_next_result.act + def _target_q_compute_value( + self, obs_batch: Batch, act_batch: DistLogProbBatchProtocol + ) -> torch.Tensor: + a_ = act_batch.act sample_ensemble_idx = np.random.choice(self.ensemble_size, self.subset_size, replace=False) - qs = self.critic_old(obs_next_batch.obs, a_)[sample_ensemble_idx, ...] + qs = self.critic_old(obs_batch.obs, a_)[sample_ensemble_idx, ...] if self.target_mode == "min": target_q, _ = torch.min(qs, dim=0) elif self.target_mode == "mean": target_q = torch.mean(qs, dim=0) + else: + raise ValueError(f"Invalid target_mode: {self.target_mode}") - target_q -= self.alpha * obs_next_result.log_prob + target_q -= self.alpha.value * act_batch.log_prob return target_q @@ -208,32 +203,25 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: alpha_loss = None # actor if self.critic_gradient_step % self.actor_delay == 0: - obs_result = self(batch) + obs_result = self.policy(batch) a = obs_result.act current_qa = self.critic(batch.obs, a).mean(dim=0).flatten() - actor_loss = (self.alpha * obs_result.log_prob.flatten() - current_qa).mean() + actor_loss = (self.alpha.value * obs_result.log_prob.flatten() - current_qa).mean() self.policy_optim.zero_grad() actor_loss.backward() self.policy_optim.step() - if self.is_auto_alpha: - log_prob = obs_result.log_prob.detach() + self._target_entropy - alpha_loss = -(self._log_alpha * log_prob).mean() - self.alpha_optim.zero_grad() - alpha_loss.backward() - self.alpha_optim.step() - self.alpha = self.log_alpha.detach().exp() + # The entropy of a Gaussian policy can be expressed as -log_prob + a constant (which we ignore) + entropy = -obs_result.log_prob.detach() + alpha_loss = self.alpha.update(entropy) - self._update_lagged_network_weights() - - if self.critic_gradient_step % self.actor_delay == 0: self._last_actor_loss = actor_loss.item() - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) + + self._update_lagged_network_weights() return REDQTrainingStats( # type: ignore[return-value] actor_loss=self._last_actor_loss, critic_loss=critic_loss.item(), - alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, + alpha=self.alpha.value, alpha_loss=alpha_loss, ) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 58e31fbfc..3add77a86 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -219,9 +219,8 @@ def __init__( If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. - :param alpha: entropy regularization coefficient. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, - then alpha is automatically tuned. + :param alpha: the entropy regularization coefficient alpha or an object + which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param estimation_step: The number of steps to look ahead. :param exploration_noise: add noise to action for exploration. This is useful when solving "hard exploration" problems. From 1d58866d61440efba0f2cd5c19c032dd28f86720 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 12:17:22 +0100 Subject: [PATCH 030/230] v2: Adapt BranchingDuelingQNetwork (BDQN) and test_bdqn, adding a success assertion to the test Improve docstrings of related DQN classes --- examples/box2d/bipedal_bdq.py | 4 +- test/discrete/{test_bdq.py => test_bdqn.py} | 52 ++++---- tianshou/policy/__init__.py | 4 +- tianshou/policy/modelfree/{bdq.py => bdqn.py} | 119 ++++++++++-------- tianshou/policy/modelfree/dqn.py | 40 +++--- tianshou/policy/modelfree/qrdqn.py | 43 +++---- 6 files changed, 137 insertions(+), 125 deletions(-) rename test/discrete/{test_bdq.py => test_bdqn.py} (81%) rename tianshou/policy/modelfree/{bdq.py => bdqn.py} (74%) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index fff509d5f..289bab4b8 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv -from tianshou.policy import BranchingDQNPolicy +from tianshou.policy import BranchingDuelingQNetwork from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -101,7 +101,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BranchingDQNPolicy = BranchingDQNPolicy( + policy: BranchingDuelingQNetwork = BranchingDuelingQNetwork( model=net, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdqn.py similarity index 81% rename from test/discrete/test_bdq.py rename to test/discrete/test_bdqn.py index 16042f622..55ce9bc7b 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdqn.py @@ -6,8 +6,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, DummyVectorEnv -from tianshou.policy import BranchingDQNPolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy import BranchingDuelingQNetwork +from tianshou.policy.modelfree.bdqn import BranchingDuelingQNetworkPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils.net.common import BranchingNet @@ -98,47 +99,52 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BranchingDQNPolicy = BranchingDQNPolicy( + policy = BranchingDuelingQNetworkPolicy( model=net, + action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? + ) + algorithm: BranchingDuelingQNetwork = BranchingDuelingQNetwork( + policy=policy, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, args.training_num), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) - policy.set_eps(eps) + algorithm.set_eps(eps) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + ) + ) + assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 561a76a3b..22212ed31 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -7,7 +7,7 @@ from tianshou.policy.modelfree.ddpg import DDPG from tianshou.policy.random import MARLRandomPolicy -from tianshou.policy.modelfree.bdq import BranchingDQNPolicy +from tianshou.policy.modelfree.bdqn import BranchingDuelingQNetwork from tianshou.policy.modelfree.c51 import C51 from tianshou.policy.modelfree.rainbow import RainbowPolicy from tianshou.policy.modelfree.qrdqn import QRDQN @@ -37,7 +37,7 @@ "Algorithm", "MARLRandomPolicy", "DeepQLearning", - "BranchingDQNPolicy", + "BranchingDuelingQNetwork", "C51", "RainbowPolicy", "QRDQN", diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdqn.py similarity index 74% rename from tianshou/policy/modelfree/bdq.py rename to tianshou/policy/modelfree/bdqn.py index 336bbcd0c..e93645a08 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -1,9 +1,10 @@ from dataclasses import dataclass -from typing import Any, Literal, TypeVar, cast +from typing import Any, TypeVar, cast import gymnasium as gym import numpy as np import torch +from sensai.util.helper import mark_used from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as from tianshou.data.batch import BatchProtocol @@ -16,9 +17,11 @@ ) from tianshou.policy import DeepQLearning from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNTrainingStats, TDQNPolicy +from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats from tianshou.utils.net.common import BranchingNet +mark_used(ActBatchProtocol) + @dataclass(kw_only=True) class BDQNTrainingStats(DQNTrainingStats): @@ -28,86 +31,112 @@ class BDQNTrainingStats(DQNTrainingStats): TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) -class BranchingDQNPolicy(DeepQLearning[TDQNPolicy, TBDQNTrainingStats]): - """Implementation of the Branching dual Q network arXiv:1711.08946. +class BranchingDuelingQNetworkPolicy(DQNPolicy): + def __init__( + self, + *, + model: BranchingNet, + action_space: gym.spaces.Discrete, + observation_space: gym.Space | None = None, + ): + """ + :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. + :param action_space: the environment's action space + :param observation_space: the environment's observation space. + """ + super().__init__( + model=model, + action_space=action_space, + observation_space=observation_space, + ) - :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + model: torch.nn.Module | None = None, + **kwargs: Any, + ) -> ModelOutputBatchProtocol: + if model is None: + model = self.model + obs = batch.obs + # TODO: this is very contrived, see also iqn.py + obs_next_BO = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info) + act_B = to_numpy(action_values_BA.argmax(dim=-1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + return cast(ModelOutputBatchProtocol, result) - .. seealso:: - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ +class BranchingDuelingQNetwork(DeepQLearning[BranchingDuelingQNetworkPolicy, TBDQNTrainingStats]): + """Implementation of the Branching Dueling Q-Network algorithm arXiv:1711.08946.""" def __init__( self, *, - model: BranchingNet, + policy: BranchingDuelingQNetworkPolicy, optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: policy + :param optim: the optimizer for the policy + :param discount_factor: in [0, 1]. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ assert ( estimation_step == 1 ), f"N-step bigger than one is not supported by BDQ but got: {estimation_step}" super().__init__( - model=model, + policy=policy, optim=optim, - action_space=action_space, discount_factor=discount_factor, estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, is_double=is_double, clip_loss_grad=clip_loss_grad, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) - self.model = cast(BranchingNet, self.model) # TODO: this used to be a public property called max_action_num, # but it collides with an attr of the same name in base class @property def _action_per_branch(self) -> int: - return self.model.action_per_branch + return self.policy.model.action_per_branch @property - def num_branches(self) -> int: - return self.model.num_branches + def _num_branches(self) -> int: + return self.policy.model.num_branches def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} - result = self(obs_next_batch) + result = self.policy(obs_next_batch) if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - target_q = self(obs_next_batch, model="model_old").logits + target_q = self.policy(obs_next_batch, model=self.model_old).logits else: target_q = result.logits if self.is_double: - act = np.expand_dims(self(obs_next_batch).act, -1) + act = np.expand_dims(self.policy(obs_next_batch).act, -1) act = to_torch(act, dtype=torch.long, device=target_q.device) else: act = target_q.max(-1).indices.unsqueeze(-1) @@ -129,7 +158,7 @@ def _compute_return( end_flag = end_flag[indice] mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q _target_q = rew + gamma * mean_target_q * (1 - end_flag) - target_q = np.repeat(_target_q[..., None], self.num_branches, axis=-1) + target_q = np.repeat(_target_q[..., None], self._num_branches, axis=-1) target_q = np.repeat(target_q[..., None], self._action_per_branch, axis=-1) batch.returns = to_torch_as(target_q, target_q_torch) @@ -146,22 +175,6 @@ def process_fn( """Compute the 1-step return for BDQ targets.""" return self._compute_return(batch, buffer, indices) - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - model: Literal["model", "model_old"] = "model", - **kwargs: Any, - ) -> ModelOutputBatchProtocol: - model = getattr(self, model) - obs = batch.obs - # TODO: this is very contrived, see also iqn.py - obs_next_BO = obs.obs if hasattr(obs, "obs") else obs - action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info) - act_B = to_numpy(action_values_BA.argmax(dim=-1)) - result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) - return cast(ModelOutputBatchProtocol, result) - def _update_with_batch( self, batch: RolloutBatchProtocol, @@ -173,7 +186,7 @@ def _update_with_batch( self.optim.zero_grad() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) - q = self(batch).logits + q = self.policy(batch).logits act_mask = torch.zeros_like(q) act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) act_q = q * act_mask diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 2408b559b..c348daeb6 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -43,6 +43,11 @@ def __init__( action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, ) -> None: + """ + :param model: a model mapping (obs, state, info) to action_values_BA. + :param action_space: the environment's action space + :param observation_space: the environment's observation space. + """ super().__init__( action_space=action_space, observation_space=observation_space, @@ -120,26 +125,6 @@ class DeepQLearning( Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is implemented in the network side, not here). - - :param model: a model following the rules (s -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. """ def __init__( @@ -156,6 +141,21 @@ def __init__( clip_loss_grad: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param optim: the optimizer for the policy + :param discount_factor: in [0, 1]. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ super().__init__( policy=policy, lr_scheduler=lr_scheduler, diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 1458e0e15..e1f65dcc6 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -27,31 +27,7 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc class QRDQN(DeepQLearning[QRDQNPolicy, TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): - """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. - - :param model: a model following the rules (s -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param action_space: Env's action space. - :param discount_factor: in [0, 1]. - :param num_quantiles: the number of quantile midpoints in the inverse - cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed - explanation. - """ + """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.""" def __init__( self, @@ -67,6 +43,23 @@ def __init__( clip_loss_grad: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: the policy + :param optim: the optimizer for the policy + :param discount_factor: in [0, 1]. + :param num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" super().__init__( policy=policy, From 384590570b2fbcc6716ec89349910b76d53e0e5e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 21:40:53 +0100 Subject: [PATCH 031/230] v2: Adapt FQF and test_fqf --- examples/atari/atari_fqf.py | 4 +- test/discrete/test_fqf.py | 61 ++++---- tianshou/policy/__init__.py | 4 +- tianshou/policy/imitation/discrete_cql.py | 2 +- tianshou/policy/modelfree/fqf.py | 171 ++++++++++++---------- tianshou/policy/modelfree/iqn.py | 2 +- tianshou/policy/modelfree/qrdqn.py | 9 +- 7 files changed, 138 insertions(+), 115 deletions(-) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 04d7905d7..ceac7bb2d 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -11,7 +11,7 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import FQFPolicy +from tianshou.policy import FQF from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -99,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None: fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) # define policy - policy: FQFPolicy = FQFPolicy( + policy: FQF = FQF( model=net, optim=optim, fraction_model=fraction_net, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index b68ec6157..0f0fec9f9 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -14,9 +14,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import FQFPolicy +from tianshou.policy import FQF from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.fqf import FQFPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -100,12 +101,15 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) - policy: FQFPolicy = FQFPolicy( + policy = FQFPolicy( model=net, - optim=optim, fraction_model=fraction_net, - fraction_optim=fraction_optim, action_space=env.action_space, + ) + algorithm: FQF = FQF( + policy=policy, + optim=optim, + fraction_optim=fraction_optim, discount_factor=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, @@ -124,8 +128,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -143,33 +147,34 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + algorithm.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + algorithm.set_eps(eps) else: - policy.set_eps(0.1 * args.eps_train) + algorithm.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 22212ed31..b61099081 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -12,7 +12,7 @@ from tianshou.policy.modelfree.rainbow import RainbowPolicy from tianshou.policy.modelfree.qrdqn import QRDQN from tianshou.policy.modelfree.iqn import IQNPolicy -from tianshou.policy.modelfree.fqf import FQFPolicy +from tianshou.policy.modelfree.fqf import FQF from tianshou.policy.modelfree.a2c import A2C from tianshou.policy.modelfree.npg import NPG from tianshou.policy.modelfree.ppo import PPO @@ -42,7 +42,7 @@ "RainbowPolicy", "QRDQN", "IQNPolicy", - "FQFPolicy", + "FQF", "Reinforce", "A2C", "NPG", diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 3087b354e..f9b332128 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -22,7 +22,7 @@ class DiscreteCQLTrainingStats(QRDQNTrainingStats): TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteCQLTrainingStats) -class DiscreteCQLPolicy(QRDQN[TDiscreteCQLTrainingStats]): +class DiscreteCQLPolicy(QRDQN): """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 89cd5d7c5..8e4609707 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Literal, TypeVar, cast +from typing import Any, TypeVar, cast import gymnasium as gym import numpy as np @@ -8,9 +8,10 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import QRDQN, DeepQLearning +from tianshou.policy import QRDQN from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -24,101 +25,39 @@ class FQFTrainingStats(QRDQNTrainingStats): TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) -class FQFPolicy(QRDQN[TFQFTrainingStats]): - """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. - - :param model: a model following the rules (s_B -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param fraction_model: a FractionProposalNetwork for - proposing fractions/quantiles given state. - :param fraction_optim: a torch.optim for optimizing - the fraction model above. - :param action_space: Env's action space. - :param discount_factor: in [0, 1]. - :param num_fractions: the number of fractions to use. - :param ent_coef: the coefficient for entropy loss. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed - explanation. - """ - +class FQFPolicy(QRDQNPolicy): def __init__( self, *, model: FullQuantileFunction, - optim: torch.optim.Optimizer, fraction_model: FractionProposalNetwork, - fraction_optim: torch.optim.Optimizer, action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - # TODO: used as num_quantiles in QRDQNPolicy, but num_fractions in FQFPolicy. - # Rename? Or at least explain what happens here. - num_fractions: int = 32, - ent_coef: float = 0.0, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: + ): + """ + :param model: a model following the rules (s_B -> action_values_BA) + :param fraction_model: a FractionProposalNetwork for + proposing fractions/quantiles given state. + :param action_space: the environment's action space + :param observation_space: the environment's observation space. + """ super().__init__( model=model, - optim=optim, action_space=action_space, - discount_factor=discount_factor, - num_quantiles=num_fractions, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, observation_space=observation_space, - lr_scheduler=lr_scheduler, ) self.fraction_model = fraction_model - self.ent_coef = ent_coef - self.fraction_optim = fraction_optim - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - if self._target: - result = self(obs_next_batch) - act, fractions = result.act, result.fractions - next_dist = self(obs_next_batch, model="model_old", fractions=fractions).logits - else: - next_batch = self(obs_next_batch) - act = next_batch.act - next_dist = next_batch.logits - return next_dist[np.arange(len(act)), act, :] - - # TODO: fix Liskov substitution principle violation def forward( # type: ignore self, batch: ObsBatchProtocol, state: dict | Batch | np.ndarray | None = None, - model: Literal["model", "model_old"] = "model", + model: FullQuantileFunction | None = None, fractions: Batch | None = None, **kwargs: Any, ) -> FQFBatchProtocol: - model = getattr(self, model) + if model is None: + model = self.model obs = batch.obs # TODO: this is convoluted! See also other places where this is done obs_next = obs.obs if hasattr(obs, "obs") else obs @@ -138,7 +77,7 @@ def forward( # type: ignore info=batch.info, ) weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits - q = DeepQLearning.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) + q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) if self.max_action_num is None: # type: ignore # TODO: see same thing in DQNPolicy! Also reduce code duplication. self.max_action_num = q.shape[1] @@ -152,6 +91,80 @@ def forward( # type: ignore ) return cast(FQFBatchProtocol, result) + +class FQF(QRDQN[FQFPolicy, TFQFTrainingStats]): + """Implementation of Fully Parameterized Quantile Function for Distributional Reinforcement Learning. arXiv:1911.02140.""" + + def __init__( + self, + *, + policy: FQFPolicy, + optim: torch.optim.Optimizer, + fraction_optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + # TODO: used as num_quantiles in QRDQNPolicy, but num_fractions in FQFPolicy. + # Rename? Or at least explain what happens here. + num_fractions: int = 32, + ent_coef: float = 0.0, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer for the policy's main Q-function model + :param fraction_optim: the optimizer for the policy's fraction model + :param action_space: Env's action space. + :param discount_factor: in [0, 1]. + :param num_fractions: the number of fractions to use. + :param ent_coef: the coefficient for entropy loss. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param observation_space: Env's observation space. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + optim=optim, + discount_factor=discount_factor, + num_quantiles=num_fractions, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + lr_scheduler=lr_scheduler, + ) + self.ent_coef = ent_coef + self.fraction_optim = fraction_optim + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + if self._target: + result = self.policy(obs_next_batch) + act, fractions = result.act, result.fractions + next_dist = self.policy( + obs_next_batch, model=self.model_old, fractions=fractions + ).logits + else: + next_batch = self.policy(obs_next_batch) + act = next_batch.act + next_dist = next_batch.logits + return next_dist[np.arange(len(act)), act, :] + def _update_with_batch( self, batch: RolloutBatchProtocol, @@ -161,7 +174,7 @@ def _update_with_batch( if self._target and self._iter % self.freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) - out = self(batch) + out = self.policy(batch) curr_dist_orig = out.logits taus, tau_hats = out.fractions.taus, out.fractions.tau_hats act = batch.act diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index e01d5eada..116f3785a 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -26,7 +26,7 @@ class IQNTrainingStats(QRDQNTrainingStats): TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) -class IQNPolicy(QRDQN[TIQNTrainingStats]): +class IQNPolicy(QRDQN): """Implementation of Implicit Quantile Network. arXiv:1806.06923. :param model: a model following the rules (s_B -> action_values_BA) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index e1f65dcc6..b4c2ec9c1 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -26,13 +26,18 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc return super().compute_q_value(logits.mean(2), mask) -class QRDQN(DeepQLearning[QRDQNPolicy, TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): +TQRDQNPolicy = TypeVar("TQRDQNPolicy", bound=QRDQNPolicy) + + +class QRDQN( + DeepQLearning[TQRDQNPolicy, TQRDQNTrainingStats], Generic[TQRDQNPolicy, TQRDQNTrainingStats] +): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.""" def __init__( self, *, - policy: QRDQNPolicy, + policy: TQRDQNPolicy, optim: torch.optim.Optimizer, discount_factor: float = 0.99, num_quantiles: int = 200, From cbc478e759170290adc794368d86a4a95fc5a090 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 21:53:48 +0100 Subject: [PATCH 032/230] v2: Adapt IQN and test_iqn --- examples/atari/atari_iqn.py | 4 +- test/discrete/test_iqn.py | 61 +++++++++-------- tianshou/highlevel/algorithm.py | 8 +-- tianshou/policy/__init__.py | 4 +- tianshou/policy/modelfree/iqn.py | 112 ++++++++++++++++--------------- 5 files changed, 99 insertions(+), 90 deletions(-) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 869c0e158..587295e6d 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -11,7 +11,7 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import IQNPolicy +from tianshou.policy import IQN from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -97,7 +97,7 @@ def main(args: argparse.Namespace = get_args()) -> None: ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: IQNPolicy = IQNPolicy( + policy: IQN = IQN( model=net, optim=optim, action_space=env.action_space, diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index bbfc3a71a..e2f4eac09 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -14,9 +14,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import IQNPolicy +from tianshou.policy import IQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.iqn import IQNPolicy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -97,14 +98,17 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: device=args.device, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: IQNPolicy = IQNPolicy( + policy = IQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, sample_size=args.sample_size, online_sample_size=args.online_sample_size, target_sample_size=args.target_sample_size, + ) + algorithm: IQN = IQN( + policy=policy, + optim=optim, + discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) @@ -120,8 +124,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -139,33 +143,34 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + algorithm.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + algorithm.set_eps(eps) else: - policy.set_eps(0.1 * args.eps_train) + algorithm.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 0667f6e24..b6f2b6718 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -48,6 +48,7 @@ from tianshou.policy import ( A2C, DDPG, + IQN, NPG, PPO, REDQ, @@ -57,7 +58,6 @@ Algorithm, DeepQLearning, DiscreteSAC, - IQNPolicy, Reinforce, ) from tianshou.policy.base import ( @@ -493,9 +493,9 @@ def _get_algorithm_class(self) -> type[DeepQLearning]: return DeepQLearning -class IQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[IQNParams, IQNPolicy]): - def _get_algorithm_class(self) -> type[IQNPolicy]: - return IQNPolicy +class IQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[IQNParams, IQN]): + def _get_algorithm_class(self) -> type[IQN]: + return IQN class DDPGAlgorithmFactory(OffPolicyAlgorithmFactory): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index b61099081..9a4f92b53 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -11,7 +11,7 @@ from tianshou.policy.modelfree.c51 import C51 from tianshou.policy.modelfree.rainbow import RainbowPolicy from tianshou.policy.modelfree.qrdqn import QRDQN -from tianshou.policy.modelfree.iqn import IQNPolicy +from tianshou.policy.modelfree.iqn import IQN from tianshou.policy.modelfree.fqf import FQF from tianshou.policy.modelfree.a2c import A2C from tianshou.policy.modelfree.npg import NPG @@ -41,7 +41,7 @@ "C51", "RainbowPolicy", "QRDQN", - "IQNPolicy", + "IQN", "FQF", "Reinforce", "A2C", diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 116f3785a..829d28321 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Literal, TypeVar, cast +from typing import Any, TypeVar, cast import gymnasium as gym import numpy as np @@ -15,7 +15,7 @@ ) from tianshou.policy import QRDQN from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats @dataclass(kw_only=True) @@ -26,53 +26,16 @@ class IQNTrainingStats(QRDQNTrainingStats): TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) -class IQNPolicy(QRDQN): - """Implementation of Implicit Quantile Network. arXiv:1806.06923. - - :param model: a model following the rules (s_B -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param sample_size: the number of samples for policy evaluation. - :param online_sample_size: the number of samples for online model - in training. - :param target_sample_size: the number of samples for target model - in training. - :param num_quantiles: the number of quantile midpoints in the inverse - cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed - explanation. - """ - +class IQNPolicy(QRDQNPolicy): def __init__( self, *, model: torch.nn.Module, - optim: torch.optim.Optimizer, action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, sample_size: int = 32, online_sample_size: int = 8, target_sample_size: int = 8, - num_quantiles: int = 200, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: assert sample_size > 1, f"sample_size should be greater than 1 but got: {sample_size}" assert ( @@ -83,19 +46,10 @@ def __init__( ), f"target_sample_size should be greater than 1 but got: {target_sample_size}" super().__init__( model=model, - optim=optim, action_space=action_space, - discount_factor=discount_factor, - num_quantiles=num_quantiles, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, observation_space=observation_space, - lr_scheduler=lr_scheduler, ) - self.sample_size = sample_size # for policy eval + self.sample_size = sample_size self.online_sample_size = online_sample_size self.target_sample_size = target_sample_size @@ -103,16 +57,18 @@ def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, - model: Literal["model", "model_old"] = "model", + model: torch.nn.Module | None = None, **kwargs: Any, ) -> QuantileRegressionBatchProtocol: - if model == "model_old": + is_model_old = model is not None + if is_model_old: sample_size = self.target_sample_size elif self.training: sample_size = self.online_sample_size else: sample_size = self.sample_size - model = getattr(self, model) + if model is None: + model = self.model obs = batch.obs # TODO: this seems very contrived! obs_next = obs.obs if hasattr(obs, "obs") else obs @@ -130,6 +86,54 @@ def forward( result = Batch(logits=logits, act=act, state=hidden, taus=taus) return cast(QuantileRegressionBatchProtocol, result) + +class IQN(QRDQN[IQNPolicy, TIQNTrainingStats]): + """Implementation of Implicit Quantile Network. arXiv:1806.06923.""" + + def __init__( + self, + *, + policy: IQNPolicy, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + num_quantiles: int = 200, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer for the policy's model + :param discount_factor: in [0, 1]. + :param num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + optim=optim, + discount_factor=discount_factor, + num_quantiles=num_quantiles, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + is_double=is_double, + clip_loss_grad=clip_loss_grad, + lr_scheduler=lr_scheduler, + ) + def _update_with_batch( self, batch: RolloutBatchProtocol, @@ -140,7 +144,7 @@ def _update_with_batch( self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) - action_batch = self(batch) + action_batch = self.policy(batch) curr_dist, taus = action_batch.logits, action_batch.taus act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) From fb28e479c18524bdc4629937b3fbc5ccedd113c7 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 22:02:27 +0100 Subject: [PATCH 033/230] v2: Adapt RainbowDQN and test_rainbow --- examples/atari/atari_rainbow.py | 7 ++- test/discrete/test_rainbow.py | 70 +++++++++++++++------------- tianshou/policy/__init__.py | 4 +- tianshou/policy/modelfree/rainbow.py | 44 ++++++++--------- 4 files changed, 62 insertions(+), 63 deletions(-) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 71a6a97d8..294fdb7e6 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -13,10 +13,9 @@ PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) -from tianshou.env.atari.atari_network import Rainbow from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51, RainbowPolicy +from tianshou.policy import C51, RainbowDQN from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer @@ -98,7 +97,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = Rainbow( + net = RainbowDQN( *args.state_shape, args.action_shape, args.num_atoms, @@ -109,7 +108,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: C51 = RainbowPolicy( + policy: C51 = RainbowDQN( model=net, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index cb80c460d..8b6c6fc15 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -14,10 +14,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import RainbowPolicy +from tianshou.policy import RainbowDQN from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.rainbow import RainbowTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear @@ -104,14 +104,17 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: dueling_param=({"linear_layer": noisy_linear}, {"linear_layer": noisy_linear}), ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: RainbowPolicy[RainbowTrainingStats] = RainbowPolicy( + policy = C51Policy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + ) + algorithm = RainbowDQN( + policy=policy, + optim=optim, + discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) @@ -128,8 +131,8 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -147,12 +150,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + algorithm.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + algorithm.set_eps(eps) else: - policy.set_eps(0.1 * args.eps_train) + algorithm.set_eps(0.1 * args.eps_train) # beta annealing, just a demo if args.prioritized_replay: if env_step <= 10000: @@ -164,7 +167,7 @@ def train_fn(epoch: int, env_step: int) -> None: buf.set_beta(beta) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + algorithm.set_eps(args.eps_test) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html @@ -173,7 +176,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - "model": policy.state_dict(), + "model": algorithm.state_dict(), "optim": optim.state_dict(), }, ckpt_path, @@ -189,8 +192,8 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) - policy.optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint["model"]) + algorithm.optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") @@ -203,24 +206,25 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: print("Fail to restore buffer.") # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 9a4f92b53..5a3930f94 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -9,7 +9,7 @@ from tianshou.policy.random import MARLRandomPolicy from tianshou.policy.modelfree.bdqn import BranchingDuelingQNetwork from tianshou.policy.modelfree.c51 import C51 -from tianshou.policy.modelfree.rainbow import RainbowPolicy +from tianshou.policy.modelfree.rainbow import RainbowDQN from tianshou.policy.modelfree.qrdqn import QRDQN from tianshou.policy.modelfree.iqn import IQN from tianshou.policy.modelfree.fqf import FQF @@ -39,7 +39,7 @@ "DeepQLearning", "BranchingDuelingQNetwork", "C51", - "RainbowPolicy", + "RainbowDQN", "QRDQN", "IQN", "FQF", diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index fc0af2cf4..6232045dc 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -9,24 +9,6 @@ from tianshou.utils.net.discrete import NoisyLinear -# TODO: this is a hacky thing interviewing side-effects and a return. Should improve. -def _sample_noise(model: nn.Module) -> bool: - """Sample the random noises of NoisyLinear modules in the model. - - Returns True if at least one NoisyLinear submodule was found. - - :param model: a PyTorch module which may have NoisyLinear submodules. - :returns: True if model has at least one NoisyLinear submodule; - otherwise, False. - """ - sampled_any_noise = False - for m in model.modules(): - if isinstance(m, NoisyLinear): - m.sample() - sampled_any_noise = True - return sampled_any_noise - - @dataclass(kw_only=True) class RainbowTrainingStats(C51TrainingStats): loss: float @@ -35,25 +17,39 @@ class RainbowTrainingStats(C51TrainingStats): TRainbowTrainingStats = TypeVar("TRainbowTrainingStats", bound=RainbowTrainingStats) -# TODO: is this class worth keeping? It barely does anything -class RainbowPolicy(C51[TRainbowTrainingStats]): +class RainbowDQN(C51[TRainbowTrainingStats]): """Implementation of Rainbow DQN. arXiv:1710.02298. - Same parameters as :class:`~tianshou.policy.C51Policy`. - .. seealso:: Please refer to :class:`~tianshou.policy.C51Policy` for more detailed explanation. """ + @staticmethod + def _sample_noise(model: nn.Module) -> bool: + """Sample the random noises of NoisyLinear modules in the model. + + Returns True if at least one NoisyLinear submodule was found. + + :param model: a PyTorch module which may have NoisyLinear submodules. + :returns: True if model has at least one NoisyLinear submodule; + otherwise, False. + """ + sampled_any_noise = False + for m in model.modules(): + if isinstance(m, NoisyLinear): + m.sample() + sampled_any_noise = True + return sampled_any_noise + def _update_with_batch( self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any, ) -> TRainbowTrainingStats: - _sample_noise(self.model) - if self._target and _sample_noise(self.model_old): + self._sample_noise(self.policy.model) + if self._target and self._sample_noise(self.model_old): self.model_old.train() # so that NoisyLinear takes effect return super()._update_with_batch(batch, **kwargs) From 2d61f78d6d5e7c56f4f5e70497321d827f4a91b4 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 22:12:11 +0100 Subject: [PATCH 034/230] v2: Adapt ImitationLearning, test_a2c_with_il and test_sac_with_il --- examples/offline/atari_il.py | 6 +- examples/offline/d4rl_il.py | 4 +- test/continuous/test_sac_with_il.py | 44 ++++++++------- test/discrete/test_a2c_with_il.py | 47 ++++++++-------- tianshou/policy/__init__.py | 4 +- tianshou/policy/imitation/base.py | 85 ++++++++++++++++++----------- 6 files changed, 111 insertions(+), 79 deletions(-) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index ffc7c5457..d69a55c89 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -15,7 +15,7 @@ from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import ImitationPolicy +from tianshou.policy import ImitationLearning from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils.space_info import SpaceInfo @@ -90,7 +90,9 @@ def test_il(args: argparse.Namespace = get_args()) -> None: net = DQNet(c, h, w, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: ImitationPolicy = ImitationPolicy(actor=net, optim=optim, action_space=env.action_space) + policy: ImitationLearning = ImitationLearning( + actor=net, optim=optim, action_space=env.action_space + ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 91c998cad..59b5d996a 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -13,7 +13,7 @@ from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import ImitationPolicy +from tianshou.policy import ImitationLearning from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger @@ -96,7 +96,7 @@ def test_il() -> None: ).to(args.device) optim = torch.optim.Adam(actor.parameters(), lr=args.lr) - policy: ImitationPolicy = ImitationPolicy( + policy: ImitationLearning = ImitationLearning( actor=actor, optim=optim, action_space=env.action_space, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 2c80429bf..3ad6f0ad3 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -8,10 +8,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import SAC, ImitationPolicy +from tianshou.policy import SAC, ImitationLearning from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import SACPolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.imitation.base import ImitationPolicy +from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -110,7 +110,7 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) policy = SACPolicy( actor=actor, @@ -181,32 +181,36 @@ def stop_fn(mean_rewards: float) -> bool: device=args.device, ).to(args.device) optim = torch.optim.Adam(il_actor.parameters(), lr=args.il_lr) - il_policy: ImitationPolicy = ImitationPolicy( + il_policy = ImitationPolicy( actor=il_actor, - optim=optim, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) + il_algorithm: ImitationLearning = ImitationLearning( + policy=il_policy, + optim=optim, + ) il_test_env = gym.make(args.task) il_test_env.reset(seed=args.seed + args.training_num + args.test_num) il_test_collector = Collector[CollectStats]( - il_policy, + il_algorithm, # envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), il_test_env, ) train_collector.reset() - result = OffpolicyTrainer( - policy=il_policy, - train_collector=train_collector, - test_collector=il_test_collector, - max_epoch=args.epoch, - step_per_epoch=args.il_step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = il_algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=il_test_collector, + max_epoch=args.epoch, + step_per_epoch=args.il_step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index eab6407df..f2070ce8e 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -8,12 +8,12 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import A2C, ImitationPolicy +from tianshou.env import DummyVectorEnv +from tianshou.policy import A2C, ImitationLearning from tianshou.policy.base import Algorithm +from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.pg import ActorPolicy -from tianshou.trainer import OffpolicyTrainer -from tianshou.trainer.base import OnPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic @@ -158,11 +158,14 @@ def stop_fn(mean_rewards: float) -> bool: net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(actor.parameters(), lr=args.il_lr) - il_policy: ImitationPolicy = ImitationPolicy( + il_policy = ImitationPolicy( actor=actor, - optim=optim, action_space=env.action_space, ) + il_algorithm: ImitationLearning = ImitationLearning( + policy=il_policy, + optim=optim, + ) if envpool is not None: il_env = envpool.make( args.task, @@ -171,28 +174,28 @@ def stop_fn(mean_rewards: float) -> bool: seed=args.seed, ) else: - il_env = SubprocVectorEnv( + il_env = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - context="fork", ) il_env.seed(args.seed) il_test_collector = Collector[CollectStats]( - il_policy, + il_algorithm, il_env, ) train_collector.reset() - result = OffpolicyTrainer( - policy=il_policy, - train_collector=train_collector, - test_collector=il_test_collector, - max_epoch=args.epoch, - step_per_epoch=args.il_step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = il_algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=il_test_collector, + max_epoch=args.epoch, + step_per_epoch=args.il_step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 5a3930f94..86925062a 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -21,7 +21,7 @@ from tianshou.policy.modelfree.sac import SAC from tianshou.policy.modelfree.redq import REDQ from tianshou.policy.modelfree.discrete_sac import DiscreteSAC -from tianshou.policy.imitation.base import ImitationPolicy +from tianshou.policy.imitation.base import ImitationLearning from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.imitation.cql import CQLPolicy from tianshou.policy.imitation.td3_bc import TD3BCPolicy @@ -53,7 +53,7 @@ "SAC", "REDQ", "DiscreteSAC", - "ImitationPolicy", + "ImitationLearning", "BCQPolicy", "CQLPolicy", "TD3BCPolicy", diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index ecf66578e..eb6882440 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -13,8 +13,12 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import Algorithm -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import ( + OffPolicyAlgorithm, + Policy, + TLearningRateScheduler, + TrainingStats, +) # Dimension Naming Convention # B - Batch Size @@ -31,46 +35,34 @@ class ImitationTrainingStats(TrainingStats): TImitationTrainingStats = TypeVar("TImitationTrainingStats", bound=ImitationTrainingStats) -class ImitationPolicy(Algorithm, Generic[TImitationTrainingStats]): - """Implementation of vanilla imitation learning. - - :param actor: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param optim: for optimizing the model. - :param action_space: Env's action_space. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - +class ImitationPolicy(Policy): def __init__( self, *, actor: torch.nn.Module, - optim: torch.optim.Optimizer, action_space: gym.Space, observation_space: gym.Space | None = None, action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: + ): + """ + :param actor: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param optim: for optimizing the model. + :param action_space: Env's action_space. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, ) self.actor = actor - self.optim = optim def forward( self, @@ -94,6 +86,35 @@ def forward( raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!") return cast(ModelOutputBatchProtocol, result) + +class ImitationLearning(OffPolicyAlgorithm, Generic[TImitationTrainingStats]): + """Implementation of vanilla imitation learning.""" + + def __init__( + self, + *, + policy: ImitationPolicy, + optim: torch.optim.Optimizer, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param optim: for optimizing the model. + :param action_space: Env's action_space. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + lr_scheduler=lr_scheduler, + ) + self.optim = optim + def _update_with_batch( self, batch: RolloutBatchProtocol, @@ -101,14 +122,16 @@ def _update_with_batch( **kwargs: Any, ) -> TImitationTrainingStats: self.optim.zero_grad() - if self.action_type == "continuous": # regression - act = self(batch).act + if self.policy.action_type == "continuous": # regression + act = self.policy(batch).act act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) loss = F.mse_loss(act, act_target) - elif self.action_type == "discrete": # classification - act = F.log_softmax(self(batch).logits, dim=-1) + elif self.policy.action_type == "discrete": # classification + act = F.log_softmax(self.policy(batch).logits, dim=-1) act_target = to_torch(batch.act, dtype=torch.long, device=act.device) loss = F.nll_loss(act, act_target) + else: + raise ValueError(self.policy.action_type) loss.backward() self.optim.step() From d1f3962949af0401d4d9512cac4f7364adb77e84 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 9 Mar 2025 20:24:48 +0100 Subject: [PATCH 035/230] v2: Rename Trainer classes, correcting capitalisation * OffpolicyTrainer -> OffPolicyTrainer * OnpolicyTrainer -> OnPolicyTrainer --- README.md | 2 +- examples/atari/atari_dqn.py | 4 ++-- examples/atari/atari_fqf.py | 4 ++-- examples/atari/atari_iqn.py | 4 ++-- examples/atari/atari_ppo.py | 4 ++-- examples/atari/atari_qrdqn.py | 4 ++-- examples/atari/atari_rainbow.py | 4 ++-- examples/atari/atari_sac.py | 4 ++-- examples/box2d/acrobot_dualdqn.py | 4 ++-- examples/box2d/bipedal_bdq.py | 4 ++-- examples/box2d/bipedal_hardcore_sac.py | 4 ++-- examples/box2d/mcc_sac.py | 4 ++-- examples/inverse/irl_gail.py | 4 ++-- examples/mujoco/fetch_her_ddpg.py | 4 ++-- examples/mujoco/mujoco_a2c.py | 4 ++-- examples/mujoco/mujoco_ddpg.py | 4 ++-- examples/mujoco/mujoco_npg.py | 4 ++-- examples/mujoco/mujoco_ppo.py | 4 ++-- examples/mujoco/mujoco_redq.py | 4 ++-- examples/mujoco/mujoco_reinforce.py | 4 ++-- examples/mujoco/mujoco_sac.py | 4 ++-- examples/mujoco/mujoco_td3.py | 4 ++-- examples/mujoco/mujoco_trpo.py | 4 ++-- examples/vizdoom/vizdoom_ppo.py | 4 ++-- test/discrete/test_drqn.py | 4 ++-- test/discrete/test_ppo.py | 4 ++-- test/modelbased/test_dqn_icm.py | 4 ++-- test/modelbased/test_ppo_icm.py | 4 ++-- test/modelbased/test_psrl.py | 4 ++-- test/offline/gather_cartpole_data.py | 4 ++-- test/offline/test_gail.py | 4 ++-- test/pettingzoo/pistonball.py | 4 ++-- test/pettingzoo/pistonball_continuous.py | 4 ++-- test/pettingzoo/tic_tac_toe.py | 4 ++-- tianshou/highlevel/algorithm.py | 6 +++--- tianshou/policy/base.py | 18 ++++++++++-------- tianshou/trainer/__init__.py | 8 ++++---- tianshou/trainer/base.py | 4 ++-- 38 files changed, 86 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index 6582eea5c..fb6a6a9de 100644 --- a/README.md +++ b/README.md @@ -385,7 +385,7 @@ test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # Let's train it: ```python -result = ts.trainer.OffpolicyTrainer( +result = ts.trainer.OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 18d6b1184..7374d7a8d 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -14,7 +14,7 @@ from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMPolicy -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -236,7 +236,7 @@ def watch() -> None: train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index ceac7bb2d..ff1a81ccf 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -13,7 +13,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import FQF from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -205,7 +205,7 @@ def watch() -> None: train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 587295e6d..5c9b4a386 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -13,7 +13,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import IQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -203,7 +203,7 @@ def watch() -> None: train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 0b441add7..6a5d19f23 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -15,7 +15,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPO, ICMPolicy from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -262,7 +262,7 @@ def watch() -> None: train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index f4496dbc8..0cdd9a8a9 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -13,7 +13,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer def get_args() -> argparse.Namespace: @@ -196,7 +196,7 @@ def watch() -> None: train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 294fdb7e6..9fb0ec7b6 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -17,7 +17,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51, RainbowDQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer def get_args() -> argparse.Namespace: @@ -236,7 +236,7 @@ def watch() -> None: train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 08c2c6390..651d9bcb1 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -13,7 +13,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteSAC, ICMPolicy from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -245,7 +245,7 @@ def watch() -> None: train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 47bee824b..52e290cc1 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -123,7 +123,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 289bab4b8..7b35b16c8 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -12,7 +12,7 @@ from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.policy import BranchingDuelingQNetwork from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet @@ -141,7 +141,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 0961347c5..d1a3d0475 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -13,7 +13,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.policy import SAC from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -188,7 +188,7 @@ def stop_fn(mean_rewards: float) -> bool: return False # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 45246dc60..21336ed3e 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -12,7 +12,7 @@ from tianshou.exploration import OUNoise from tianshou.policy import SAC from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -134,7 +134,7 @@ def stop_fn(mean_rewards: float) -> bool: return False # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 10c10de1d..752624504 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -26,7 +26,7 @@ from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.policy import GAILPolicy from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -257,7 +257,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 888686300..3f2c24e32 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -24,7 +24,7 @@ from tianshou.exploration import GaussianNoise from tianshou.policy import DDPG from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic from tianshou.env.venvs import BaseVectorEnv @@ -222,7 +222,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 4069858c4..e5a18e1ac 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -16,7 +16,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import A2C from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -202,7 +202,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index d63390be4..869267d91 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -14,7 +14,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPG from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -151,7 +151,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index db5af6818..bcb55ba59 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -16,7 +16,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import NPG from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -199,7 +199,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index da18ea3ab..08a802470 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -16,7 +16,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPO from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -207,7 +207,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 48625ca2e..2391a81a5 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -13,7 +13,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import REDQ from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -179,7 +179,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index e49124e92..c1b6f8e47 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -16,7 +16,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import Reinforce from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb @@ -179,7 +179,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 1ca6fb396..baae832e9 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -13,7 +13,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import SAC from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -173,7 +173,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 46c8fe509..8bc814f93 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -14,7 +14,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TD3 from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -171,7 +171,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 350439a70..141924118 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -16,7 +16,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TRPO from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -204,7 +204,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 7ba1bf503..2c32562a1 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -15,7 +15,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPO, ICMPolicy from tianshou.policy.base import Algorithm -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -264,7 +264,7 @@ def watch() -> None: train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index b3e8bb381..5cbf30325 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent from tianshou.utils.space_info import SpaceInfo @@ -113,7 +113,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index a7c80caa7..491a7c3c8 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -13,7 +13,7 @@ from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net from tianshou.utils.net.discrete import Actor, Critic @@ -135,7 +135,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 13189cadb..f7ee3cac2 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -16,7 +16,7 @@ from tianshou.policy import DeepQLearning, ICMPolicy from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -184,7 +184,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 544107e51..ee1c46660 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -12,7 +12,7 @@ from tianshou.policy import PPO, ICMPolicy from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -173,7 +173,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index e79977381..d94f1cd1f 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -8,7 +8,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.policy import PSRLPolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger try: @@ -104,7 +104,7 @@ def stop_fn(mean_rewards: float) -> bool: train_collector.collect(n_step=args.buffer_size, random=True) # trainer, test it without logger - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 8c5252f72..11d0f32e7 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -17,7 +17,7 @@ from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -149,7 +149,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 099b96d0a..298cc91f4 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import Algorithm, GAILPolicy -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -204,7 +204,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: print("Fail to restore policy and optim.") # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 05ef337c8..d17db94f7 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import Algorithm, DeepQLearning, MultiAgentPolicyManager -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -159,7 +159,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index ef61508fa..4c476512d 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -16,7 +16,7 @@ from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import PPO, Algorithm, MultiAgentPolicyManager -from tianshou.trainer import OnpolicyTrainer +from tianshou.trainer import OnPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.continuous import ActorProb, Critic @@ -258,7 +258,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer - result = OnpolicyTrainer( + result = OnPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index af17c24b7..9ad50f9c2 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -19,7 +19,7 @@ MARLRandomPolicy, MultiAgentPolicyManager, ) -from tianshou.trainer import OffpolicyTrainer +from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -204,7 +204,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, args.agent_id - 1] # trainer - result = OffpolicyTrainer( + result = OffPolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index b6f2b6718..506dec695 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -69,7 +69,7 @@ from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.pg import ActorPolicy -from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer +from tianshou.trainer import BaseTrainer, OffPolicyTrainer, OnPolicyTrainer from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor @@ -186,7 +186,7 @@ def create_trainer( self, world: World, policy_persistence: PolicyPersistence, - ) -> OnpolicyTrainer: + ) -> OnPolicyTrainer: sampling_config = self.sampling_config callbacks = self.trainer_callbacks context = TrainingContext(world.policy, world.envs, world.logger) @@ -233,7 +233,7 @@ def create_trainer( self, world: World, policy_persistence: PolicyPersistence, - ) -> OffpolicyTrainer: + ) -> OffPolicyTrainer: sampling_config = self.sampling_config callbacks = self.trainer_callbacks context = TrainingContext(world.policy, world.envs, world.logger) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 6be8ced61..775b1df1e 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -32,9 +32,11 @@ if TYPE_CHECKING: from tianshou.trainer.base import ( BaseTrainer, - OffpolicyTrainer, + OfflineTrainer, + OfflineTrainingConfig, + OffPolicyTrainer, OffPolicyTrainingConfig, - OnpolicyTrainer, + OnPolicyTrainer, OnPolicyTrainingConfig, ) @@ -786,10 +788,10 @@ class OnPolicyAlgorithm( Generic[TPolicy, TTrainingStats], ABC, ): - def create_trainer(self, config: "OnPolicyTrainingConfig") -> "OnpolicyTrainer": - from tianshou.trainer.base import OnpolicyTrainer + def create_trainer(self, config: "OnPolicyTrainingConfig") -> "OnPolicyTrainer": + from tianshou.trainer.base import OnPolicyTrainer - return OnpolicyTrainer(self, config) + return OnPolicyTrainer(self, config) class OffPolicyAlgorithm( @@ -797,10 +799,10 @@ class OffPolicyAlgorithm( Generic[TPolicy, TTrainingStats], ABC, ): - def create_trainer(self, config: "OffPolicyTrainingConfig") -> "OffpolicyTrainer": - from tianshou.trainer.base import OffpolicyTrainer + def create_trainer(self, config: "OffPolicyTrainingConfig") -> "OffPolicyTrainer": + from tianshou.trainer.base import OffPolicyTrainer - return OffpolicyTrainer(self, config) + return OffPolicyTrainer(self, config) # TODO must become Policy not Algorithm diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 5946555a2..a4eaa4118 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -3,15 +3,15 @@ from tianshou.trainer.base import ( BaseTrainer, OfflineTrainer, - OffpolicyTrainer, - OnpolicyTrainer, + OffPolicyTrainer, + OnPolicyTrainer, ) from tianshou.trainer.utils import gather_info, test_episode __all__ = [ "BaseTrainer", - "OffpolicyTrainer", - "OnpolicyTrainer", + "OffPolicyTrainer", + "OnPolicyTrainer", "OfflineTrainer", "test_episode", "gather_info", diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index d21ac7407..ce81409a2 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -785,7 +785,7 @@ def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: return update_stat -class OffpolicyTrainer(BaseTrainer[OffPolicyTrainingConfig]): +class OffPolicyTrainer(BaseTrainer[OffPolicyTrainingConfig]): """Offpolicy trainer, samples mini-batches from buffer and passes them to update. Note that with this trainer, it is expected that the policy's `learn` method @@ -838,7 +838,7 @@ def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: return update_stat -class OnpolicyTrainer(BaseTrainer[OnPolicyTrainingConfig]): +class OnPolicyTrainer(BaseTrainer[OnPolicyTrainingConfig]): """On-policy trainer, passes the entire buffer to .update and resets it after. Note that it is expected that the learn method of a policy will perform From 835792f2edd808226075d53c8c7905bf85385ad5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 8 Mar 2025 19:40:17 +0100 Subject: [PATCH 036/230] v2: Major refactoring of the Trainer classes * The trainer logic and configuration is now properly separated between the three cases of on-policy, off-policy and offline learning: The base class is no longer a "God" class which does it all; logic and functionality has been moved to the respective subclasses (`OnPolicyTrainer`, `OffPolicyTrainer` and `OfflineTrainer`, with `OnlineTrainer` being introduced as a base class for the two former specialisations). * The trainer configuration objects introduced earlier are now fully specific to the respective case, and certral central documentation is provided for each parameter (with greatly improved detail) * The iterator semantics have been dropped: Method `__next__` has been replaced by `execute_epoch`. * The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely the methods and attributes a user can reasonably use directly. * Issues resolved: * Parameter `reset_prior_to_run` of `run` was never respected; changed parametrisation accordingly * Inconsistent configuration now raises exceptions instead of making assumptions about the intention For further details, see changes committed to CHANGELOG.md. In the context of v2 refactoring, this commit renders OfflineTrainer functional again --- CHANGELOG.md | 49 +- tianshou/data/stats.py | 2 +- tianshou/highlevel/algorithm.py | 4 +- tianshou/highlevel/world.py | 4 +- tianshou/policy/base.py | 4 +- tianshou/trainer/__init__.py | 14 +- tianshou/trainer/base.py | 1229 +++++++++++++++++-------------- tianshou/trainer/utils.py | 3 + 8 files changed, 728 insertions(+), 581 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09fe609f9..dca302677 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,12 +2,49 @@ ## Release 2.0.0 -* We now conceptually differentiate between the learning algorithm and the policy being optimised: - * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`. - Migration information (`BasePolicy` -> `Algorithm`): - * `PGPolicy` -> `Reinforce` - * `DQNPolicy` -> `DeepQLearning` - * `DDPGPolicy` -> `DDPG` +* `Trainer` abstraction (formerly `BaseTrainer`): + * The trainer logic and configuration is now properly separated between the three cases of on-policy, off-policy + and offline learning: The base class is no longer a "God" class which does it all; logic and functionality has moved + to the respective subclasses (`OnPolicyTrainer`, `OffPolicyTrainer` and `OfflineTrainer`, with `OnlineTrainer` + being introduced as a base class for the two former specialisations). + * The trainers now use configuration objects with central documentation (which has been greatly improved to enhance + clarity and usability in general); every type of trainer now has a dedicated configuration class which provides + precisely the options that are applicable. + * The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely + the methods and attributes a user should reasonably access. + * Further changes affecting usage: + * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. + * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full + episodes) + * Issues resolved: + * Methods `run` and `reset`: Parameter `reset_prior_to_run` of `run` was never respected if it was set to `False`, + because the implementation of `__iter__` (now removed) would call `reset` regardless - and calling `reset` + is indeed necessary, because it initializes the training. The parameter was removed and replaced by + `reset_collectors` (such that `run` now replicates the parameters of `reset`). + * Inconsistent configuration options now raise exceptions rather than silently ignoring the issue in the + hope that default behaviour will achieve what the user intended. + One condition where `test_in_train` was silently set to `False` was removed and replaced by a warning. + * Open issues: + * TODO: For `test_in_train`, the early stopping criterion was computed incorrectly (did not consider `compute_score_fn`, + i.e. it assumed that the default implementation applies) + * TODO: _gradient_step counter is not incorrect; replace it with a simple update step counter + * Migration information at a glance: + * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: + `OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`, `OfflineTrainingConfig`. + * Trainer classes have been renamed: + * `OnpolicyTrainer` -> `OnPolicyTrainer` + * `OffpolicyTrainer` -> `OffPolicyTrainer` + * Method `run`: The parameter `reset_prior_to_run` was removed and replaced by `reset_collectors` (see above). + * Methods `run` and `reset`: The parameter `reset_buffer` was renamed to `reset_collector_buffers` for clarity + * Trainers are no longer iterators; manual usage (not using `run`) should simply call `reset` followed by + calls of `execute_epoch`. +* `Policy` and `Algorithm` abstractions (formerly unified in `BasePolicy`): + * We now conceptually differentiate between the learning algorithm and the policy being optimised: + * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`. + Migration information (`BasePolicy` -> `Algorithm`): + * `PGPolicy` -> `Reinforce` + * `DQNPolicy` -> `DeepQLearning` + * `DDPGPolicy` -> `DDPG` * The `Algorithm` abstraction can directly initiate the learning process via method `run_training`. * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index 11d64c017..ec2cd6703 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -107,7 +107,7 @@ class EpochStats(DataclassPPrintMixin): epoch: int """The current epoch.""" - train_collect_stat: "CollectStatsBase" + train_collect_stat: Optional["CollectStatsBase"] """The statistics of the last call to the training collector.""" test_collect_stat: Optional["CollectStats"] """The statistics of the last call to the test collector.""" diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 506dec695..84793e712 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -69,7 +69,7 @@ from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.pg import ActorPolicy -from tianshou.trainer import BaseTrainer, OffPolicyTrainer, OnPolicyTrainer +from tianshou.trainer import OffPolicyTrainer, OnPolicyTrainer, Trainer from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor @@ -177,7 +177,7 @@ def create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: return policy @abstractmethod - def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> BaseTrainer: + def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> Trainer: pass diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 439ed0b7e..2a68d3ff8 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -7,7 +7,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger from tianshou.policy import Algorithm - from tianshou.trainer import BaseTrainer + from tianshou.trainer import Trainer @dataclass(kw_only=True) @@ -21,7 +21,7 @@ class World: logger: "TLogger" persist_directory: str restore_directory: str | None - trainer: Optional["BaseTrainer"] = None + trainer: Optional["Trainer"] = None def persist_path(self, filename: str) -> str: return os.path.abspath(os.path.join(self.persist_directory, filename)) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 775b1df1e..d7d60d491 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -31,13 +31,13 @@ if TYPE_CHECKING: from tianshou.trainer.base import ( - BaseTrainer, OfflineTrainer, OfflineTrainingConfig, OffPolicyTrainer, OffPolicyTrainingConfig, OnPolicyTrainer, OnPolicyTrainingConfig, + Trainer, ) logger = logging.getLogger(__name__) @@ -775,7 +775,7 @@ def compute_nstep_return( return cast(BatchWithReturnsProtocol, batch) @abstractmethod - def create_trainer(self, config: TTrainingConfig) -> "BaseTrainer": + def create_trainer(self, config: TTrainingConfig) -> "Trainer": pass def run_training(self, config: TTrainingConfig): diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index a4eaa4118..426f13080 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,18 +1,22 @@ """Trainer package.""" -from tianshou.trainer.base import ( - BaseTrainer, +from .base import ( OfflineTrainer, + OfflineTrainingConfig, OffPolicyTrainer, + OffPolicyTrainingConfig, OnPolicyTrainer, + OnPolicyTrainingConfig, + Trainer, ) from tianshou.trainer.utils import gather_info, test_episode __all__ = [ - "BaseTrainer", + "Trainer", "OffPolicyTrainer", "OnPolicyTrainer", "OfflineTrainer", - "test_episode", - "gather_info", + "OffPolicyTrainingConfig", + "OnPolicyTrainingConfig", + "OfflineTrainingConfig", ] diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index ce81409a2..3c9b5723d 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -1,7 +1,30 @@ +""" +This module contains Tianshou's trainer classes, which orchestrate the training and call upon an RL algorithm's +specific network updating logic to perform the actual gradient updates. + +Training is structured as follows (hierarchical glossary): +- **epoch**: The outermost iteration level of the training loop. Each epoch consists of a number of training steps + and one test step (see :attr:`TrainingConfig.max_epoch` for a detailed explanation): + - **training step**: A training step performs the steps necessary in order to apply a single update of the neural + network components as defined by the underlying RL algorithm (:class:`Algorithm`). This involves the following sub-steps: + - for online learning algorithms: + - **collection step**: collecting environment steps/transitions to be used for training. + - (potentially) a test step (see below) if the early stopping criterion is satisfied based on + the data collected (see :attr:`OnlineTrainingConfig.test_in_train`). + - **update step**: applying the actual gradient updates using the RL algorithm. + The update is based on either ... + - data from only the preceding collection step (on-policy learning), + - data from the collection step and previously collected data (off-policy learning), or + - data from the user-provided replay buffer (offline learning). + For offline learning algorithms, a training step is thus equivalent to an update step. + - **test step**: Collects test episodes from dedicated test environments which are used to evaluate the performance + of the policy. Optionally, the performance result can be used to determine whether training shall stop early + (see :attr:`TrainingConfig.stop_fn`). +""" import logging import time from abc import ABC, abstractmethod -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Callable from dataclasses import asdict, dataclass from functools import partial @@ -42,40 +65,149 @@ class TrainingConfig(ToStringMixin): max_epoch: int = 100 """ - the number of epochs to run training for. An epoch is the outermost iteration level and each - epoch consists of a number of training steps and a test step, where each training step + the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each + epoch consists of a number of training steps and one test step, where each training step - * collects environment steps/transitions (collection step), adding them to the (replay) - buffer (see :attr:`step_per_collect`) - * performs one or more gradient updates (see :attr:`update_per_step`), + * [for the online case] collects environment steps/transitions (**collection step**), + adding them to the (replay) buffer (see :attr:`step_per_collect` and :attr:`episode_per_collect`) + * performs an **update step** via the RL algorithm being used, which can involve + one or more actual gradient updates, depending on the algorithm and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate agent performance. - The number of training steps in each epoch is indirectly determined by + Training may be stopped early if the stop criterion is met (see :attr:`stop_fn`). + + For online training, the number of training steps in each epoch is indirectly determined by :attr:`step_per_epoch`: As many training steps will be performed as are required in order to reach :attr:`step_per_epoch` total steps in the training environments. Specifically, if the number of transitions collected per step is `c` (see :attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number of training steps per epoch is `ceil(s / c)`. - Therefore, if `num_epochs = e`, the total number of environment steps taken during training can be computed as `e * ceil(s / c) * c`. + + For offline training, the number of training steps per epoch is equal to :attr:`step_per_epoch`. """ step_per_epoch: int = 30000 """ - the total number of environment steps to be made per epoch. See :attr:`num_epochs` for - an explanation of epoch semantics. + for an online algorithm, this is the total number of environment steps to be collected per epoch, and, + for an offline algorithm, it is the total number of training steps to take per epoch. + See :attr:`num_epochs` for an explanation of epoch semantics. + """ + + test_collector: BaseCollector | None = None + """ + the collector to use for test episode collection (test steps); if None, perform no test steps. """ episode_per_test: int = 1 - """the total number of episodes to collect in each test step (across all test environments). + """the number of episodes to collect in each test step. + """ + + train_fn: Callable[[int, int], None] | None = None + """ + a callback function which is called at the beginning of each training step. + It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + """ + + test_fn: Callable[[int, int | None], None] | None = None + """ + a callback function to be called at the beginning of each test step. + It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + """ + + stop_fn: Callable[[float], bool] | None = None + """ + a callback function with signature ``f(score: float) -> bool``, which + is used to decide whether training shall be stopped early based on the score + achieved in a test step. + The score it receives is computed by the :attr:`compute_score_fn` callback + (which defaults to the mean reward if the function is not provided). + + Requires test steps to be activated and thus :attr:`test_collector` to be set. + + Note: The function is also used when :attr:`test_in_train` is activated (see docstring). + """ + + compute_score_fn: Callable[[CollectStats], float] | None = None + """ + the callback function to use in order to compute the test batch performance score, which is used to + determine what the best model is (score is maximized); if None, use the mean reward. + """ + + save_best_fn: Callable[["Algorithm"], None] | None = None + """ + the callback function to call in order to save the best model whenever a new best score (see :attr:`compute_score_fn`) + is achieved in a test step. It should have the signature ``f(policy: BasePolicy) -> None``. + """ + + save_checkpoint_fn: Callable[[int, int, int], str] | None = None """ + the callback function with which to save checkpoint data after each training step, + which can save whatever data is desired to a file and returns the path of the file. + Signature: ``f(epoch: int, env_step: int, gradient_step: int) -> str``. + """ + + resume_from_log: bool = False + """ + whether to load env_step/gradient_step and other metadata from the existing log, + which is given in :attr:`logger`. + """ + + reward_metric: Callable[[np.ndarray], np.ndarray] | None = None + """ + a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + which is used in multi-agent RL. We need to return a single scalar for each episode's result + to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, + e.g., the reward of agent 1 or the average reward over all agents. + """ + + logger: BaseLogger | None = None + """ + the logger with which to log statistics during training/testing/updating. To not log anything, use None. + """ + + verbose: bool = True + """ + whether to print status information to stdout. + If set to False, status information will still be logged (provided that logging is enabled via the + `logging` Python module). + """ + + show_progress: bool = True + """ + whether to display a progress bars during training. + """ + + def __post_init__(self): + if self.resume_from_log and self.logger is None: + raise ValueError("Cannot resume from log without a logger being provided") + if self.test_collector is None: + if self.stop_fn is not None: + raise ValueError( + "stop_fn cannot be activated without test steps being enabled (test_collector being set)" + ) + if self.test_fn is not None: + raise ValueError( + "test_fn is set while test steps are disabled (test_collector is None)" + ) + if self.save_best_fn is not None: + raise ValueError( + "save_best_fn is set while test steps are disabled (test_collector is None)" + ) - buffer_size: int = 4096 - """the total size of the sample/replay buffer, in which environment steps (transitions) are - stored""" + +@dataclass(kw_only=True) +class OnlineTrainingConfig(TrainingConfig): + train_collector: BaseCollector + """ + the collector with which to gather new data for training in each training step + """ step_per_collect: int | None = 2048 """ @@ -103,52 +235,50 @@ class TrainingConfig(ToStringMixin): This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. """ - # TODO copy docstrings from BaseTrainer - train_collector: BaseCollector | None = None - test_collector: BaseCollector | None = None - buffer: ReplayBuffer | None = None - train_fn: Callable[[int, int], None] | None = None - test_fn: Callable[[int, int | None], None] | None = None - stop_fn: Callable[[float], bool] | None = None - compute_score_fn: Callable[[CollectStats], float] | None = None - save_best_fn: Callable[["Algorithm"], None] | None = None - save_checkpoint_fn: Callable[[int, int, int], str] | None = None - resume_from_log: bool = False - reward_metric: Callable[[np.ndarray], np.ndarray] | None = None - logger: BaseLogger | None = None - verbose: bool = True - show_progress: bool = True test_in_train: bool = True + """ + Whether to apply an effective test step triggered by the early stopping criterion (given by :attr:`stop_fn`) + being satisfied in the data collected in the collect step within a training step: + If the stop criterion is satisfied, it collects `episode_per_test` test episodes (as in a test step) + and determines whether the stop criterion is also satisfied by the episodes thus collected, + and if so, training stops early. + """ def __post_init__(self): + super().__post_init__() if count_none(self.step_per_collect, self.episode_per_collect) != 1: raise ValueError("Exactly one of {step_per_collect, episode_per_collect} must be set") + if self.test_in_train and (self.test_collector is None or self.stop_fn is None): + raise ValueError("test_in_train requires test_collector and stop_fn to be set") @dataclass(kw_only=True) -class OnPolicyTrainingConfig(TrainingConfig): +class OnPolicyTrainingConfig(OnlineTrainingConfig): batch_size: int | None = 64 """ - Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, a form of regularization). - Set ``batch_size=None`` for the full buffer to be used for the gradient update (no mini-batching). + Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, + a form of regularization). + Set ``batch_size=None`` for the full buffer that was collected within the training step to be + used for the gradient update (no mini-batching). """ - repeat_per_collect: int | None = 1 + repeat_per_collect: int = 1 """ - controls, within one gradient update step of an on-policy algorithm, the number of times an - actual gradient update is applied using the full collected dataset, i.e. if the parameter is + controls, within one update step of an on-policy algorithm, the number of times + the full collected data is applied for gradient updates, i.e. if the parameter is 5, then the collected data shall be used five times to update the policy within the same - training step. + update step. """ @dataclass(kw_only=True) -class OffPolicyTrainingConfig(TrainingConfig): +class OffPolicyTrainingConfig(OnlineTrainingConfig): batch_size: int = 64 """ the the number of environment steps/transitions to sample from the buffer for a gradient update. """ + # TODO: Given our glossary, this is confusingly named. Should definitely contain the word "gradient" update_per_step: float = 1.0 """ the number of gradient steps to perform per sample collected (see :attr:`step_per_collect`). @@ -158,646 +288,621 @@ class OffPolicyTrainingConfig(TrainingConfig): @dataclass(kw_only=True) -class OfflineTrainingConfig(OffPolicyTrainingConfig): - pass - - -TConfig = TypeVar("TConfig", bound=TrainingConfig) - - -class BaseTrainer(Generic[TConfig], ABC): - """An iterator base class for trainers. - - Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results - on every epoch. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param batch_size: the batch size of sample data, which is going to feed in - the policy network. If None, will use the whole buffer in each gradient step. - :param train_collector: the collector used for training. - :param test_collector: the collector used for testing. If it's None, - then no testing will be performed. - :param buffer: the replay buffer used for off-policy algorithms or for pre-training. - If a policy overrides the ``process_buffer`` method, the replay buffer will - be pre-processed before training. - :param max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` - is set. - :param step_per_epoch: the number of transitions collected per epoch. - :param repeat_per_collect: the number of repeat time for policy learning, - for example, set it to 2 means the policy needs to learn each given batch - data twice. Only used in on-policy algorithms - :param episode_per_test: the number of episodes for one policy evaluation. - :param update_per_step: only used in off-policy algorithms. - How many gradient steps to perform per step in the environment - (i.e., per sample added to the buffer). - :param step_per_collect: the number of transitions the collector would - collect before the network update, i.e., trainer will collect - "step_per_collect" transitions and do some policy network update repeatedly - in each epoch. - :param episode_per_collect: the number of episodes the collector would - collect before the network update, i.e., trainer will collect - "episode_per_collect" episodes and do some policy network update repeatedly - in each epoch. - :param train_fn: a hook called at the beginning of training in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param compute_score_fn: Calculate the test batch performance score to - determine whether it is the best model, the mean reward will be used as score if not provided. - :param save_best_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. - :param save_checkpoint_fn: a function to save training process and - return the saved checkpoint path, with the signature ``f(epoch: int, - env_step: int, gradient_step: int) -> str``; you can save whatever you want. - :param resume_from_log: resume env_step/gradient_step and other metadata - from existing tensorboard log. - :param stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param reward_metric: a function with signature - ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray - with shape (num_episode,)``, used in multi-agent RL. We need to return a - single scalar for each episode's result to monitor training in the - multi-agent RL setting. This function specifies what is the desired metric, - e.g., the reward of agent 1 or the average reward over all agents. - :param logger: A logger that logs statistics during - training/testing/updating. To not log anything, keep the default logger. - :param verbose: whether to print status information to stdout. - If set to False, status information will still be logged (provided that - logging is enabled via the `logging` module). - :param show_progress: whether to display a progress bar when training. - :param test_in_train: whether to test in the training phase. - """ - - __doc__: str - - @staticmethod - def gen_doc(learning_type: str) -> str: - """Document string for subclass trainer.""" - step_means = f'The "step" in {learning_type} trainer means ' - if learning_type != "offline": - step_means += "an environment step (a.k.a. transition)." - else: # offline - step_means += "a gradient step." - - trainer_name = learning_type.capitalize() + "Trainer" - - return f"""An iterator class for {learning_type} trainer procedure. - - Returns an iterator that yields a 3-tuple (epoch, stats, info) of - train results on every epoch. - - {step_means} - - Example usage: +class OfflineTrainingConfig(TrainingConfig): + buffer: ReplayBuffer + """ + the replay buffer with environment steps to use as training data for offline learning. + This buffer will be pre-processed using the RL algorithm's pre-processing + function (if any) before training. + """ - :: + batch_size: int = 64 + """ + the the number of environment steps/transitions to sample from the buffer for a gradient update. + """ - trainer = {trainer_name}(...) - for epoch, epoch_stat, info in trainer: - print("Epoch:", epoch) - print(epoch_stat) - print(info) - do_something_with_policy() - query_something_about_policy() - make_a_plot_with(epoch_stat) - display(info) - - epoch int: the epoch number - - epoch_stat dict: a large collection of metrics of the current epoch - - info dict: result returned from :func:`~tianshou.trainer.gather_info` +TTrainingConfig = TypeVar("TTrainingConfig", bound=TrainingConfig) +TOnlineTrainingConfig = TypeVar("TOnlineTrainingConfig", bound=OnlineTrainingConfig) - You can even iterate on several trainers at the same time: - :: +class Trainer(Generic[TTrainingConfig], ABC): + """ + Base class for trainers in Tianshou, which orchestrate the training process and call upon an RL algorithm's + specific network updating logic to perform the actual gradient updates. - trainer1 = {trainer_name}(...) - trainer2 = {trainer_name}(...) - for result1, result2, ... in zip(trainer1, trainer2, ...): - compare_results(result1, result2, ...) - """ + The base class already implements the fundamental epoch logic and fully implements the test step + logic, which is common to all trainers. The training step logic is left to be implemented by subclasses. + """ def __init__( self, policy: "Algorithm", - config: TConfig, + config: TTrainingConfig, ): - logger = config.logger - logger = logger or LazyLogger() - self.policy = policy - - buffer = config.buffer - if buffer is not None: - buffer = policy.process_buffer(buffer) - self.buffer = buffer - - self.train_collector = config.train_collector - self.test_collector = config.test_collector - - self.logger = logger - self.start_time = time.time() - self.stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) - self.best_score = 0.0 - self.best_reward = 0.0 - self.best_reward_std = 0.0 - self.start_epoch = 0 - # This is only used for logging but creeps into the implementations - # of the trainers. I believe it would be better to remove - self._gradient_step = 0 - self.env_step = 0 - self.env_episode = 0 - self.policy_update_time = 0.0 - self.max_epoch = config.max_epoch - assert ( - config.step_per_epoch is not None - ), "The trainer requires step_per_epoch to be set, sorry for the wrong type hint" - self.step_per_epoch: int = config.step_per_epoch - - # either on of these two - self.step_per_collect = config.step_per_collect - self.episode_per_collect = config.episode_per_collect - + self.algorithm = policy self.config = config - self.episode_per_test = config.episode_per_test - self.train_fn = config.train_fn - self.test_fn = config.test_fn - self.stop_fn = config.stop_fn - self.compute_score_fn: Callable[[CollectStats], float] - compute_score_fn = config.compute_score_fn - if compute_score_fn is None: + self._logger = config.logger or LazyLogger() - def compute_score_fn(stat: CollectStats) -> float: - assert stat.returns_stat is not None # for mypy - return stat.returns_stat.mean + self._start_time = time.time() + self._stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) + self._best_score = 0.0 + self._best_reward = 0.0 + self._best_reward_std = 0.0 + self._start_epoch = 0 + # This is only used for logging but creeps into the implementations + # of the trainers. I believe it would be better to remove + self._gradient_step = 0 + self._env_step = 0 + """ + the step counter which is used to track progress of the training process. + For online learning (i.e. on-policy and off-policy learning), this is the total number of + environment steps collected, and for offline training, it is the total number of environment + steps that have been sampled from the replay buffer to perform gradient updates. + """ + self._policy_update_time = 0.0 - self.compute_score_fn = compute_score_fn - self.save_best_fn = config.save_best_fn - self.save_checkpoint_fn = config.save_checkpoint_fn + self._compute_score_fn: Callable[[CollectStats], float] = ( + config.compute_score_fn or self._compute_score_fn_default + ) - self.reward_metric = config.reward_metric - self.verbose = config.verbose - self.show_progress = config.show_progress - self.test_in_train = config.test_in_train - self.resume_from_log = config.resume_from_log + self._epoch = self._start_epoch + self._best_epoch = self._start_epoch + self._stop_fn_flag = False - self.is_run = False - self.last_rew, self.last_len = 0.0, 0.0 + @staticmethod + def _compute_score_fn_default(stat: CollectStats) -> float: + """ + The default score function, which returns the mean return/reward. - self.epoch = self.start_epoch - self.best_epoch = self.start_epoch - self.stop_fn_flag = False - self.iter_num = 0 + :param stat: the collection stats + :return: the mean return + """ + assert stat.returns_stat is not None # for mypy + return stat.returns_stat.mean @property - def _pbar(self) -> type[tqdm.tqdm]: + def _pbar(self) -> Callable[..., tqdm.tqdm]: """Use as context manager or iterator, i.e., `with self._pbar(...) as t:` or `for _ in self._pbar(...):`.""" return partial( tqdm.tqdm, dynamic_ncols=True, ascii=True, - disable=not self.show_progress, - ) # type: ignore[return-value] + disable=not self.config.show_progress, + ) def _reset_collectors(self, reset_buffer: bool = False) -> None: - if self.train_collector is not None: - self.train_collector.reset(reset_buffer=reset_buffer) - if self.test_collector is not None: - self.test_collector.reset(reset_buffer=reset_buffer) - - def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None: - """Initialize or reset the instance to yield a new iterator from zero.""" - self.is_run = False - self.env_step = 0 - if self.resume_from_log: + if self.config.test_collector is not None: + self.config.test_collector.reset(reset_buffer=reset_buffer) + + def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = False) -> None: + """Initializes the training process. + + :param reset_collectors: whether to reset the collectors prior to starting the training process. + Specifically, this will reset the environments in the collectors (starting new episodes), + and the statistics stored in the collector. Whether the contained buffers will be reset/cleared + is determined by the `reset_buffer` parameter. + :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the + contained buffers as well. + This has no effect if `reset_collectors` is False. + """ + self._env_step = 0 + if self.config.resume_from_log: ( - self.start_epoch, - self.env_step, + self._start_epoch, + self._env_step, self._gradient_step, - ) = self.logger.restore_data() + ) = self._logger.restore_data() - self.last_rew, self.last_len = 0.0, 0.0 - self.start_time = time.time() + self._start_time = time.time() if reset_collectors: - self._reset_collectors(reset_buffer=reset_buffer) - - if self.train_collector is not None and ( - self.train_collector.algorithm != self.policy or self.test_collector is None - ): - self.test_in_train = False + self._reset_collectors(reset_buffer=reset_collector_buffers) - if self.test_collector is not None: - assert self.episode_per_test is not None - assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 + if self.config.test_collector is not None: + assert self.config.episode_per_test is not None + assert not isinstance(self.config.test_collector, AsyncCollector) # Issue 700 test_result = test_episode( - self.test_collector, - self.test_fn, - self.start_epoch, - self.episode_per_test, - self.logger, - self.env_step, - self.reward_metric, + self.config.test_collector, + self.config.test_fn, + self._start_epoch, + self.config.episode_per_test, + self._logger, + self._env_step, + self.config.reward_metric, ) assert test_result.returns_stat is not None # for mypy - self.best_epoch = self.start_epoch - self.best_reward, self.best_reward_std = ( + self._best_epoch = self._start_epoch + self._best_reward, self._best_reward_std = ( test_result.returns_stat.mean, test_result.returns_stat.std, ) - self.best_score = self.compute_score_fn(test_result) - if self.save_best_fn: - self.save_best_fn(self.policy) - - self.epoch = self.start_epoch - self.stop_fn_flag = False - self.iter_num = 0 - - def __iter__(self): # type: ignore - self.reset(reset_collectors=True, reset_buffer=False) - return self - - def __next__(self) -> EpochStats: - """Perform one epoch (both train and eval).""" - self.epoch += 1 - self.iter_num += 1 + self._best_score = self._compute_score_fn(test_result) + if self.config.save_best_fn: + self.config.save_best_fn(self.algorithm) + + self._epoch = self._start_epoch + self._stop_fn_flag = False + + class _TrainingStepResult(ABC): + @abstractmethod + def get_steps_in_epoch_advancement(self): + """ + :return: the number of steps that were done within the epoch, where the concrete semantics + of what a step is depend on the type of algorith. See docstring of `TrainingConfig.step_per_epoch`. + """ + + @abstractmethod + def get_collect_stats(self) -> CollectStats | None: + pass + + @abstractmethod + def get_training_stats(self) -> TrainingStats | None: + pass + + @abstractmethod + def is_training_done(self): + """:return: whether the early stopping criterion is satisfied and training shall stop.""" + + @abstractmethod + def get_env_step_advancement(self) -> int: + """ + :return: the number of steps by which to advance the env_step counter in the trainer (see docstring + of trainer attribute). The semantics depend on the type of the algorithm. + """ - if self.iter_num > 1: - # iterator exhaustion check - if self.epoch > self.max_epoch: - raise StopIteration + @abstractmethod + def _create_epoch_pbar_data_dict( + self, training_step_result: _TrainingStepResult + ) -> dict[str, str]: + pass - # exit flag 1, when stop_fn succeeds in train_step or test_step - if self.stop_fn_flag: - raise StopIteration + def execute_epoch(self) -> EpochStats: + self._epoch += 1 - # perform n step_per_epoch + # perform the required number of steps for the epoch (`step_per_epoch`) steps_done_in_this_epoch = 0 - with self._pbar(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", position=1) as t: - train_stat: CollectStatsBase - while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: - train_stat, update_stat, self.stop_fn_flag = self.training_step() - - if isinstance(train_stat, CollectStats): - pbar_data_dict = { - "env_step": str(self.env_step), - "env_episode": str(self.env_episode), - "rew": f"{self.last_rew:.2f}", - "len": str(int(self.last_len)), - "n/ep": str(train_stat.n_collected_episodes), - "n/st": str(train_stat.n_collected_steps), - } - - # t might be disabled, we track the steps manually - t.update(train_stat.n_collected_steps) - steps_done_in_this_epoch += train_stat.n_collected_steps - - if self.stop_fn_flag: - t.set_postfix(**pbar_data_dict) - else: - # TODO: there is no iteration happening here, it's the offline case - # Code should be restructured! - pbar_data_dict = {} - assert self.buffer, "No train_collector or buffer specified" - train_stat = CollectStatsBase( - n_collected_steps=len(self.buffer), - ) - - # t might be disabled, we track the steps manually - t.update() - steps_done_in_this_epoch += 1 - + train_collect_stats, training_stats = None, None + with self._pbar( + total=self.config.step_per_epoch, desc=f"Epoch #{self._epoch}", position=1 + ) as t: + while steps_done_in_this_epoch < self.config.step_per_epoch and not self._stop_fn_flag: + # perform a training step and update progress + training_step_result = self._training_step() + steps_done_in_this_epoch += training_step_result.get_steps_in_epoch_advancement() + t.update(training_step_result.get_steps_in_epoch_advancement()) + self._stop_fn_flag = training_step_result.is_training_done() + self._env_step += training_step_result.get_env_step_advancement() + + collect_stats = training_step_result.get_collect_stats() + if collect_stats is not None: + self._logger.log_train_data(asdict(collect_stats), self._env_step) + training_stats = training_step_result.get_training_stats() + + pbar_data_dict = self._create_epoch_pbar_data_dict(training_step_result) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) pbar_data_dict["gradient_step"] = str(self._gradient_step) t.set_postfix(**pbar_data_dict) - if self.stop_fn_flag: - break - - if steps_done_in_this_epoch <= self.step_per_epoch and not self.stop_fn_flag: - # t might be disabled, we track the steps manually - t.update() - steps_done_in_this_epoch += 1 - - # TODO What is this doing here? Where to put it? - # for offline RL - if self.train_collector is None: - assert self.buffer is not None - batch_size = self.batch_size or len(self.buffer) - self.env_step = self._gradient_step * batch_size - - test_stat = None - if not self.stop_fn_flag: - self.logger.save_data( - self.epoch, - self.env_step, + test_collect_stats = None + if not self._stop_fn_flag: + self._logger.save_data( + self._epoch, + self._env_step, self._gradient_step, - self.save_checkpoint_fn, + self.config.save_checkpoint_fn, ) - # test - if self.test_collector is not None: - test_stat, self.stop_fn_flag = self.test_step() - info_stat = gather_info( - start_time=self.start_time, - policy_update_time=self.policy_update_time, + # test step + if self.config.test_collector is not None: + test_collect_stats, self._stop_fn_flag = self._test_step() + + info_stats = gather_info( + start_time=self._start_time, + policy_update_time=self._policy_update_time, gradient_step=self._gradient_step, - best_score=self.best_score, - best_reward=self.best_reward, - best_reward_std=self.best_reward_std, - train_collector=self.train_collector, - test_collector=self.test_collector, + best_score=self._best_score, + best_reward=self._best_reward, + best_reward_std=self._best_reward_std, + train_collector=self.config.train_collector + if isinstance(self.config, OnlineTrainingConfig) + else None, + test_collector=self.config.test_collector, ) - self.logger.log_info_data(asdict(info_stat), self.epoch) + self._logger.log_info_data(asdict(info_stats), self._epoch) - # in case trainer is used with run(), epoch_stat will not be returned return EpochStats( - epoch=self.epoch, - train_collect_stat=train_stat, - test_collect_stat=test_stat, - training_stat=update_stat, - info_stat=info_stat, + epoch=self._epoch, + train_collect_stat=train_collect_stats, + test_collect_stat=test_collect_stats, + training_stat=training_stats, + info_stat=info_stats, ) - def test_step(self) -> tuple[CollectStats, bool]: - """Perform one testing step.""" - assert self.episode_per_test is not None - assert self.test_collector is not None + def _test_step(self) -> tuple[CollectStats, bool]: + """Perform one test step.""" + assert self.config.episode_per_test is not None + assert self.config.test_collector is not None stop_fn_flag = False test_stat = test_episode( - self.test_collector, - self.test_fn, - self.epoch, - self.episode_per_test, - self.logger, - self.env_step, - self.reward_metric, + self.config.test_collector, + self.config.test_fn, + self._epoch, + self.config.episode_per_test, + self._logger, + self._env_step, + self.config.reward_metric, ) assert test_stat.returns_stat is not None # for mypy rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std - score = self.compute_score_fn(test_stat) - if self.best_epoch < 0 or self.best_score < score: - self.best_score = score - self.best_epoch = self.epoch - self.best_reward = float(rew) - self.best_reward_std = rew_std - if self.save_best_fn: - self.save_best_fn(self.policy) + score = self._compute_score_fn(test_stat) + if self._best_epoch < 0 or self._best_score < score: + self._best_score = score + self._best_epoch = self._epoch + self._best_reward = float(rew) + self._best_reward_std = rew_std + if self.config.save_best_fn: + self.config.save_best_fn(self.algorithm) cur_info, best_info = "", "" if score != rew: - cur_info, best_info = f", score: {score: .6f}", f", best_score: {self.best_score:.6f}" + cur_info, best_info = f", score: {score: .6f}", f", best_score: {self._best_score:.6f}" log_msg = ( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}" - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f}{best_info} in #{self.best_epoch}" + f"Epoch #{self._epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}" + f" best_reward: {self._best_reward:.6f} ± " + f"{self._best_reward_std:.6f}{best_info} in #{self._best_epoch}" ) log.info(log_msg) - if self.verbose: + if self.config.verbose: print(log_msg, flush=True) - if self.stop_fn and self.stop_fn(self.best_reward): + if self.config.stop_fn and self.config.stop_fn(self._best_reward): stop_fn_flag = True return test_stat, stop_fn_flag - def training_step(self) -> tuple[CollectStatsBase, TrainingStats | None, bool]: - """Perform one training iteration. + @abstractmethod + def _training_step(self) -> _TrainingStepResult: + """Performs one training step.""" - A training iteration includes collecting data (for online RL), determining whether to stop training, - and performing a policy update if the training iteration should continue. + # TODO: move moving average computation and logging into its own logger + # TODO: maybe think about a command line logger instead of always printing data dict + def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: + """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" + cur_losses_dict = update_stat.get_loss_stats_dict() + update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( + cur_losses_dict, + ) + self._logger.log_update_data(asdict(update_stat), self._gradient_step) + + # TODO: seems convoluted, there should be a better way of dealing with the moving average stats + def _update_moving_avg_stats_and_get_averaged_data( + self, + data: dict[str, float], + ) -> dict[str, float]: + """Add entries to the moving average object in the trainer and retrieve the averaged results. + + :param data: any entries to be tracked in the moving average object. + :return: A dictionary containing the averaged values of the tracked entries. - :return: the iteration's collect stats, training stats, and a flag indicating whether to stop training. - If training is to be stopped, no gradient steps will be performed and the training stats will be `None`. """ - with policy_within_training_step(self.policy.policy): - should_stop_training = False - - collect_stats: CollectStatsBase | CollectStats - if self.train_collector is not None: - collect_stats = self._collect_training_data() - should_stop_training = self._update_best_reward_and_return_should_stop_training( - collect_stats, - ) - else: - assert self.buffer is not None, "Either train_collector or buffer must be provided." - collect_stats = CollectStatsBase( - n_collected_episodes=len(self.buffer), - ) + smoothed_data = {} + for key, loss_item in data.items(): + self._stat[key].add(loss_item) + smoothed_data[key] = self._stat[key].get() + return smoothed_data + + def run( + self, reset_collectors: bool = True, reset_collector_buffers: bool = False + ) -> InfoStats: + """Runs the training process with the configuration given at construction. + + :param reset_collectors: whether to reset the collectors prior to starting the training process. + Specifically, this will reset the environments in the collectors (starting new episodes), + and the statistics stored in the collector. Whether the contained buffers will be reset/cleared + is determined by the `reset_buffer` parameter. + :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the + contained buffers as well. + This has no effect if `reset_collectors` is False. + """ + self.reset( + reset_collectors=reset_collectors, reset_collector_buffers=reset_collector_buffers + ) + + while self._epoch < self.config.max_epoch and not self._stop_fn_flag: + self.execute_epoch() + + return gather_info( + start_time=self._start_time, + policy_update_time=self._policy_update_time, + gradient_step=self._gradient_step, + best_score=self._best_score, + best_reward=self._best_reward, + best_reward_std=self._best_reward_std, + train_collector=self.config.train_collector + if isinstance(self.config, OnlineTrainingConfig) + else None, + test_collector=self.config.test_collector, + ) + + +class OfflineTrainer(Trainer[OfflineTrainingConfig]): + """An offline trainer, which samples mini-batches from a given buffer and passes them to + the algorithm's update function. + """ + + def __init__( + self, + policy: "Algorithm", + config: OfflineTrainingConfig, + ): + super().__init__(policy, config) + self._buffer = policy.process_buffer(self.config.buffer) + + class _TrainingStepResult(Trainer._TrainingStepResult): + def __init__(self, training_stats: TrainingStats, env_step_advancement: int): + self._training_stats = training_stats + self._env_step_advancement = env_step_advancement + + def get_steps_in_epoch_advancement(self): + return 1 + + def get_collect_stats(self) -> None: + return None + + def get_training_stats(self) -> TrainingStats: + return self._training_stats + + def is_training_done(self) -> bool: + return False + + def get_env_step_advancement(self) -> int: + return self._env_step_advancement + + def _training_step(self) -> _TrainingStepResult: + with policy_within_training_step(self.algorithm.policy): + self._gradient_step += 1 + # Note: since sample_size=batch_size, this will perform + # exactly one gradient step. This is why we don't need to calculate the + # number of gradient steps, like in the on-policy case. + training_stats = self.algorithm.update( + sample_size=self.config.batch_size, buffer=self._buffer + ) + self._update_moving_avg_stats_and_log_update_data(training_stats) + self._policy_update_time += training_stats.train_time + return self._TrainingStepResult( + training_stats=training_stats, env_step_advancement=self.config.batch_size + ) + + def _create_epoch_pbar_data_dict( + self, training_step_result: _TrainingStepResult + ) -> dict[str, str]: + return {} + + +class OnlineTrainer(Trainer[TOnlineTrainingConfig], Generic[TOnlineTrainingConfig], ABC): + """ + An online trainer, which collects data from the environment in each training step and + uses the collected data to perform an update step, the nature of which is to be defined + in subclasses. + """ + + def __init__( + self, + policy: "Algorithm", + config: OnlineTrainingConfig, + ): + super().__init__(policy, config) + self._env_episode = 0 + """ + the total number of episodes collected in the environment + """ + + def _reset_collectors(self, reset_buffer: bool = False) -> None: + super()._reset_collectors(reset_buffer=reset_buffer) + self.config.train_collector.reset(reset_buffer=reset_buffer) + + def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = False) -> None: + super().reset( + reset_collectors=reset_collectors, reset_collector_buffers=reset_collector_buffers + ) + + if ( + self.config.test_in_train + and self.config.train_collector.algorithm is not self.algorithm + ): + log.warning( + "The training data collector's algorithm is not the same as the one being trained, " + "yet test_in_train is enabled. This may lead to unexpected results." + ) + + self._env_episode = 0 + + class _TrainingStepResult(Trainer._TrainingStepResult): + def __init__( + self, + collect_stats: CollectStats, + training_stats: TrainingStats | None, + is_training_done: bool, + ): + self._collect_stats = collect_stats + self._training_stats = training_stats + self._is_training_done = is_training_done + + def get_steps_in_epoch_advancement(self): + return self.get_env_step_advancement() + + def get_collect_stats(self) -> CollectStats: + return self._collect_stats + + def get_training_stats(self) -> TrainingStats | None: + return self._training_stats + + def is_training_done(self): + return self._is_training_done + + def get_env_step_advancement(self) -> int: + return self._collect_stats.n_collected_steps + + def _training_step(self) -> _TrainingStepResult: + """Perform one training step. + For an online algorithm, a training step involves: + * collecting data + * for the case where `test_in_train` is activated, + determining whether the stop condition has been reached + (and returning without performing any actual training if so) + * performing a gradient update step + """ + with policy_within_training_step(self.algorithm.policy): + # collect data + collect_stats = self._collect_training_data() + + # determine whether we should stop training based on the data collected + should_stop_training = self._test_in_train( + collect_stats, + ) + + # perform gradient update step (if not already done) + training_stats: TrainingStats | None = None if not should_stop_training: - training_stats = self.policy_update_fn(collect_stats) - else: - training_stats = None + training_stats = self._update_step(collect_stats) - return collect_stats, training_stats, should_stop_training + return self._TrainingStepResult( + collect_stats=collect_stats, + training_stats=training_stats, + is_training_done=should_stop_training, + ) def _collect_training_data(self) -> CollectStats: """Performs training data collection. :return: the data collection stats """ - assert self.episode_per_test is not None - assert self.train_collector is not None - if self.train_fn: - self.train_fn(self.epoch, self.env_step) - collect_stats = self.train_collector.collect( - n_step=self.step_per_collect, - n_episode=self.episode_per_collect, + assert self.config.episode_per_test is not None + assert self.config.train_collector is not None + + if self.config.train_fn: + self.config.train_fn(self._epoch, self._env_step) + + collect_stats = self.config.train_collector.collect( + n_step=self.config.step_per_collect, + n_episode=self.config.episode_per_collect, ) - if self.train_collector.buffer.hasnull(): + if self.config.train_collector.buffer.hasnull(): from tianshou.data.collector import EpisodeRolloutHook from tianshou.env import DummyVectorEnv raise MalformedBufferError( - f"Encountered NaNs in buffer after {self.env_step} steps." + f"Encountered NaNs in buffer after {self._env_step} steps." f"Such errors are usually caused by either a bug in the environment or by " f"problematic implementations {EpisodeRolloutHook.__class__.__name__}. " f"For debugging such issues it is recommended to run the training in a single process, " f"e.g., by using {DummyVectorEnv.__class__.__name__}.", ) - self.env_step += collect_stats.n_collected_steps - self.env_episode += collect_stats.n_collected_episodes - if collect_stats.n_collected_episodes > 0: assert collect_stats.returns_stat is not None # for mypy assert collect_stats.lens_stat is not None # for mypy - self.last_rew = collect_stats.returns_stat.mean - self.last_len = collect_stats.lens_stat.mean - if self.reward_metric: # TODO: move inside collector - rew = self.reward_metric(collect_stats.returns) + if self.config.reward_metric: # TODO: move inside collector + rew = self.config.reward_metric(collect_stats.returns) collect_stats.returns = rew collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) - self.logger.log_train_data(asdict(collect_stats), self.env_step) + # update collection stats specific to this specialization + self._env_episode += collect_stats.n_collected_episodes + return collect_stats - # TODO (maybe): separate out side effect, simplify name? - def _update_best_reward_and_return_should_stop_training( + def _test_in_train( self, collect_stats: CollectStats, ) -> bool: - """If `test_in_train` and `stop_fn` are set, will compute the `stop_fn` on the mean return of the training data. - Then, if the `stop_fn` is True there, will collect test data also compute the stop_fn of the mean return - on it. - Finally, if the latter is also True, will return True. - - **NOTE:** has a side effect of updating the best reward and corresponding std. - - - :param collect_stats: the data collection stats - :return: flag indicating whether to stop training + """ + Performs performance testing based on the early stopping criterion being satisfied based on the + data collected in the current training step: + If the stop criterion is satisfied, it collects `episode_per_test` test episodes (as in a test step) + and determines whether the stop criterion is also satisfied by the episodes thus collected, + and if so, returns True, indicating that training stops early. + + Therefore, if the early stopping criterion is satisfied on the data collected for training, + this effectively carries out a test step and updates the respective metrics (best_reward, etc.) + accordingly. + + :param collect_stats: the data collection stats from the preceding collection step + :return: flag indicating whether to stop training early """ should_stop_training = False # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics - with policy_within_training_step(self.policy.policy, enabled=False): + with policy_within_training_step(self.algorithm.policy, enabled=False): if ( collect_stats.n_collected_episodes > 0 - and self.test_in_train - and self.stop_fn - and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore + and self.config.test_in_train + and self.config.stop_fn + and self.config.stop_fn(collect_stats.returns_stat.mean) # type: ignore ): - assert self.test_collector is not None - assert self.episode_per_test is not None and self.episode_per_test > 0 + assert self.config.test_collector is not None + assert self.config.episode_per_test is not None and self.config.episode_per_test > 0 test_result = test_episode( - self.test_collector, - self.test_fn, - self.epoch, - self.episode_per_test, - self.logger, - self.env_step, + self.config.test_collector, + self.config.test_fn, + self._epoch, + self.config.episode_per_test, + self._logger, + self._env_step, ) assert test_result.returns_stat is not None # for mypy - if self.stop_fn(test_result.returns_stat.mean): + if self.config.stop_fn(test_result.returns_stat.mean): should_stop_training = True - self.best_reward = test_result.returns_stat.mean - self.best_reward_std = test_result.returns_stat.std - self.best_score = self.compute_score_fn(test_result) + self._best_reward = test_result.returns_stat.mean + self._best_reward_std = test_result.returns_stat.std + self._best_score = self._compute_score_fn(test_result) return should_stop_training - # TODO: move moving average computation and logging into its own logger - # TODO: maybe think about a command line logger instead of always printing data dict - def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: - """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" - cur_losses_dict = update_stat.get_loss_stats_dict() - update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( - cur_losses_dict, - ) - self.logger.log_update_data(asdict(update_stat), self._gradient_step) - - # TODO: seems convoluted, there should be a better way of dealing with the moving average stats - def _update_moving_avg_stats_and_get_averaged_data( - self, - data: dict[str, float], - ) -> dict[str, float]: - """Add entries to the moving average object in the trainer and retrieve the averaged results. - - :param data: any entries to be tracked in the moving average object. - :return: A dictionary containing the averaged values of the tracked entries. - - """ - smoothed_data = {} - for key, loss_item in data.items(): - self.stat[key].add(loss_item) - smoothed_data[key] = self.stat[key].get() - return smoothed_data - @abstractmethod - def policy_update_fn( + def _update_step( self, collect_stats: CollectStatsBase, ) -> TrainingStats: - """Policy update function for different trainer implementation. + """Performs a gradient update step, calling the algorithm's update method accordingly. - :param collect_stats: provides info about the most recent collection. In the offline case, this will contain - stats of the whole dataset + :param collect_stats: provides info about the preceding data collection step. """ - def run(self, reset_prior_to_run: bool = True, reset_buffer: bool = False) -> InfoStats: - """Consume iterator. - - See itertools - recipes. Use functions that consume iterators at C speed - (feed the entire iterator into a zero-length deque). - - :param reset_prior_to_run: whether to reset collectors prior to run - :param reset_buffer: only has effect if `reset_prior_to_run` is True. - Then it will also reset the buffer. This is usually not necessary, use - with caution. - """ - if reset_prior_to_run: - self.reset(reset_buffer=reset_buffer) - try: - self.is_run = True - deque(self, maxlen=0) # feed the entire iterator into a zero-length deque - info = gather_info( - start_time=self.start_time, - policy_update_time=self.policy_update_time, - gradient_step=self._gradient_step, - best_score=self.best_score, - best_reward=self.best_reward, - best_reward_std=self.best_reward_std, - train_collector=self.train_collector, - test_collector=self.test_collector, + def _create_epoch_pbar_data_dict( + self, training_step_result: _TrainingStepResult + ) -> dict[str, str]: + collect_stats = training_step_result.get_collect_stats() + result = { + "env_step": str(self._env_step), + "env_episode": str(self._env_episode), + "n_ep": str(collect_stats.n_collected_episodes), + "n_st": str(collect_stats.n_collected_steps), + } + # return and episode length info is only available if at least one episode was completed + if collect_stats.n_collected_episodes > 0: + result.update( + { + "rew": f"{collect_stats.returns_stat.mean:.2f}", + "len": str(int(collect_stats.lens_stat.mean)), + } ) - finally: - self.is_run = False - - return info + return result -class OfflineTrainer(BaseTrainer[OfflineTrainingConfig]): - """Offline trainer, samples mini-batches from buffer and passes them to update. +class OffPolicyTrainer(OnlineTrainer[OffPolicyTrainingConfig]): + """An off-policy trainer, which samples mini-batches from the buffer of collected data and passes them to + algorithm's `update` function. - Uses a buffer directly and usually does not have a collector. + The algorithm's `update` method is expected to not perform additional mini-batching but just update + model parameters from the received mini-batch. """ - # for mypy - assert isinstance(BaseTrainer.__doc__, str) - __doc__ += BaseTrainer.gen_doc("offline") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) - - def policy_update_fn( - self, - collect_stats: CollectStatsBase | None = None, - ) -> TrainingStats: - """Perform one off-line policy update.""" - assert self.buffer - update_stat = self._sample_and_update(self.buffer) - # logging - self.policy_update_time += update_stat.train_time - return update_stat - - def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: - """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" - self._gradient_step += 1 - # Note: since sample_size=batch_size, this will perform - # exactly one gradient step. This is why we don't need to calculate the - # number of gradient steps, like in the on-policy case. - update_stat = self.policy.update(sample_size=self.config.batch_size, buffer=buffer) - self._update_moving_avg_stats_and_log_update_data(update_stat) - return update_stat - - -class OffPolicyTrainer(BaseTrainer[OffPolicyTrainingConfig]): - """Offpolicy trainer, samples mini-batches from buffer and passes them to update. - - Note that with this trainer, it is expected that the policy's `learn` method - does not perform additional mini-batching but just updates params from the received - mini-batch. - """ - - # for mypy - assert isinstance(BaseTrainer.__doc__, str) - __doc__ += BaseTrainer.gen_doc("offpolicy") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) - - def policy_update_fn( + def _update_step( self, # TODO: this is the only implementation where collect_stats is actually needed. Maybe change interface? collect_stats: CollectStatsBase, @@ -807,7 +912,7 @@ def policy_update_fn( :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values in it will be replaced by their moving averages. """ - assert self.train_collector is not None + assert self.config.train_collector is not None n_collected_steps = collect_stats.n_collected_steps n_gradient_steps = round(self.config.update_per_step * n_collected_steps) if n_gradient_steps == 0: @@ -816,14 +921,16 @@ def policy_update_fn( f"update_per_step={self.config.update_per_step}", ) + update_stat = None for _ in self._pbar( range(n_gradient_steps), desc="Offpolicy gradient update", position=0, leave=False, ): - update_stat = self._sample_and_update(self.train_collector.buffer) - self.policy_update_time += update_stat.train_time + update_stat = self._sample_and_update(self.config.train_collector.buffer) + self._policy_update_time += update_stat.train_time + # TODO: only the last update_stat is returned, should be improved return update_stat @@ -833,36 +940,32 @@ def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: # Note: since sample_size=batch_size, this will perform # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. - update_stat = self.policy.update(sample_size=self.config.batch_size, buffer=buffer) + update_stat = self.algorithm.update(sample_size=self.config.batch_size, buffer=buffer) self._update_moving_avg_stats_and_log_update_data(update_stat) return update_stat -class OnPolicyTrainer(BaseTrainer[OnPolicyTrainingConfig]): - """On-policy trainer, passes the entire buffer to .update and resets it after. +class OnPolicyTrainer(OnlineTrainer[OnPolicyTrainingConfig]): + """An on-policy trainer, which passes the entire buffer to the algorithm's `update` methods and + resets the buffer thereafter. - Note that it is expected that the learn method of a policy will perform + Note that it is expected that the update method of the algorithm will perform batching when using this trainer. """ - # for mypy - assert isinstance(BaseTrainer.__doc__, str) - __doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) - - def policy_update_fn( + def _update_step( self, result: CollectStatsBase | None = None, ) -> TrainingStats: """Perform one on-policy update by passing the entire buffer to the policy's update method.""" - assert self.train_collector is not None - # TODO: add logging like in off-policy. Iteration over minibatches currently happens in the learn implementation of - # on-policy algos like PG or PPO + assert self.config.train_collector is not None + # TODO: add logging like in off-policy. Iteration over minibatches currently happens in the algorithms themselves. log.info( - f"Performing on-policy update on buffer of length {len(self.train_collector.buffer)}", + f"Performing on-policy update on buffer of length {len(self.config.train_collector.buffer)}", ) - training_stat = self.policy.update( + training_stat = self.algorithm.update( sample_size=0, - buffer=self.train_collector.buffer, + buffer=self.config.train_collector.buffer, # Note: sample_size is None, so the whole buffer is used for the update. # The kwargs are in the end passed to the .learn method, which uses # batch_size to iterate through the buffer in mini-batches @@ -872,7 +975,7 @@ def policy_update_fn( ) # just for logging, no functional role - self.policy_update_time += training_stat.train_time + self._policy_update_time += training_stat.train_time # TODO: remove the gradient step counting in trainers? Doesn't seem like # it's important and it adds complexity self._gradient_step += 1 @@ -880,7 +983,7 @@ def policy_update_fn( self._gradient_step += 1 elif self.config.batch_size > 0: self._gradient_step += int( - (len(self.train_collector.buffer) - 0.1) // self.config.batch_size, + (len(self.config.train_collector.buffer) - 0.1) // self.config.batch_size, ) # Note 1: this is the main difference to the off-policy trainer! @@ -892,7 +995,7 @@ def policy_update_fn( # _ep_rew and _ep_len. This means that such quantities can no longer be computed # from samples still contained in the buffer, which is also not clean # TODO: improve this situation - self.train_collector.reset_buffer(keep_statistics=True) + self.config.train_collector.reset_buffer(keep_statistics=True) # The step is the number of mini-batches used for the update, so essentially self._update_moving_avg_stats_and_log_update_data(training_stat) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 1f4369f72..b3e14350f 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -13,7 +13,10 @@ from tianshou.data.collector import BaseCollector from tianshou.utils import BaseLogger +# TODO: This module should be eliminated: Move methods into Trainer + +# TODO: Improve name def test_episode( collector: BaseCollector, test_fn: Callable[[int, int | None], None] | None, From 31280fba9d395e9596dcbb4172d26f137707d68b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 8 Mar 2025 13:03:37 +0100 Subject: [PATCH 037/230] v2: Adapt BCQ and test_bcq (identified some issues; see TODOs) --- examples/offline/d4rl_bcq.py | 4 +- test/offline/gather_pendulum_data.py | 46 ++++--- test/offline/test_bcq.py | 55 ++++---- tianshou/policy/__init__.py | 4 +- tianshou/policy/base.py | 11 ++ tianshou/policy/imitation/bcq.py | 185 +++++++++++++++------------ 6 files changed, 169 insertions(+), 136 deletions(-) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 37cf18446..2da381f2e 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -13,7 +13,7 @@ from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import BCQPolicy +from tianshou.policy import BCQ from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger @@ -152,7 +152,7 @@ def test_bcq() -> None: ).to(args.device) vae_optim = torch.optim.Adam(vae.parameters()) - policy: BCQPolicy = BCQPolicy( + policy: BCQ = BCQ( actor_perturbation=actor, actor_perturbation_optim=actor_optim, critic=critic1, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index a698530a6..4aafb88be 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -11,8 +11,8 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import SAC from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import SACTrainingStats -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats +from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -114,10 +114,14 @@ def gather_data() -> VectorReplayBuffer: target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) - policy: SAC[SACTrainingStats] = SAC( + policy = SACPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: SAC[SACTrainingStats] = SAC( + policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, @@ -125,12 +129,11 @@ def gather_data() -> VectorReplayBuffer: gamma=args.gamma, alpha=args.alpha, estimation_step=args.n_step, - action_space=env.action_space, ) # collector buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") @@ -144,20 +147,21 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - save_best_fn=save_best_fn, - stop_fn=stop_fn, - logger=logger, - ).run() + algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + ) + ) train_collector.reset() collector_stats = train_collector.collect(n_step=args.buffer_size) print(collector_stats) diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index dfa1dfe50..6c87f64c6 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -11,9 +11,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, BCQPolicy -from tianshou.policy.imitation.bcq import BCQTrainingStats -from tianshou.trainer import OfflineTrainer +from tianshou.policy import BCQ, Algorithm +from tianshou.policy.imitation.bcq import BCQPolicy, BCQTrainingStats +from tianshou.trainer.base import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation @@ -143,29 +143,31 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: ).to(args.device) vae_optim = torch.optim.Adam(vae.parameters()) - policy: BCQPolicy[BCQTrainingStats] = BCQPolicy( + policy = BCQPolicy( actor_perturbation=actor, - actor_perturbation_optim=actor_optim, critic=critic, - critic_optim=critic_optim, vae=vae, - vae_optim=vae_optim, action_space=env.action_space, - device=args.device, + ) + algorithm: BCQ[BCQTrainingStats] = BCQ( + policy=policy, + actor_perturbation_optim=actor_optim, + critic_optim=critic_optim, + vae_optim=vae_optim, gamma=args.gamma, tau=args.tau, lmbda=args.lmbda, - ) + ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' @@ -181,24 +183,25 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold def watch() -> None: - policy.load_state_dict( + algorithm.load_state_dict( torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) - collector = Collector[CollectStats](policy, env) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) # trainer - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - stop_fn=stop_fn, - logger=logger, - show_progress=args.show_progress, - ).run() + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + show_progress=args.show_progress, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 86925062a..b8ee50547 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -22,7 +22,7 @@ from tianshou.policy.modelfree.redq import REDQ from tianshou.policy.modelfree.discrete_sac import DiscreteSAC from tianshou.policy.imitation.base import ImitationLearning -from tianshou.policy.imitation.bcq import BCQPolicy +from tianshou.policy.imitation.bcq import BCQ from tianshou.policy.imitation.cql import CQLPolicy from tianshou.policy.imitation.td3_bc import TD3BCPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy @@ -54,7 +54,7 @@ "REDQ", "DiscreteSAC", "ImitationLearning", - "BCQPolicy", + "BCQ", "CQLPolicy", "TD3BCPolicy", "DiscreteBCQPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index d7d60d491..5b6d20a00 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -805,6 +805,17 @@ def create_trainer(self, config: "OffPolicyTrainingConfig") -> "OffPolicyTrainer return OffPolicyTrainer(self, config) +class OfflineAlgorithm( + Algorithm[TPolicy, "OfflineTrainingConfig", TTrainingStats], + Generic[TPolicy, TTrainingStats], + ABC, +): + def create_trainer(self, config: "OfflineTrainingConfig") -> "OfflineTrainer": + from tianshou.trainer.base import OfflineTrainer + + return OfflineTrainer(self, config) + + # TODO must become Policy not Algorithm class RandomActionPolicy(Algorithm): def __init__( diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 6d42f29bb..267f45417 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -10,8 +10,12 @@ from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import Algorithm -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import ( + OfflineAlgorithm, + Policy, + TLearningRateScheduler, + TrainingStats, +) from tianshou.utils.net.continuous import VAE from tianshou.utils.optim import clone_optimizer @@ -27,100 +31,41 @@ class BCQTrainingStats(TrainingStats): TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) -class BCQPolicy(Algorithm, Generic[TBCQTrainingStats]): - """Implementation of BCQ algorithm. arXiv:1812.02900. - - :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` - :param actor_perturbation_optim: the optimizer for actor network. - :param critic: the first critic network. - :param critic_optim: the optimizer for the first critic network. - :param critic2: the second critic network. - :param critic2_optim: the optimizer for the second critic network. - :param vae: the VAE network, generating actions similar to those in batch. - :param vae_optim: the optimizer for the VAE network. - :param device: which device to create this model on. - :param gamma: discount factor, in [0, 1]. - :param tau: param for soft update of the target network. - :param lmbda: param for Clipped Double Q-learning. - :param forward_sampled_times: the number of sampled actions in forward function. - The policy samples many actions and takes the action with the max value. - :param num_sampled_action: the number of sampled actions in calculating target Q. - The algorithm samples several actions using VAE, and perturbs each action to get the target Q. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. - """ - +class BCQPolicy(Policy): def __init__( self, *, actor_perturbation: torch.nn.Module, - actor_perturbation_optim: torch.optim.Optimizer, - critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, action_space: gym.Space, + critic: torch.nn.Module, vae: VAE, - vae_optim: torch.optim.Optimizer, - critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, - # TODO: remove? Many policies don't use this - device: str | torch.device = "cpu", - gamma: float = 0.99, - tau: float = 0.005, - lmbda: float = 0.75, forward_sampled_times: int = 100, - num_sampled_action: int = 10, observation_space: gym.Space | None = None, action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: - # actor is Perturbation! + """ + :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` + :param critic: the first critic network. + :param vae: the VAE network, generating actions similar to those in batch. + :param forward_sampled_times: the number of sampled actions in forward function. + The policy samples many actions and takes the action with the max value. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + Only used if the action_space is continuous. + """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, ) self.actor_perturbation = actor_perturbation - self.actor_perturbation_target = copy.deepcopy(self.actor_perturbation) - self.actor_perturbation_optim = actor_perturbation_optim - self.critic = critic - self.critic_target = copy.deepcopy(self.critic) - self.critic_optim = critic_optim - - critic2 = critic2 or copy.deepcopy(critic) - critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2 = critic2 - self.critic2_target = copy.deepcopy(self.critic2) - self.critic2_optim = critic2_optim - self.vae = vae - self.vae_optim = vae_optim - - self.gamma = gamma - self.tau = tau - self.lmbda = lmbda - self.device = device self.forward_sampled_times = forward_sampled_times - self.num_sampled_action = num_sampled_action - - def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" - self.training = mode - self.actor_perturbation.train(mode) - self.critic.train(mode) - self.critic2.train(mode) - return self def forward( self, @@ -131,7 +76,8 @@ def forward( """Compute action over the given batch data.""" # There is "obs" in the Batch # obs_group: several groups. Each group has a state. - obs_group: torch.Tensor = to_torch(batch.obs, device=self.device) + device = next(self.parameters()).device + obs_group: torch.Tensor = to_torch(batch.obs, device=device) act_group = [] for obs_orig in obs_group: # now obs is (state_dim) @@ -148,12 +94,78 @@ def forward( act_group = np.array(act_group) return cast(ActBatchProtocol, Batch(act=act_group)) + +class BCQ(OfflineAlgorithm[BCQPolicy, TBCQTrainingStats], Generic[TBCQTrainingStats]): + """Implementation of Batch-Constrained Deep Q-learning (BCQ) algorithm. arXiv:1812.02900.""" + + def __init__( + self, + *, + policy: BCQPolicy, + actor_perturbation_optim: torch.optim.Optimizer, + critic_optim: torch.optim.Optimizer, + vae_optim: torch.optim.Optimizer, + critic2: torch.nn.Module | None = None, + critic2_optim: torch.optim.Optimizer | None = None, + gamma: float = 0.99, + tau: float = 0.005, + lmbda: float = 0.75, + num_sampled_action: int = 10, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :param actor_perturbation_optim: the optimizer for the policy's actor perturbation network. + :param critic_optim: the optimizer for the policy's critic network. + :param critic2: the second critic network; if None, clone the critic from the policy + :param critic2_optim: the optimizer for the second critic network; if None, clone optimizer of first critic + :param vae_optim: the optimizer for the VAE network. + :param gamma: discount factor, in [0, 1]. + :param tau: param for soft update of the target network. + :param lmbda: param for Clipped Double Q-learning. + :param num_sampled_action: the number of sampled actions in calculating target Q. + The algorithm samples several actions using VAE, and perturbs each action to get the target Q. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + # actor is Perturbation! + super().__init__( + policy=policy, + lr_scheduler=lr_scheduler, + ) + self.actor_perturbation_target = copy.deepcopy(self.policy.actor_perturbation) + self.actor_perturbation_optim = actor_perturbation_optim + + self.critic_target = copy.deepcopy(self.policy.critic) + self.critic_optim = critic_optim + + critic2 = critic2 or copy.deepcopy(self.policy.critic) + critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) + self.critic2 = critic2 + self.critic2_target = copy.deepcopy(self.critic2) + self.critic2_optim = critic2_optim + + self.vae_optim = vae_optim + + self.gamma = gamma + self.tau = tau + self.lmbda = lmbda + self.num_sampled_action = num_sampled_action + + def train(self, mode: bool = True) -> Self: + """Set the module in training mode, except for the target network.""" + # TODO: vae is not considered; this is probably a bug! + self.training = mode + self.policy.actor_perturbation.train(mode) + self.policy.critic.train(mode) + self.critic2.train(mode) + return self + def sync_weight(self) -> None: """Soft-update the weight for the target network.""" - self._polyak_parameter_update(self.critic_target, self.critic, self.tau) + self._polyak_parameter_update(self.critic_target, self.policy.critic, self.tau) self._polyak_parameter_update(self.critic2_target, self.critic2, self.tau) self._polyak_parameter_update( - self.actor_perturbation_target, self.actor_perturbation, self.tau + self.actor_perturbation_target, self.policy.actor_perturbation, self.tau ) def _update_with_batch( @@ -164,12 +176,15 @@ def _update_with_batch( ) -> TBCQTrainingStats: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) - batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) + # TODO: This does not use policy.forward but computes things directly, which seems odd + + device = next(self.parameters()).device + batch: Batch = to_torch(batch, dtype=torch.float, device=device) obs, act = batch.obs, batch.act batch_size = obs.shape[0] # mean, std: (state.shape[0], latent_dim) - recon, mean, std = self.vae(obs, act) + recon, mean, std = self.policy.vae(obs, act) recon_loss = F.mse_loss(act, recon) # (....) is D_KL( N(mu, sigma) || N(0,1) ) KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() @@ -186,7 +201,7 @@ def _update_with_batch( # now obs_next: (num_sampled_action * batch_size, state_dim) # perturbed action generated by VAE - act_next = self.vae.decode(obs_next) + act_next = self.policy.vae.decode(obs_next) # now obs_next: (num_sampled_action * batch_size, action_dim) target_Q1 = self.critic_target(obs_next, act_next) target_Q2 = self.critic2_target(obs_next, act_next) @@ -208,7 +223,7 @@ def _update_with_batch( ) target_Q = target_Q.float() - current_Q1 = self.critic(obs, act) + current_Q1 = self.policy.critic(obs, act) current_Q2 = self.critic2(obs, act) critic1_loss = F.mse_loss(current_Q1, target_Q) @@ -221,11 +236,11 @@ def _update_with_batch( self.critic_optim.step() self.critic2_optim.step() - sampled_act = self.vae.decode(obs) - perturbed_act = self.actor_perturbation(obs, sampled_act) + sampled_act = self.policy.vae.decode(obs) + perturbed_act = self.policy.actor_perturbation(obs, sampled_act) # max - actor_loss = -self.critic(obs, perturbed_act).mean() + actor_loss = -self.policy.critic(obs, perturbed_act).mean() self.actor_perturbation_optim.zero_grad() actor_loss.backward() From 319469ba47d6d0d6dc00443759955611f8744bd9 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 10 Mar 2025 12:52:15 +0100 Subject: [PATCH 038/230] v2: Move the functions gather_info and test_episode from trainer.utils to Trainer The functions are now static methods; code is unchanged. These functions are strongly tied to the trainer and were not used elsewhere. The module trainer.utils was removed. --- CHANGELOG.md | 5 +- tianshou/trainer/__init__.py | 1 - tianshou/trainer/base.py | 85 +++++++++++++++++++++++++++++++--- tianshou/trainer/utils.py | 90 ------------------------------------ 4 files changed, 83 insertions(+), 98 deletions(-) delete mode 100644 tianshou/trainer/utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index dca302677..812cecd0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,13 @@ precisely the options that are applicable. * The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely the methods and attributes a user should reasonably access. - * Further changes affecting usage: + * Further changes potentially affecting usage: * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full episodes) + * See also "Issues resolved" below (as issue resolution can result in usage changes) + * Further internal changes unlikely to affect usage: + * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` * Issues resolved: * Methods `run` and `reset`: Parameter `reset_prior_to_run` of `run` was never respected if it was set to `False`, because the implementation of `__iter__` (now removed) would call `reset` regardless - and calling `reset` diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 426f13080..065af3e99 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -9,7 +9,6 @@ OnPolicyTrainingConfig, Trainer, ) -from tianshou.trainer.utils import gather_info, test_episode __all__ = [ "Trainer", diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 3c9b5723d..d2d1395e5 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -42,11 +42,11 @@ InfoStats, ReplayBuffer, SequenceSummaryStats, + TimingStats, ) from tianshou.data.buffer.base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase from tianshou.policy.base import TrainingStats -from tianshou.trainer.utils import gather_info, test_episode from tianshou.utils import ( BaseLogger, LazyLogger, @@ -403,7 +403,7 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F if self.config.test_collector is not None: assert self.config.episode_per_test is not None assert not isinstance(self.config.test_collector, AsyncCollector) # Issue 700 - test_result = test_episode( + test_result = self._test_episode( self.config.test_collector, self.config.test_fn, self._start_epoch, @@ -458,6 +458,55 @@ def _create_epoch_pbar_data_dict( ) -> dict[str, str]: pass + @staticmethod + def _gather_info( + start_time: float, + policy_update_time: float, + gradient_step: int, + best_score: float, + best_reward: float, + best_reward_std: float, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, + ) -> InfoStats: + """A simple wrapper of gathering information from collectors. + + :return: InfoStats object with times computed based on the `start_time` and + episode/step counts read off the collectors. No computation of + expensive statistics is done here. + """ + duration = max(0.0, time.time() - start_time) + test_time = 0.0 + update_speed = 0.0 + train_time_collect = 0.0 + if test_collector is not None: + test_time = test_collector.collect_time + + if train_collector is not None: + train_time_collect = train_collector.collect_time + update_speed = train_collector.collect_step / (duration - test_time) + + timing_stat = TimingStats( + total_time=duration, + train_time=duration - test_time, + train_time_collect=train_time_collect, + train_time_update=policy_update_time, + test_time=test_time, + update_speed=update_speed, + ) + + return InfoStats( + gradient_step=gradient_step, + best_score=best_score, + best_reward=best_reward, + best_reward_std=best_reward_std, + train_step=train_collector.collect_step if train_collector is not None else 0, + train_episode=train_collector.collect_episode if train_collector is not None else 0, + test_step=test_collector.collect_step if test_collector is not None else 0, + test_episode=test_collector.collect_episode if test_collector is not None else 0, + timing=timing_stat, + ) + def execute_epoch(self) -> EpochStats: self._epoch += 1 @@ -498,7 +547,7 @@ def execute_epoch(self) -> EpochStats: if self.config.test_collector is not None: test_collect_stats, self._stop_fn_flag = self._test_step() - info_stats = gather_info( + info_stats = self._gather_info( start_time=self._start_time, policy_update_time=self._policy_update_time, gradient_step=self._gradient_step, @@ -521,12 +570,36 @@ def execute_epoch(self) -> EpochStats: info_stat=info_stats, ) + @staticmethod + def _test_episode( + collector: BaseCollector, + test_fn: Callable[[int, int | None], None] | None, + epoch: int, + n_episode: int, + logger: BaseLogger | None = None, + global_step: int | None = None, + reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, + ) -> CollectStats: + """A simple wrapper of testing policy in collector.""" + collector.reset(reset_stats=False) + if test_fn: + test_fn(epoch, global_step) + result = collector.collect(n_episode=n_episode) + if reward_metric: # TODO: move into collector + rew = reward_metric(result.returns) + result.returns = rew + result.returns_stat = SequenceSummaryStats.from_sequence(rew) + if logger and global_step is not None: + assert result.n_collected_episodes > 0 + logger.log_test_data(asdict(result), global_step) + return result + def _test_step(self) -> tuple[CollectStats, bool]: """Perform one test step.""" assert self.config.episode_per_test is not None assert self.config.test_collector is not None stop_fn_flag = False - test_stat = test_episode( + test_stat = self._test_episode( self.config.test_collector, self.config.test_fn, self._epoch, @@ -613,7 +686,7 @@ def run( while self._epoch < self.config.max_epoch and not self._stop_fn_flag: self.execute_epoch() - return gather_info( + return self._gather_info( start_time=self._start_time, policy_update_time=self._policy_update_time, gradient_step=self._gradient_step, @@ -846,7 +919,7 @@ def _test_in_train( ): assert self.config.test_collector is not None assert self.config.episode_per_test is not None and self.config.episode_per_test > 0 - test_result = test_episode( + test_result = self._test_episode( self.config.test_collector, self.config.test_fn, self._epoch, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py deleted file mode 100644 index b3e14350f..000000000 --- a/tianshou/trainer/utils.py +++ /dev/null @@ -1,90 +0,0 @@ -import time -from collections.abc import Callable -from dataclasses import asdict - -import numpy as np - -from tianshou.data import ( - CollectStats, - InfoStats, - SequenceSummaryStats, - TimingStats, -) -from tianshou.data.collector import BaseCollector -from tianshou.utils import BaseLogger - -# TODO: This module should be eliminated: Move methods into Trainer - - -# TODO: Improve name -def test_episode( - collector: BaseCollector, - test_fn: Callable[[int, int | None], None] | None, - epoch: int, - n_episode: int, - logger: BaseLogger | None = None, - global_step: int | None = None, - reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, -) -> CollectStats: - """A simple wrapper of testing policy in collector.""" - collector.reset(reset_stats=False) - if test_fn: - test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode) - if reward_metric: # TODO: move into collector - rew = reward_metric(result.returns) - result.returns = rew - result.returns_stat = SequenceSummaryStats.from_sequence(rew) - if logger and global_step is not None: - assert result.n_collected_episodes > 0 - logger.log_test_data(asdict(result), global_step) - return result - - -def gather_info( - start_time: float, - policy_update_time: float, - gradient_step: int, - best_score: float, - best_reward: float, - best_reward_std: float, - train_collector: BaseCollector | None = None, - test_collector: BaseCollector | None = None, -) -> InfoStats: - """A simple wrapper of gathering information from collectors. - - :return: InfoStats object with times computed based on the `start_time` and - episode/step counts read off the collectors. No computation of - expensive statistics is done here. - """ - duration = max(0.0, time.time() - start_time) - test_time = 0.0 - update_speed = 0.0 - train_time_collect = 0.0 - if test_collector is not None: - test_time = test_collector.collect_time - - if train_collector is not None: - train_time_collect = train_collector.collect_time - update_speed = train_collector.collect_step / (duration - test_time) - - timing_stat = TimingStats( - total_time=duration, - train_time=duration - test_time, - train_time_collect=train_time_collect, - train_time_update=policy_update_time, - test_time=test_time, - update_speed=update_speed, - ) - - return InfoStats( - gradient_step=gradient_step, - best_score=best_score, - best_reward=best_reward, - best_reward_std=best_reward_std, - train_step=train_collector.collect_step if train_collector is not None else 0, - train_episode=train_collector.collect_episode if train_collector is not None else 0, - test_step=test_collector.collect_step if test_collector is not None else 0, - test_episode=test_collector.collect_episode if test_collector is not None else 0, - timing=timing_stat, - ) From 928c7a7f588fb9a52f3247256a68d497ac9e7940 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 10 Mar 2025 13:26:40 +0100 Subject: [PATCH 039/230] v2: Trainer: Turn functions moved from trainer.util into methods without arguments * gather_info -> _create_info_stats * test_episode -> _collect_test_episodes --- tianshou/trainer/base.py | 120 ++++++++++----------------------------- 1 file changed, 31 insertions(+), 89 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index d2d1395e5..66eafff5e 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -388,6 +388,7 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F This has no effect if `reset_collectors` is False. """ self._env_step = 0 + if self.config.resume_from_log: ( self._start_epoch, @@ -395,6 +396,8 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F self._gradient_step, ) = self._logger.restore_data() + self._epoch = self._start_epoch + self._start_time = time.time() if reset_collectors: @@ -403,15 +406,7 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F if self.config.test_collector is not None: assert self.config.episode_per_test is not None assert not isinstance(self.config.test_collector, AsyncCollector) # Issue 700 - test_result = self._test_episode( - self.config.test_collector, - self.config.test_fn, - self._start_epoch, - self.config.episode_per_test, - self._logger, - self._env_step, - self.config.reward_metric, - ) + test_result = self._collect_test_episodes() assert test_result.returns_stat is not None # for mypy self._best_epoch = self._start_epoch self._best_reward, self._best_reward_std = ( @@ -422,7 +417,6 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F if self.config.save_best_fn: self.config.save_best_fn(self.algorithm) - self._epoch = self._start_epoch self._stop_fn_flag = False class _TrainingStepResult(ABC): @@ -458,24 +452,16 @@ def _create_epoch_pbar_data_dict( ) -> dict[str, str]: pass - @staticmethod - def _gather_info( - start_time: float, - policy_update_time: float, - gradient_step: int, - best_score: float, - best_reward: float, - best_reward_std: float, - train_collector: BaseCollector | None = None, - test_collector: BaseCollector | None = None, + def _create_info_stats( + self, ) -> InfoStats: - """A simple wrapper of gathering information from collectors. + test_collector = self.config.test_collector + if isinstance(self.config, OnlineTrainingConfig): + train_collector = self.config.train_collector + else: + train_collector = None - :return: InfoStats object with times computed based on the `start_time` and - episode/step counts read off the collectors. No computation of - expensive statistics is done here. - """ - duration = max(0.0, time.time() - start_time) + duration = max(0.0, time.time() - self._start_time) test_time = 0.0 update_speed = 0.0 train_time_collect = 0.0 @@ -490,16 +476,16 @@ def _gather_info( total_time=duration, train_time=duration - test_time, train_time_collect=train_time_collect, - train_time_update=policy_update_time, + train_time_update=self._policy_update_time, test_time=test_time, update_speed=update_speed, ) return InfoStats( - gradient_step=gradient_step, - best_score=best_score, - best_reward=best_reward, - best_reward_std=best_reward_std, + gradient_step=self._gradient_step, + best_score=self._best_score, + best_reward=self._best_reward, + best_reward_std=self._best_reward_std, train_step=train_collector.collect_step if train_collector is not None else 0, train_episode=train_collector.collect_episode if train_collector is not None else 0, test_step=test_collector.collect_step if test_collector is not None else 0, @@ -547,18 +533,7 @@ def execute_epoch(self) -> EpochStats: if self.config.test_collector is not None: test_collect_stats, self._stop_fn_flag = self._test_step() - info_stats = self._gather_info( - start_time=self._start_time, - policy_update_time=self._policy_update_time, - gradient_step=self._gradient_step, - best_score=self._best_score, - best_reward=self._best_reward, - best_reward_std=self._best_reward_std, - train_collector=self.config.train_collector - if isinstance(self.config, OnlineTrainingConfig) - else None, - test_collector=self.config.test_collector, - ) + info_stats = self._create_info_stats() self._logger.log_info_data(asdict(info_stats), self._epoch) @@ -570,28 +545,21 @@ def execute_epoch(self) -> EpochStats: info_stat=info_stats, ) - @staticmethod - def _test_episode( - collector: BaseCollector, - test_fn: Callable[[int, int | None], None] | None, - epoch: int, - n_episode: int, - logger: BaseLogger | None = None, - global_step: int | None = None, - reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, + def _collect_test_episodes( + self, ) -> CollectStats: - """A simple wrapper of testing policy in collector.""" + collector = self.config.test_collector collector.reset(reset_stats=False) - if test_fn: - test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode) - if reward_metric: # TODO: move into collector - rew = reward_metric(result.returns) + if self.config.test_fn: + self.config.test_fn(self._epoch, self._env_step) + result = collector.collect(n_episode=self.config.episode_per_test) + if self.config.reward_metric: # TODO: move into collector + rew = self.config.reward_metric(result.returns) result.returns = rew result.returns_stat = SequenceSummaryStats.from_sequence(rew) - if logger and global_step is not None: + if self._logger and self._env_step is not None: assert result.n_collected_episodes > 0 - logger.log_test_data(asdict(result), global_step) + self._logger.log_test_data(asdict(result), self._env_step) return result def _test_step(self) -> tuple[CollectStats, bool]: @@ -599,15 +567,7 @@ def _test_step(self) -> tuple[CollectStats, bool]: assert self.config.episode_per_test is not None assert self.config.test_collector is not None stop_fn_flag = False - test_stat = self._test_episode( - self.config.test_collector, - self.config.test_fn, - self._epoch, - self.config.episode_per_test, - self._logger, - self._env_step, - self.config.reward_metric, - ) + test_stat = self._collect_test_episodes() assert test_stat.returns_stat is not None # for mypy rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std score = self._compute_score_fn(test_stat) @@ -686,18 +646,7 @@ def run( while self._epoch < self.config.max_epoch and not self._stop_fn_flag: self.execute_epoch() - return self._gather_info( - start_time=self._start_time, - policy_update_time=self._policy_update_time, - gradient_step=self._gradient_step, - best_score=self._best_score, - best_reward=self._best_reward, - best_reward_std=self._best_reward_std, - train_collector=self.config.train_collector - if isinstance(self.config, OnlineTrainingConfig) - else None, - test_collector=self.config.test_collector, - ) + return self._create_info_stats() class OfflineTrainer(Trainer[OfflineTrainingConfig]): @@ -919,14 +868,7 @@ def _test_in_train( ): assert self.config.test_collector is not None assert self.config.episode_per_test is not None and self.config.episode_per_test > 0 - test_result = self._test_episode( - self.config.test_collector, - self.config.test_fn, - self._epoch, - self.config.episode_per_test, - self._logger, - self._env_step, - ) + test_result = self._collect_test_episodes() assert test_result.returns_stat is not None # for mypy if self.config.stop_fn(test_result.returns_stat.mean): should_stop_training = True From 4cc7aaa155dcef3a44ce9b846012313c69408b25 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 10 Mar 2025 14:49:58 +0100 Subject: [PATCH 040/230] Do not allow ruff to remove unused imports from __init__.py files (removing the need to use __all__ in these files) --- pyproject.toml | 1 + tianshou/trainer/__init__.py | 10 ---------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d2ba919a3..bde53c605 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,6 +204,7 @@ max-complexity = 20 "test/**" = ["D103"] "docs/**" = ["D103"] "examples/**" = ["D103"] +"__init__.py" = ["F401"] # do not remove "unused" imports (F401) from __init__.py files [tool.poetry_bumpversion.file."tianshou/__init__.py"] diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 065af3e99..2f7d8fabc 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -9,13 +9,3 @@ OnPolicyTrainingConfig, Trainer, ) - -__all__ = [ - "Trainer", - "OffPolicyTrainer", - "OnPolicyTrainer", - "OfflineTrainer", - "OffPolicyTrainingConfig", - "OnPolicyTrainingConfig", - "OfflineTrainingConfig", -] From 9e2fa6e937cc95eee31c2d3e226414abc71f3d01 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 10 Mar 2025 16:02:46 +0100 Subject: [PATCH 041/230] v2: Trainer: Factorise/resolve inconsistencies in performance evaluation and early stopping * The two places that collected and evaluated test episodes (_test_in_train and _reset) in addition to _test_step were unified to use _test_step (with some minor parametrisation) and now log the results of the test step accordingly. * (Fix) The stop criterion stop_fn did not consider scores as computed by compute_score_fn but instead always used mean returns (i.e. it was assumed that the default implementation of compute_score_fn applies). This is an inconsistency which has been resolved. --- CHANGELOG.md | 8 ++- tianshou/trainer/base.py | 146 ++++++++++++++++++++++++--------------- 2 files changed, 97 insertions(+), 57 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 812cecd0b..53f79d164 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ * See also "Issues resolved" below (as issue resolution can result in usage changes) * Further internal changes unlikely to affect usage: * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` + * The two places that collected and evaluated test episodes (`_test_in_train` and `_reset`) in addition to + `_test_step` were unified to use `_test_step` (with some minor parametrisation) and now log the results + of the test step accordingly. * Issues resolved: * Methods `run` and `reset`: Parameter `reset_prior_to_run` of `run` was never respected if it was set to `False`, because the implementation of `__iter__` (now removed) would call `reset` regardless - and calling `reset` @@ -27,9 +30,10 @@ * Inconsistent configuration options now raise exceptions rather than silently ignoring the issue in the hope that default behaviour will achieve what the user intended. One condition where `test_in_train` was silently set to `False` was removed and replaced by a warning. + * The stop criterion `stop_fn` did not consider scores as computed by `compute_score_fn` but instead always used + mean returns (i.e. it was assumed that the default implementation of `compute_score_fn` applies). + This is an inconsistency which has been resolved. * Open issues: - * TODO: For `test_in_train`, the early stopping criterion was computed incorrectly (did not consider `compute_score_fn`, - i.e. it assumed that the default implementation applies) * TODO: _gradient_step counter is not incorrect; replace it with a simple update step counter * Migration information at a glance: * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 66eafff5e..04b0c7159 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -200,6 +200,12 @@ def __post_init__(self): raise ValueError( "save_best_fn is set while test steps are disabled (test_collector is None)" ) + else: + if self.episode_per_test < 1: + raise ValueError( + "episode_per_test must be positive if test steps are enabled " + "(test_collector not None)" + ) @dataclass(kw_only=True) @@ -237,11 +243,13 @@ class OnlineTrainingConfig(TrainingConfig): test_in_train: bool = True """ - Whether to apply an effective test step triggered by the early stopping criterion (given by :attr:`stop_fn`) - being satisfied in the data collected in the collect step within a training step: - If the stop criterion is satisfied, it collects `episode_per_test` test episodes (as in a test step) - and determines whether the stop criterion is also satisfied by the episodes thus collected, - and if so, training stops early. + Whether to apply a test step within a training step depending on the early stopping criterion + (given by :attr:`stop_fn`) being satisfied based on the data collected within the training step. + Specifically, after each collect step, we check whether the early stopping criterion (:attr:`stop_fn`) + would be satisfied by data we collected (provided that at least one episode was indeed completed, such + that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step + (collecting :attr:`episode_per_test` episodes in order to evaluate performance), and if the early + stopping criterion is also satisfied based on the test data, we stop training early. """ def __post_init__(self): @@ -327,10 +335,18 @@ def __init__( self._start_time = time.time() self._stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) + self._start_epoch = 0 + + self._epoch = self._start_epoch + + # initialize stats on the best model found during a test step + # NOTE: The values don't matter, as in the first test step (which is taken in reset() + # at the beginning of the training process), these will all be updated self._best_score = 0.0 self._best_reward = 0.0 self._best_reward_std = 0.0 - self._start_epoch = 0 + self._best_epoch = self._start_epoch + # This is only used for logging but creeps into the implementations # of the trainers. I believe it would be better to remove self._gradient_step = 0 @@ -347,8 +363,6 @@ def __init__( config.compute_score_fn or self._compute_score_fn_default ) - self._epoch = self._start_epoch - self._best_epoch = self._start_epoch self._stop_fn_flag = False @staticmethod @@ -403,19 +417,11 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F if reset_collectors: self._reset_collectors(reset_buffer=reset_collector_buffers) + # make an initial test step to determine the initial best model if self.config.test_collector is not None: assert self.config.episode_per_test is not None assert not isinstance(self.config.test_collector, AsyncCollector) # Issue 700 - test_result = self._collect_test_episodes() - assert test_result.returns_stat is not None # for mypy - self._best_epoch = self._start_epoch - self._best_reward, self._best_reward_std = ( - test_result.returns_stat.mean, - test_result.returns_stat.std, - ) - self._best_score = self._compute_score_fn(test_result) - if self.config.save_best_fn: - self.config.save_best_fn(self.algorithm) + self._test_step(force_update_best=True, log_msg_prefix="Initial test step") self._stop_fn_flag = False @@ -545,6 +551,29 @@ def execute_epoch(self) -> EpochStats: info_stat=info_stats, ) + def _should_stop_training_early( + self, *, score: float | None = None, collect_stats: CollectStats | None = None + ) -> bool: + """ + Determine whether, given the early stopping criterion stop_fn, training shall be stopped early + based on the score achieved or the collection stats (from which the score could be computed). + """ + # If no stop criterion is defined, we can never stop training early + if self.config.stop_fn is None: + return False + + if score is None: + if collect_stats is None: + raise ValueError("Must provide collect_stats if score is not given") + + # If no episodes were collected, we have no episode returns and thus cannot compute a score + if collect_stats.n_collected_episodes == 0: + return False + + score = self._compute_score_fn(collect_stats) + + return self.config.stop_fn(score) + def _collect_test_episodes( self, ) -> CollectStats: @@ -562,27 +591,43 @@ def _collect_test_episodes( self._logger.log_test_data(asdict(result), self._env_step) return result - def _test_step(self) -> tuple[CollectStats, bool]: - """Perform one test step.""" + def _test_step( + self, force_update_best: bool = False, log_msg_prefix: str | None = None + ) -> tuple[CollectStats, bool]: + """Performs one test step. + + :param log_msg_prefix: a prefix to prepend to the log message, which is to establish the context within + which the test step is being carried out + :param force_update_best: whether to force updating of the best model stats (best score, reward, etc.) + and call the `save_best_fn` callback + """ assert self.config.episode_per_test is not None assert self.config.test_collector is not None - stop_fn_flag = False + + # collect test episodes test_stat = self._collect_test_episodes() assert test_stat.returns_stat is not None # for mypy + + # check whether we have a new best score and, if so, update stats and save the model + # (or if forced) rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std score = self._compute_score_fn(test_stat) - if self._best_epoch < 0 or self._best_score < score: + if score > self._best_score or force_update_best: self._best_score = score self._best_epoch = self._epoch self._best_reward = float(rew) self._best_reward_std = rew_std if self.config.save_best_fn: self.config.save_best_fn(self.algorithm) + + # log results cur_info, best_info = "", "" if score != rew: cur_info, best_info = f", score: {score: .6f}", f", best_score: {self._best_score:.6f}" + if log_msg_prefix is None: + log_msg_prefix = f"Epoch #{self._epoch}" log_msg = ( - f"Epoch #{self._epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}" + f"{log_msg_prefix}: test_reward: {rew:.6f} ± {rew_std:.6f},{cur_info}" f" best_reward: {self._best_reward:.6f} ± " f"{self._best_reward_std:.6f}{best_info} in #{self._best_epoch}" ) @@ -590,8 +635,8 @@ def _test_step(self) -> tuple[CollectStats, bool]: if self.config.verbose: print(log_msg, flush=True) - if self.config.stop_fn and self.config.stop_fn(self._best_reward): - stop_fn_flag = True + # determine whether training shall be stopped early + stop_fn_flag = self._should_stop_training_early(score=self._best_score) return test_stat, stop_fn_flag @@ -782,9 +827,9 @@ def _training_step(self) -> _TrainingStepResult: collect_stats = self._collect_training_data() # determine whether we should stop training based on the data collected - should_stop_training = self._test_in_train( - collect_stats, - ) + should_stop_training = False + if self.config.test_in_train: + should_stop_training = self._test_in_train(collect_stats) # perform gradient update step (if not already done) training_stats: TrainingStats | None = None @@ -840,41 +885,32 @@ def _collect_training_data(self) -> CollectStats: def _test_in_train( self, - collect_stats: CollectStats, + train_collect_stats: CollectStats, ) -> bool: """ - Performs performance testing based on the early stopping criterion being satisfied based on the - data collected in the current training step: - If the stop criterion is satisfied, it collects `episode_per_test` test episodes (as in a test step) - and determines whether the stop criterion is also satisfied by the episodes thus collected, - and if so, returns True, indicating that training stops early. + Performs a test step if the data collected in the current training step suggests that performance + is good enough to stop training early. If the test step confirms that performance is indeed good + enough, returns True, and False otherwise. - Therefore, if the early stopping criterion is satisfied on the data collected for training, - this effectively carries out a test step and updates the respective metrics (best_reward, etc.) - accordingly. + Specifically, applies the early stopping criterion to the data collected in the current training step, + and if the criterion is satisfied, performs a test step which returns the relevant result. - :param collect_stats: the data collection stats from the preceding collection step + :param train_collect_stats: the data collection stats from the preceding collection step :return: flag indicating whether to stop training early """ should_stop_training = False - # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics - with policy_within_training_step(self.algorithm.policy, enabled=False): - if ( - collect_stats.n_collected_episodes > 0 - and self.config.test_in_train - and self.config.stop_fn - and self.config.stop_fn(collect_stats.returns_stat.mean) # type: ignore - ): - assert self.config.test_collector is not None - assert self.config.episode_per_test is not None and self.config.episode_per_test > 0 - test_result = self._collect_test_episodes() - assert test_result.returns_stat is not None # for mypy - if self.config.stop_fn(test_result.returns_stat.mean): - should_stop_training = True - self._best_reward = test_result.returns_stat.mean - self._best_reward_std = test_result.returns_stat.std - self._best_score = self._compute_score_fn(test_result) + # check whether the stop criterion is satisfied based on the data collected in the training step + # (if any full episodes were indeed collected) + if train_collect_stats.n_collected_episodes > 0 and self._should_stop_training_early( + collect_stats=train_collect_stats + ): + # apply a test step, temporarily switching out of "is_training_step" semantics such that the policy can + # be evaluated, in order to determine whether we should stop training + with policy_within_training_step(self.algorithm.policy, enabled=False): + _, should_stop_training = self._test_step( + log_msg_prefix=f"Test step triggered by train stats (env_step={self._env_step})" + ) return should_stop_training From 9da2c00930cea71046728e8e096cb48f474e37ab Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 10 Mar 2025 16:37:47 +0100 Subject: [PATCH 042/230] v2: Trainer: Replace flawed gradient step counter with an update step counter The gradient step counter made incorrect assumptions about the underlying algorithms (such that the count was actually incorrect for many algorithms). Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed to reflect the change in semantics. --- CHANGELOG.md | 5 ++-- tianshou/data/stats.py | 4 +-- tianshou/trainer/base.py | 39 ++++++++++++++-------------- tianshou/utils/logger/base.py | 10 +++---- tianshou/utils/logger/tensorboard.py | 8 +++--- tianshou/utils/logger/wandb.py | 8 +++--- 6 files changed, 37 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 53f79d164..3174626aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,8 +33,9 @@ * The stop criterion `stop_fn` did not consider scores as computed by `compute_score_fn` but instead always used mean returns (i.e. it was assumed that the default implementation of `compute_score_fn` applies). This is an inconsistency which has been resolved. - * Open issues: - * TODO: _gradient_step counter is not incorrect; replace it with a simple update step counter + * The `gradient_step` counter was flawed (as it made assumptions about the underlying algorithms, which were + not valid). It has been replaced with an update step counter. + Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed accordingly. * Migration information at a glance: * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: `OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`, `OfflineTrainingConfig`. diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index ec2cd6703..964479113 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -79,8 +79,8 @@ class TimingStats(DataclassPPrintMixin): class InfoStats(DataclassPPrintMixin): """A data structure for storing information about the learning process.""" - gradient_step: int - """The total gradient step.""" + update_step: int + """The total number of update steps that have been taken.""" best_score: float """The best score over the test results. The one with the highest score will be considered the best model.""" best_reward: float diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 04b0c7159..d5fc6f78a 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -170,6 +170,11 @@ class TrainingConfig(ToStringMixin): logger: BaseLogger | None = None """ the logger with which to log statistics during training/testing/updating. To not log anything, use None. + + Relevant step types for logger update intervals: + * `update_interval`: update step + * `train_interval`: env step + * `test_interval`: env step """ verbose: bool = True @@ -347,9 +352,11 @@ def __init__( self._best_reward_std = 0.0 self._best_epoch = self._start_epoch - # This is only used for logging but creeps into the implementations - # of the trainers. I believe it would be better to remove - self._gradient_step = 0 + self._current_update_step = 0 + """ + the current (1-based) update step/training step number (to be incremented before the actual step is taken) + """ + self._env_step = 0 """ the step counter which is used to track progress of the training process. @@ -357,6 +364,7 @@ def __init__( environment steps collected, and for offline training, it is the total number of environment steps that have been sampled from the replay buffer to perform gradient updates. """ + self._policy_update_time = 0.0 self._compute_score_fn: Callable[[CollectStats], float] = ( @@ -402,12 +410,13 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F This has no effect if `reset_collectors` is False. """ self._env_step = 0 + self._current_update_step = 0 if self.config.resume_from_log: ( self._start_epoch, self._env_step, - self._gradient_step, + self._current_update_step, ) = self._logger.restore_data() self._epoch = self._start_epoch @@ -488,7 +497,7 @@ def _create_info_stats( ) return InfoStats( - gradient_step=self._gradient_step, + update_step=self._current_update_step, best_score=self._best_score, best_reward=self._best_reward, best_reward_std=self._best_reward_std, @@ -510,6 +519,7 @@ def execute_epoch(self) -> EpochStats: ) as t: while steps_done_in_this_epoch < self.config.step_per_epoch and not self._stop_fn_flag: # perform a training step and update progress + self._current_update_step += 1 training_step_result = self._training_step() steps_done_in_this_epoch += training_step_result.get_steps_in_epoch_advancement() t.update(training_step_result.get_steps_in_epoch_advancement()) @@ -523,7 +533,7 @@ def execute_epoch(self) -> EpochStats: pbar_data_dict = self._create_epoch_pbar_data_dict(training_step_result) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) - pbar_data_dict["gradient_step"] = str(self._gradient_step) + pbar_data_dict["update_step"] = str(self._current_update_step) t.set_postfix(**pbar_data_dict) test_collect_stats = None @@ -531,7 +541,7 @@ def execute_epoch(self) -> EpochStats: self._logger.save_data( self._epoch, self._env_step, - self._gradient_step, + self._current_update_step, self.config.save_checkpoint_fn, ) @@ -652,7 +662,7 @@ def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStat update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( cur_losses_dict, ) - self._logger.log_update_data(asdict(update_stat), self._gradient_step) + self._logger.log_update_data(asdict(update_stat), self._current_update_step) # TODO: seems convoluted, there should be a better way of dealing with the moving average stats def _update_moving_avg_stats_and_get_averaged_data( @@ -729,7 +739,6 @@ def get_env_step_advancement(self) -> int: def _training_step(self) -> _TrainingStepResult: with policy_within_training_step(self.algorithm.policy): - self._gradient_step += 1 # Note: since sample_size=batch_size, this will perform # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. @@ -987,7 +996,6 @@ def _update_step( def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" - self._gradient_step += 1 # Note: since sample_size=batch_size, this will perform # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. @@ -1008,7 +1016,7 @@ def _update_step( self, result: CollectStatsBase | None = None, ) -> TrainingStats: - """Perform one on-policy update by passing the entire buffer to the policy's update method.""" + """Perform one on-policy update by passing the entire buffer to the algorithm's update method.""" assert self.config.train_collector is not None # TODO: add logging like in off-policy. Iteration over minibatches currently happens in the algorithms themselves. log.info( @@ -1027,15 +1035,6 @@ def _update_step( # just for logging, no functional role self._policy_update_time += training_stat.train_time - # TODO: remove the gradient step counting in trainers? Doesn't seem like - # it's important and it adds complexity - self._gradient_step += 1 - if self.config.batch_size is None: - self._gradient_step += 1 - elif self.config.batch_size > 0: - self._gradient_step += int( - (len(self.config.train_collector.buffer) - 0.1) // self.config.batch_size, - ) # Note 1: this is the main difference to the off-policy trainer! # The second difference is that batches of data are sampled without replacement diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 2ff6c6760..305606ef0 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -104,7 +104,7 @@ def log_update_data(self, log_data: dict, step: int) -> None: # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.UPDATE}/gradient_step", step, log_data) + self.write(f"{DataScope.UPDATE}/update_step", step, log_data) self.last_log_update_step = step def log_info_data(self, log_data: dict, step: int) -> None: @@ -125,14 +125,14 @@ def save_data( self, epoch: int, env_step: int, - gradient_step: int, + update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. :param epoch: the epoch in trainer. :param env_step: the env_step in trainer. - :param gradient_step: the gradient_step in trainer. + :param update_step: the update step count in the trainer. :param function save_checkpoint_fn: a hook defined by user, see trainer documentation for detail. """ @@ -144,7 +144,7 @@ def restore_data(self) -> tuple[int, int, int]: If it finds nothing or an error occurs during the recover process, it will return the default parameters. - :return: epoch, env_step, gradient_step. + :return: epoch, env_step, update_step. """ @staticmethod @@ -180,7 +180,7 @@ def save_data( self, epoch: int, env_step: int, - gradient_step: int, + update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: pass diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index ef504cb58..1400d8a52 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -106,18 +106,18 @@ def save_data( self, epoch: int, env_step: int, - gradient_step: int, + update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: self.last_save_step = epoch - save_checkpoint_fn(epoch, env_step, gradient_step) + save_checkpoint_fn(epoch, env_step, update_step) self.write("save/epoch", epoch, {"save/epoch": epoch}) self.write("save/env_step", env_step, {"save/env_step": env_step}) self.write( "save/gradient_step", - gradient_step, - {"save/gradient_step": gradient_step}, + update_step, + {"save/gradient_step": update_step}, ) def restore_data(self) -> tuple[int, int, int]: diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 9172bf54b..129e20668 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -132,20 +132,20 @@ def save_data( self, epoch: int, env_step: int, - gradient_step: int, + update_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, ) -> None: """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. :param epoch: the epoch in trainer. :param env_step: the env_step in trainer. - :param gradient_step: the gradient_step in trainer. + :param update_step: the gradient_step in trainer. :param function save_checkpoint_fn: a hook defined by user, see trainer documentation for detail. """ if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: self.last_save_step = epoch - checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step) + checkpoint_path = save_checkpoint_fn(epoch, env_step, update_step) checkpoint_artifact = wandb.Artifact( "run_" + self.wandb_run.id + "_checkpoint", # type: ignore @@ -153,7 +153,7 @@ def save_data( metadata={ "save/epoch": epoch, "save/env_step": env_step, - "save/gradient_step": gradient_step, + "save/gradient_step": update_step, "checkpoint_path": str(checkpoint_path), }, ) From fd19577e1f04024d93336ba1c531dc45592ca2dc Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 10 Mar 2025 20:34:37 +0100 Subject: [PATCH 043/230] v2: Move function for adding exploration noise from Algorithm to Policy, allowing collectors to require only a Policy, not an Algorithm * Renamed the function: exploration_noise -> add_exploration_noise * Move all required, noise-related attributes and functions to the Policy implementations * Introduce new policy base class ContinuousPolicyWithExplorationNoise to avoid code duplication --- CHANGELOG.md | 2 + examples/atari/atari_c51.py | 6 +- examples/discrete/discrete_dqn.py | 6 +- test/base/test_env_finite.py | 12 +-- test/continuous/test_ddpg.py | 2 +- test/continuous/test_td3.py | 2 +- test/discrete/test_bdqn.py | 4 +- test/discrete/test_c51.py | 8 +- test/discrete/test_dqn.py | 8 +- test/discrete/test_fqf.py | 8 +- test/discrete/test_iqn.py | 8 +- test/discrete/test_qrdqn.py | 8 +- test/discrete/test_rainbow.py | 8 +- tianshou/data/collector.py | 26 +++---- tianshou/policy/base.py | 48 ++++++------ tianshou/policy/modelbased/icm.py | 6 +- tianshou/policy/modelfree/bdqn.py | 57 ++++++--------- tianshou/policy/modelfree/ddpg.py | 89 ++++++++++++++--------- tianshou/policy/modelfree/discrete_sac.py | 1 - tianshou/policy/modelfree/dqn.py | 57 +++++++-------- tianshou/policy/modelfree/redq.py | 9 ++- tianshou/policy/modelfree/sac.py | 15 ++-- tianshou/policy/modelfree/td3.py | 11 --- tianshou/policy/multiagent/mapolicy.py | 6 +- tianshou/trainer/base.py | 4 +- 25 files changed, 205 insertions(+), 206 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3174626aa..417e39cce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,8 @@ * The (auto-)updating logic is now completely encapsulated, reducing the complexity of the algorithms. * Implementations for continuous and discrete cases now share the same abstraction, making the codebase more consistent while preserving the original functionality. + * Introduced a policy base class `ContinuousPolicyWithExplorationNoise` which encapsulates noise generation + for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). * Fixed issues in the class hierarchy (e.g. violations of the Liskov substitution principle): * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOffPolicyAlgorithm` diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 886f991de..c4b52ed44 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -158,17 +158,17 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - algorithm.set_eps(eps) + policy.set_eps(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch: int, env_step: int | None) -> None: - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) # watch agent's performance def watch() -> None: print("Setup test envs ...") - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 2916aaaa8..3605c0985 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -75,8 +75,8 @@ def stop_fn(mean_rewards: float) -> bool: episode_per_test=test_num, batch_size=batch_size, update_per_step=1 / step_per_collect, - train_fn=lambda epoch, env_step: algorithm.set_eps(eps_train), - test_fn=lambda epoch, env_step: algorithm.set_eps(eps_test), + train_fn=lambda epoch, env_step: policy.set_eps(eps_train), + test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=stop_fn, logger=logger, ) @@ -84,7 +84,7 @@ def stop_fn(mean_rewards: float) -> bool: print(f"Finished training in {result.timing.total_time} seconds") # watch performance - algorithm.set_eps(eps_test) + policy.set_eps(eps_test) collector = ts.data.Collector[CollectStats](algorithm, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 2240e1d53..7e6065ada 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -16,11 +16,10 @@ ActBatchProtocol, BatchProtocol, ObsBatchProtocol, - RolloutBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type -from tianshou.policy import Algorithm +from tianshou.policy.base import Policy class DummyDataset(Dataset): @@ -204,7 +203,7 @@ class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): pass -class AnyPolicy(Algorithm): +class DummyPolicy(Policy): def __init__(self) -> None: super().__init__(action_space=Box(-1, 1, (1,))) @@ -216,9 +215,6 @@ def forward( ) -> ActBatchProtocol: return cast(ActBatchProtocol, Batch(act=np.stack([1] * len(batch)))) - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> None: - pass - def _finite_env_factory(dataset: Dataset, num_replicas: int, rank: int) -> Callable[[], FiniteEnv]: return lambda: FiniteEnv(dataset, num_replicas, rank) @@ -247,7 +243,7 @@ def validate(self) -> None: def test_finite_dummy_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) - policy = AnyPolicy() + policy = DummyPolicy() test_collector = Collector[CollectStats](policy, envs, exploration_noise=True) test_collector.reset() @@ -263,7 +259,7 @@ def test_finite_dummy_vector_env() -> None: def test_finite_subproc_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) - policy = AnyPolicy() + policy = DummyPolicy() test_collector = Collector[CollectStats](policy, envs, exploration_noise=True) test_collector.reset() diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 552287a4a..1f4b11474 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -88,6 +88,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor=actor, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, ) policy_optim = torch.optim.Adam(policy.parameters(), lr=args.actor_lr) @@ -98,7 +99,6 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), estimation_step=args.n_step, ) # collector diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 13c7a1db9..e26f1f1c9 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -101,6 +101,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: policy = DDPGPolicy( actor=actor, action_space=env.action_space, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), ) algorithm: TD3 = TD3( policy=policy, @@ -111,7 +112,6 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 55ce9bc7b..1a9d9efc4 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -123,10 +123,10 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) - algorithm.set_eps(eps) + policy.set_eps(eps) def test_fn(epoch: int, env_step: int | None) -> None: - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 96e4628f9..0023863d5 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -140,15 +140,15 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - algorithm.set_eps(args.eps_train) + policy.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - algorithm.set_eps(eps) + policy.set_eps(eps) else: - algorithm.set_eps(0.1 * args.eps_train) + policy.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index e6f102120..0f0b1ebb6 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -129,15 +129,15 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - algorithm.set_eps(args.eps_train) + policy.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - algorithm.set_eps(eps) + policy.set_eps(eps) else: - algorithm.set_eps(0.1 * args.eps_train) + policy.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) # train result = algorithm.run_training( diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 0f0fec9f9..34c891bd5 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -147,15 +147,15 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - algorithm.set_eps(args.eps_train) + policy.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - algorithm.set_eps(eps) + policy.set_eps(eps) else: - algorithm.set_eps(0.1 * args.eps_train) + policy.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) # train result = algorithm.run_training( diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index e2f4eac09..bd6fd21d5 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -143,15 +143,15 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - algorithm.set_eps(args.eps_train) + policy.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - algorithm.set_eps(eps) + policy.set_eps(eps) else: - algorithm.set_eps(0.1 * args.eps_train) + policy.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) # train result = algorithm.run_training( diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 8fa02c417..cc15effdc 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -134,15 +134,15 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - algorithm.set_eps(args.eps_train) + policy.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - algorithm.set_eps(eps) + policy.set_eps(eps) else: - algorithm.set_eps(0.1 * args.eps_train) + policy.set_eps(0.1 * args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) # trainer result = algorithm.run_training( diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 8b6c6fc15..f1fdf8bcf 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -150,12 +150,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annealing, just a demo if env_step <= 10000: - algorithm.set_eps(args.eps_train) + policy.set_eps(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - algorithm.set_eps(eps) + policy.set_eps(eps) else: - algorithm.set_eps(0.1 * args.eps_train) + policy.set_eps(0.1 * args.eps_train) # beta annealing, just a demo if args.prioritized_replay: if env_step <= 10000: @@ -167,7 +167,7 @@ def train_fn(epoch: int, env_step: int) -> None: buf.set_beta(beta) def test_fn(epoch: int, env_step: int | None) -> None: - algorithm.set_eps(args.eps_test) + policy.set_eps(args.eps_test) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1c34aa130..0f3fbaf31 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -31,7 +31,7 @@ ) from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import Algorithm -from tianshou.policy.base import episode_mc_return_to_go +from tianshou.policy.base import Policy, episode_mc_return_to_go from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode @@ -309,7 +309,7 @@ class BaseCollector(Generic[TCollectStats], ABC): def __init__( self, - algorithm: Algorithm, + policy: Policy | Algorithm, env: BaseVectorEnv | gym.Env, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, @@ -327,7 +327,7 @@ def __init__( self.buffer: ReplayBuffer | ReplayBufferManager = buffer self.raise_on_nan_in_buffer = raise_on_nan_in_buffer - self.algorithm = algorithm + self.policy = policy.policy if isinstance(policy, Algorithm) else policy self.env = cast(BaseVectorEnv, env) self.exploration_noise = exploration_noise self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 @@ -469,7 +469,7 @@ def collect( self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) pre_collect_time = time.time() - with torch_train_mode(self.algorithm, enabled=False): + with torch_train_mode(self.policy, enabled=False): collect_stats = self._collect( n_step=n_step, n_episode=n_episode, @@ -548,7 +548,7 @@ class Collector(BaseCollector[TCollectStats], Generic[TCollectStats]): # def __init__( self, - algorithm: Algorithm, + policy: Policy | Algorithm, env: gym.Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, @@ -558,7 +558,7 @@ def __init__( collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: """ - :param algorithm: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch + :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch of actions from a batch of observations. :param env: a ``gymnasium.Env`` environment or a vectorized instance of the :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with @@ -599,7 +599,7 @@ def __init__( this is rarely necessary and is mainly done by "power users". """ super().__init__( - algorithm, + policy, env, buffer, exploration_noise=exploration_noise, @@ -691,7 +691,7 @@ def _compute_action_policy_hidden( # TODO: test whether envpool env explicitly except TypeError: # envpool's action space is not for per-env act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) - act_RA = self.algorithm.policy.map_action_inverse(np.array(act_normalized_RA)) + act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) policy_R = Batch() hidden_state_RH = None # TODO: instead use a (uniform) Distribution instance that corresponds to sampling from action_space @@ -701,15 +701,15 @@ def _compute_action_policy_hidden( info_batch = _HACKY_create_info_batch(last_info_R) obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) - act_batch_RA: ActBatchProtocol | DistBatchProtocol = self.algorithm.policy( + act_batch_RA: ActBatchProtocol | DistBatchProtocol = self.policy( obs_batch_R, last_hidden_state_RH, ) act_RA = to_numpy(act_batch_RA.act) if self.exploration_noise: - act_RA = self.algorithm.exploration_noise(act_RA, obs_batch_R) - act_normalized_RA = self.algorithm.policy.map_action(act_RA) + act_RA = self.policy.add_exploration_noise(act_RA, obs_batch_R) + act_normalized_RA = self.policy.map_action(act_RA) # TODO: cleanup the whole policy in batch thing # todo policy_R can also be none, check @@ -1084,7 +1084,7 @@ class AsyncCollector(Collector[CollectStats]): def __init__( self, - algorithm: Algorithm, + policy: Algorithm, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, @@ -1099,7 +1099,7 @@ def __init__( # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") super().__init__( - algorithm, + policy, env, buffer, exploration_noise, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5b6d20a00..bef0fec22 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -43,6 +43,7 @@ logger = logging.getLogger(__name__) TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers +TArrOrActBatch = TypeVar("TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") @dataclass(kw_only=True) @@ -357,6 +358,28 @@ def _compile() -> None: _gae_return(f32, f32, f64, b, 0.1, 0.1) _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + + def add_exploration_noise( + self, + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: + """(Optionally) adds noise to an actions computed by the policy's forward method for + exploration purposes. + + NOTE: The base implementation does not add any noise, but subclasses can override + this method to add appropriate mechanisms for adding noise. + + :param act: a data batch or numpy.ndarray containing actions computed by the policy's + forward method. + :param batch: the corresponding input batch that was passed to forward; provided for + advanced usage. + :return: actions in the same format as the input `act` but with added exploration + noise (if implemented - otherwise returns `act` unchanged). + """ + return act + TPolicy = TypeVar("TPolicy", bound=Policy) TTrainingConfig = TypeVar( @@ -439,31 +462,6 @@ def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id - # TODO: needed, since for most of offline algorithm, the algorithm itself doesn't - # have a method to add noise to action. - # So we add the default behavior here. It's a little messy, maybe one can - # find a better way to do this. - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - """Modify the action from policy.forward with exploration noise. - - NOTE: currently does not add any noise! Needs to be overridden by subclasses - to actually do something. - - :param act: a data batch or numpy.ndarray which is the action taken by - policy.forward. - :param batch: the input batch for policy.forward, kept for advanced usage. - :return: action in the same form of input "act" but with added exploration - noise. - """ - return act - def _polyak_parameter_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: """Softly updates the parameters of a target network `tgt` with the parameters of a source network `src` using Polyak averaging: `tau * src + (1 - tau) * tgt`. diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index ab6a452e1..c3611bd31 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -106,12 +106,14 @@ def forward( _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - def exploration_noise( + # TODO move to policy + # @override + def add_exploration_noise( self, act: _TArrOrActBatch, batch: ObsBatchProtocol, ) -> _TArrOrActBatch: - return self.policy.exploration_noise(act, batch) + return self.policy.add_exploration_noise(act, batch) def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index e93645a08..08ffa5adf 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -16,7 +16,7 @@ RolloutBatchProtocol, ) from tianshou.policy import DeepQLearning -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TArrOrActBatch, TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats from tianshou.utils.net.common import BranchingNet @@ -31,7 +31,7 @@ class BDQNTrainingStats(DQNTrainingStats): TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) -class BranchingDuelingQNetworkPolicy(DQNPolicy): +class BranchingDuelingQNetworkPolicy(DQNPolicy[BranchingNet]): def __init__( self, *, @@ -67,6 +67,25 @@ def forward( result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) + def add_exploration_noise( + self, + act: TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> TArrOrActBatch: + # TODO: This looks problematic; the non-array case is silently ignored + if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): + bsz = len(act) + rand_mask = np.random.rand(bsz) < self.eps + rand_act = np.random.randint( + low=0, + high=self.model.action_per_branch, + size=(bsz, act.shape[-1]), + ) + if hasattr(batch.obs, "mask"): + rand_act += batch.obs.mask + act[rand_mask] = rand_act[rand_mask] + return act + class BranchingDuelingQNetwork(DeepQLearning[BranchingDuelingQNetworkPolicy, TBDQNTrainingStats]): """Implementation of the Branching Dueling Q-Network algorithm arXiv:1711.08946.""" @@ -114,16 +133,6 @@ def __init__( lr_scheduler=lr_scheduler, ) - # TODO: this used to be a public property called max_action_num, - # but it collides with an attr of the same name in base class - @property - def _action_per_branch(self) -> int: - return self.policy.model.action_per_branch - - @property - def _num_branches(self) -> int: - return self.policy.model.num_branches - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, @@ -158,8 +167,8 @@ def _compute_return( end_flag = end_flag[indice] mean_target_q = np.mean(target_q, -1) if len(target_q.shape) > 1 else target_q _target_q = rew + gamma * mean_target_q * (1 - end_flag) - target_q = np.repeat(_target_q[..., None], self._num_branches, axis=-1) - target_q = np.repeat(target_q[..., None], self._action_per_branch, axis=-1) + target_q = np.repeat(_target_q[..., None], self.policy.model.num_branches, axis=-1) + target_q = np.repeat(target_q[..., None], self.policy.model.action_per_branch, axis=-1) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update @@ -200,23 +209,3 @@ def _update_with_batch( self._iter += 1 return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): - bsz = len(act) - rand_mask = np.random.rand(bsz) < self.eps - rand_act = np.random.randint( - low=0, - high=self._action_per_branch, - size=(bsz, act.shape[-1]), - ) - if hasattr(batch.obs, "mask"): - rand_act += batch.obs.mask - act[rand_mask] = rand_act[rand_mask] - return act diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 22fae5ecb..5fdc375a0 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -22,6 +22,7 @@ from tianshou.policy.base import ( OffPolicyAlgorithm, Policy, + TArrOrActBatch, TLearningRateScheduler, TPolicy, TrainingStats, @@ -41,11 +42,59 @@ class DDPGTrainingStats(TrainingStats): TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) -class DDPGPolicy(Policy): +class ContinuousPolicyWithExplorationNoise(Policy, ABC): + def __init__( + self, + *, + exploration_noise: BaseNoise | Literal["default"] | None = None, + action_space: gym.Space, + observation_space: gym.Space | None = None, + action_scaling: bool = True, + action_bound_method: Literal["clip"] | None = "clip", + ): + """ + :param exploration_noise: noise model for adding noise to continuous actions + for exploration. This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param action_space: Env's action space. + :param observation_space: Env's observation space. + :param action_scaling: if True, scale the action from [-1, 1] to the range + of action_space. Only used if the action_space is continuous. + :param action_bound_method: method to bound action to range [-1, 1]. + """ + super().__init__( + action_space=action_space, + observation_space=observation_space, + action_scaling=action_scaling, + action_bound_method=action_bound_method, + ) + if exploration_noise == "default": + exploration_noise = GaussianNoise(sigma=0.1) + self.exploration_noise = exploration_noise + + def set_exploration_noise(self, noise: BaseNoise | None) -> None: + """Set the exploration noise.""" + self.exploration_noise = noise + + def add_exploration_noise( + self, + act: TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> TArrOrActBatch: + if self.exploration_noise is None: + return act + if isinstance(act, np.ndarray): + return act + self.exploration_noise(act.shape) + warnings.warn("Cannot add exploration noise to non-numpy_array action.") + return act + + +class DDPGPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, actor: torch.nn.Module | Actor, + exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.Space, observation_space: gym.Space | None = None, action_scaling: bool = True, @@ -53,6 +102,10 @@ def __init__( ): """ :param actor: The actor network following the rules (s -> actions) + :param exploration_noise: add noise to continuous actions for exploration; + set to None for discrete action spaces. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). :param action_space: Env's action space. :param tau: Param for soft update of the target network. :param observation_space: Env's observation space. @@ -69,6 +122,7 @@ def __init__( "or set action_scaling to False and action_bound_method to None.", ) super().__init__( + exploration_noise=exploration_noise, action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, @@ -133,7 +187,6 @@ def __init__( critic_optim: torch.optim.Optimizer, tau: float = 0.005, gamma: float = 0.99, - exploration_noise: BaseNoise | Literal["default"] | None = None, estimation_step: int = 1, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: @@ -148,10 +201,6 @@ def __init__( :param critic_optim: the optimizer for the critic network. :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. - :param exploration_noise: add noise to continuous actions for exploration; - set to None for discrete action spaces. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() """ @@ -168,32 +217,8 @@ def __init__( self.critic_optim = critic_optim self.tau = tau self.gamma = gamma - if exploration_noise == "default": - exploration_noise = GaussianNoise(sigma=0.1) - # TODO: IMPORTANT - can't call this "exploration_noise" because confusingly, - # there is already a method called exploration_noise() in the base class - # Now this method doesn't apply any noise and is also not overridden. See TODO there - self._exploration_noise = exploration_noise self.estimation_step = estimation_step - def set_exp_noise(self, noise: BaseNoise | None) -> None: - """Set the exploration noise.""" - self._exploration_noise = noise - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - if self._exploration_noise is None: - return act - if isinstance(act, np.ndarray): - return act + self._exploration_noise(act.shape) - warnings.warn("Cannot add exploration noise to non-numpy_array action.") - return act - @staticmethod def _minimize_critic_squared_loss( batch: RolloutBatchProtocol, @@ -299,7 +324,6 @@ def __init__( critic_optim: torch.optim.Optimizer, tau: float = 0.005, gamma: float = 0.99, - exploration_noise: BaseNoise | Literal["default"] | None = "default", estimation_step: int = 1, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: @@ -310,8 +334,6 @@ def __init__( :param critic_optim: The optimizer for critic network. :param tau: Param for soft update of the target network. :param gamma: Discount factor, in [0, 1]. - :param exploration_noise: The exploration noise, added to the action. Defaults - to ``GaussianNoise(sigma=0.1)``. :param estimation_step: The number of steps to look ahead. :param lr_scheduler: if not None, will be called in `policy.update()`. """ @@ -323,7 +345,6 @@ def __init__( critic_optim=critic_optim, tau=tau, gamma=gamma, - exploration_noise=exploration_noise, estimation_step=estimation_step, ) self.actor_old = deepcopy(policy.actor) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index a3a879e83..8e808fe22 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -119,7 +119,6 @@ def __init__( tau=tau, gamma=gamma, estimation_step=estimation_step, - exploration_noise=None, lr_scheduler=lr_scheduler, ) self.alpha = FixedAlpha(alpha) if isinstance(alpha, float) else alpha diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c348daeb6..9536e7ca2 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -19,6 +19,7 @@ from tianshou.policy.base import ( OffPolicyAlgorithm, Policy, + TArrOrActBatch, TLearningRateScheduler, TrainingStats, ) @@ -33,13 +34,14 @@ class DQNTrainingStats(TrainingStats): TDQNTrainingStats = TypeVar("TDQNTrainingStats", bound=DQNTrainingStats) +TModel = TypeVar("TModel", bound=torch.nn.Module | Net) -class DQNPolicy(Policy): +class DQNPolicy(Policy, Generic[TModel]): def __init__( self, *, - model: torch.nn.Module | Net, + model: TModel, action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, ) -> None: @@ -56,6 +58,11 @@ def __init__( ) self.model = model self.max_action_num: int | None = None + self.eps = 0.0 + + def set_eps(self, eps: float) -> None: + """Set the eps for epsilon-greedy exploration.""" + self.eps = eps def forward( self, @@ -112,6 +119,25 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc logits = logits + to_torch_as(1 - mask, logits) * min_value return logits + def add_exploration_noise( + self, + act: TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> TArrOrActBatch: + # TODO: This looks problematic; the non-array case is silently ignored + if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): + bsz = len(act) + rand_mask = np.random.rand(bsz) < self.eps + assert ( + self.max_action_num is not None + ), "Can't call this method before max_action_num was set in first forward" + q = np.random.rand(bsz, self.max_action_num) # [0, 1] + if hasattr(batch.obs, "mask"): + q += batch.obs.mask + rand_act = q.argmax(axis=1) + act[rand_mask] = rand_act[rand_mask] + return act + TDQNPolicy = TypeVar("TDQNPolicy", bound=DQNPolicy) @@ -132,7 +158,6 @@ def __init__( *, policy: TDQNPolicy, optim: torch.optim.Optimizer, - # TODO: type violates Liskov substitution principle discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, @@ -161,7 +186,6 @@ def __init__( lr_scheduler=lr_scheduler, ) self.optim = optim - self.eps = 0.0 assert ( 0.0 <= discount_factor <= 1.0 ), f"discount factor should be in [0, 1] but got: {discount_factor}" @@ -180,11 +204,6 @@ def __init__( self.is_double = is_double self.clip_loss_grad = clip_loss_grad - # TODO: Should use two eps parameters: one for training, one for inference/testing - def set_eps(self, eps: float) -> None: - """Set the eps for epsilon-greedy exploration.""" - self.eps = eps - def train(self, mode: bool = True) -> Self: """Set the module in training mode, except for the target network.""" # TODO: Determine whether this is called correctly and who relies on this being called (for all subclasses) @@ -261,23 +280,3 @@ def _update_with_batch( self._iter += 1 return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] - - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - - def exploration_noise( - self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): - bsz = len(act) - rand_mask = np.random.rand(bsz) < self.eps - assert ( - self.policy.max_action_num is not None - ), "Can't call this method before max_action_num was set in first forward" - q = np.random.rand(bsz, self.policy.max_action_num) # [0, 1] - if hasattr(batch.obs, "mask"): - q += batch.obs.mask - rand_act = q.argmax(axis=1) - act[rand_mask] = rand_act[rand_mask] - return act diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 1638242d8..92891d50a 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -13,9 +13,10 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.policy.base import Policy, TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, + ContinuousPolicyWithExplorationNoise, DDPGTrainingStats, ) from tianshou.policy.modelfree.sac import Alpha, FixedAlpha @@ -33,11 +34,12 @@ class REDQTrainingStats(DDPGTrainingStats): TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) -class REDQPolicy(Policy): +class REDQPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, actor: torch.nn.Module | ActorProb, + exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.spaces.Box, deterministic_eval: bool = True, action_scaling: bool = True, @@ -59,6 +61,7 @@ def __init__( Only used if the action_space is continuous. """ super().__init__( + exploration_noise=exploration_noise, action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, @@ -116,7 +119,6 @@ def __init__( alpha: float | Alpha = 0.2, estimation_step: int = 1, actor_delay: int = 20, - exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, target_mode: Literal["mean", "min"] = "min", lr_scheduler: TLearningRateScheduler | None = None, @@ -152,7 +154,6 @@ def __init__( critic_optim=critic_optim, tau=tau, gamma=gamma, - exploration_noise=exploration_noise, estimation_step=estimation_step, lr_scheduler=lr_scheduler, ) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 3add77a86..244f1368f 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -14,7 +14,8 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.policy.base import Policy, TLearningRateScheduler, TrainingStats +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.modelfree.ddpg import ContinuousPolicyWithExplorationNoise from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ActorProb @@ -49,11 +50,12 @@ class SACTrainingStats(TrainingStats): TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) -class SACPolicy(Policy): +class SACPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, actor: torch.nn.Module | ActorProb, + exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, action_scaling: bool = True, action_bound_method: Literal["clip"] | None = "clip", @@ -62,6 +64,9 @@ def __init__( ): """ :param actor: the actor network following the rules (s -> dist_input_BD) + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). :param deterministic_eval: whether to use deterministic action (mode of Gaussian policy) in evaluation mode instead of stochastic action sampled by the policy. Does not affect training. @@ -76,6 +81,7 @@ def __init__( :param observation_space: the observation space of the environment """ super().__init__( + exploration_noise=exploration_noise, action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, @@ -204,7 +210,6 @@ def __init__( gamma: float = 0.99, alpha: float | Alpha = 0.2, estimation_step: int = 1, - exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: @@ -222,9 +227,6 @@ def __init__( :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param estimation_step: The number of steps to look ahead. - :param exploration_noise: add noise to action for exploration. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() """ @@ -237,7 +239,6 @@ def __init__( critic2_optim=critic2_optim, tau=tau, gamma=gamma, - exploration_noise=exploration_noise, estimation_step=estimation_step, lr_scheduler=lr_scheduler, ) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 139c52840..7ccf9c48b 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -55,7 +55,6 @@ def __init__( critic2_optim: torch.optim.Optimizer, tau: float = 0.005, gamma: float = 0.99, - exploration_noise: BaseNoise | Literal["default"] | None = None, estimation_step: int = 1, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: @@ -73,10 +72,6 @@ def __init__( If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. - :param exploration_noise: add noise to continuous actions for exploration; - set to None for discrete action spaces. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() """ @@ -88,7 +83,6 @@ def __init__( critic_optim=critic_optim, tau=tau, gamma=gamma, - exploration_noise=exploration_noise, estimation_step=estimation_step, ) if critic2 and not critic2_optim: @@ -136,7 +130,6 @@ def __init__( critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, gamma: float = 0.99, - exploration_noise: BaseNoise | Literal["default"] | None = "default", policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5, @@ -154,9 +147,6 @@ def __init__( If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. - :param exploration_noise: add noise to action for exploration. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). :param policy_noise: the noise used in updating policy network. :param update_actor_freq: the update frequency of actor network. :param noise_clip: the clipping range used in updating policy network. @@ -172,7 +162,6 @@ def __init__( critic2_optim=critic2_optim, tau=tau, gamma=gamma, - exploration_noise=exploration_noise, estimation_step=estimation_step, lr_scheduler=lr_scheduler, ) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 083ace04d..4df09473b 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -162,7 +162,9 @@ def process_fn( # type: ignore _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - def exploration_noise( + # TODO: Move to policy + # @override + def add_exploration_noise( self, act: _TArrOrActBatch, batch: ObsBatchProtocol, @@ -176,7 +178,7 @@ def exploration_noise( agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: continue - act[agent_index] = policy.exploration_noise(act[agent_index], batch[agent_index]) + act[agent_index] = policy.add_exploration_noise(act[agent_index], batch[agent_index]) return act def forward( # type: ignore diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index d5fc6f78a..2bd212643 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -786,10 +786,10 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F if ( self.config.test_in_train - and self.config.train_collector.algorithm is not self.algorithm + and self.config.train_collector.policy is not self.algorithm.policy ): log.warning( - "The training data collector's algorithm is not the same as the one being trained, " + "The training data collector's policy is not the same as the one being trained, " "yet test_in_train is enabled. This may lead to unexpected results." ) From 34965c75bef3a1b366467e9fc4a1ddee84bd007b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 11 Mar 2025 01:16:49 +0100 Subject: [PATCH 044/230] v2: Adapt DiscreteCRR, test_discrete_crr and gather_cartpole_data * Fix hierarchy: DiscreteCRR (offline algorithm) was derived from Reinforce (on-policy) * Factor out DiscountedReturnComputation to avoid code duplication * Add specialization DiscreteActorPolicy and util function for default discrete dist_fn --- CHANGELOG.md | 1 + examples/offline/atari_crr.py | 4 +- test/offline/gather_cartpole_data.py | 52 ++++---- test/offline/test_discrete_crr.py | 39 +++--- tianshou/policy/__init__.py | 4 +- tianshou/policy/imitation/discrete_crr.py | 109 ++++++++-------- tianshou/policy/modelfree/a2c.py | 2 +- tianshou/policy/modelfree/pg.py | 145 +++++++++++++++++----- tianshou/trainer/base.py | 6 +- tianshou/utils/net/discrete.py | 5 + 10 files changed, 233 insertions(+), 134 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 417e39cce..e16f86ac4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,7 @@ * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOffPolicyAlgorithm` * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) + * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) * `NPG`: Inherit from `AbstractActorCriticWithAdvantage` instead of `A2C` (which is now has the same base class) * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index cd77730f6..237ac2f44 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -16,7 +16,7 @@ from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteCRRPolicy +from tianshou.policy import DiscreteCRR from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils.net.common import ActorCritic @@ -122,7 +122,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy - policy: DiscreteCRRPolicy = DiscreteCRRPolicy( + policy: DiscreteCRR = DiscreteCRR( actor=actor, critic=critic, optim=optim, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 11d0f32e7..8ec3c1dec 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -16,8 +16,8 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -97,10 +97,13 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: num_atoms=args.num_quantiles, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: QRDQN[QRDQNTrainingStats] = QRDQN( + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + algorithm = QRDQN( + policy=policy, + optim=optim, discount_factor=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, @@ -118,9 +121,9 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) train_collector.reset() - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) test_collector.reset() # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) @@ -148,29 +151,30 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + ) + ) assert stop_fn(result.best_reward) # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) policy.set_eps(0.2) - collector = Collector[CollectStats](policy, test_envs, buf, exploration_noise=True) + collector = Collector[CollectStats](algorithm, test_envs, buf, exploration_noise=True) collector.reset() collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index f13bc006d..a873df1a0 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -15,8 +15,9 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, DiscreteCRRPolicy -from tianshou.trainer import OfflineTrainer +from tianshou.policy import Algorithm, DiscreteCRR +from tianshou.policy.modelfree.pg import DiscreteActorPolicy +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic @@ -85,11 +86,14 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - policy: DiscreteCRRPolicy = DiscreteCRRPolicy( + policy = DiscreteActorPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: DiscreteCRR = DiscreteCRR( + policy=policy, critic=critic, optim=optim, - action_space=env.action_space, discount_factor=args.gamma, target_update_freq=args.target_update_freq, ).to(args.device) @@ -105,7 +109,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: buffer = gather_data() # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_crr") writer = SummaryWriter(log_path) @@ -117,17 +121,18 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.update_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index b8ee50547..aecb82ec7 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -27,7 +27,7 @@ from tianshou.policy.imitation.td3_bc import TD3BCPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy -from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy +from tianshou.policy.imitation.discrete_crr import DiscreteCRR from tianshou.policy.imitation.gail import GAILPolicy from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.modelbased.icm import ICMPolicy @@ -59,7 +59,7 @@ "TD3BCPolicy", "DiscreteBCQPolicy", "DiscreteCQLPolicy", - "DiscreteCRRPolicy", + "DiscreteCRR", "GAILPolicy", "PSRLPolicy", "ICMPolicy", diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 5c7395ff8..599f19c82 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -2,16 +2,20 @@ from dataclasses import dataclass from typing import Any, Literal, TypeVar -import gymnasium as gym +import numpy as np import torch import torch.nn.functional as F from torch.distributions import Categorical -from tianshou.data import to_torch, to_torch_as -from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import PGTrainingStats, Reinforce -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.data import ReplayBuffer, to_torch, to_torch_as +from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol +from tianshou.policy.base import OfflineAlgorithm, TLearningRateScheduler +from tianshou.policy.modelfree.pg import ( + DiscountedReturnComputation, + DiscreteActorPolicy, + PGTrainingStats, +) +from tianshou.utils.net.discrete import Critic @dataclass @@ -24,43 +28,15 @@ class DiscreteCRRTrainingStats(PGTrainingStats): TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteCRRTrainingStats) -class DiscreteCRRPolicy(Reinforce[TDiscreteCRRTrainingStats]): - r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param critic: the action-value critic (i.e., Q function) - network. (s -> Q(s, \*)) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param str policy_improvement_mode: type of the weight function f. Possible - values: "binary"/"exp"/"all". - :param ratio_upper_bound: when policy_improvement_mode is "exp", the value - of the exp function is upper-bounded by this parameter. - :param beta: when policy_improvement_mode is "exp", this is the denominator - of the exp function. - :param min_q_weight: weight for CQL loss/regularizer. Default to 10. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: if True, will normalize the *returns* - by subtracting the running mean and dividing by the running standard deviation. - Can be detrimental to performance! See TODO in process_fn. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed - explanation. - """ +class DiscreteCRR(OfflineAlgorithm[DiscreteActorPolicy, TDiscreteCRRTrainingStats]): + r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.""" def __init__( self, *, - actor: torch.nn.Module | Actor, + policy: DiscreteActorPolicy, critic: torch.nn.Module | Critic, optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, discount_factor: float = 0.99, policy_improvement_mode: Literal["exp", "binary", "all"] = "exp", ratio_upper_bound: float = 20.0, @@ -68,27 +44,43 @@ def __init__( min_q_weight: float = 10.0, target_update_freq: int = 0, reward_normalization: bool = False, - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + r""" + :param policy: the policy + :param critic: the action-value critic (i.e., Q function) + network. (s -> Q(s, \*)) + :param optim: the optimizer for the policy's actor and the critic networks. + :param discount_factor: in [0, 1]. + :param str policy_improvement_mode: type of the weight function f. Possible + values: "binary"/"exp"/"all". + :param ratio_upper_bound: when policy_improvement_mode is "exp", the value + of the exp function is upper-bounded by this parameter. + :param beta: when policy_improvement_mode is "exp", this is the denominator + of the exp function. + :param min_q_weight: weight for CQL loss/regularizer. Default to 10. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: if True, will normalize the *returns* + by subtracting the running mean and dividing by the running standard deviation. + Can be detrimental to performance! + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ super().__init__( - actor=actor, - optim=optim, - action_space=action_space, - dist_fn=lambda x: Categorical(logits=x), + policy=policy, + lr_scheduler=lr_scheduler, + ) + self.optim = optim + self.discounted_return_computation = DiscountedReturnComputation( discount_factor=discount_factor, reward_normalization=reward_normalization, - observation_space=observation_space, - action_scaling=False, - action_bound_method=None, - lr_scheduler=lr_scheduler, ) self.critic = critic self._target = target_update_freq > 0 self._freq = target_update_freq self._iter = 0 if self._target: - self.actor_old = deepcopy(self.actor) + self.actor_old = deepcopy(self.policy.actor) self.actor_old.eval() self.critic_old = deepcopy(self.critic) self.critic_old.eval() @@ -100,8 +92,20 @@ def __init__( self._beta = beta self._min_q_weight = min_q_weight - def sync_weight(self) -> None: - self.actor_old.load_state_dict(self.actor.state_dict()) + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + return self.discounted_return_computation.add_discounted_returns( + batch, + buffer, + indices, + ) + + def _update_lagged_network_weights(self) -> None: + self.actor_old.load_state_dict(self.policy.actor.state_dict()) self.critic_old.load_state_dict(self.critic.state_dict()) def _update_with_batch( # type: ignore @@ -111,7 +115,7 @@ def _update_with_batch( # type: ignore **kwargs: Any, ) -> TDiscreteCRRTrainingStats: if self._target and self._iter % self._freq == 0: - self.sync_weight() + self._update_lagged_network_weights() self.optim.zero_grad() q_t = self.critic(batch.obs) act = to_torch(batch.act, dtype=torch.long, device=q_t.device) @@ -124,10 +128,10 @@ def _update_with_batch( # type: ignore rew = to_torch_as(batch.rew, q_t_target) expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) expected_target_q[batch.done > 0] = 0.0 - target = rew.unsqueeze(1) + self.gamma * expected_target_q + target = rew.unsqueeze(1) + self.discounted_return_computation.gamma * expected_target_q critic_loss = 0.5 * F.mse_loss(qa_t, target) # Actor loss - act_target, _ = self.actor(batch.obs) + act_target, _ = self.policy.actor(batch.obs) dist = Categorical(logits=act_target) expected_policy_q = (q_t * dist.probs).sum(-1, keepdim=True) advantage = qa_t - expected_policy_q @@ -146,6 +150,7 @@ def _update_with_batch( # type: ignore self._iter += 1 return DiscreteCRRTrainingStats( # type: ignore[return-value] + # TODO: Type is wrong loss=loss.item(), actor_loss=actor_loss.item(), critic_loss=critic_loss.item(), diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index f210800e5..6ada890c0 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -116,7 +116,7 @@ def __init__( gae_lambda: float = 0.95, max_batchsize: int = 256, discount_factor: float = 0.99, - # TODO: rename to return_normalization? + # TODO: This algorithm does not seem to use the reward_normalization parameter. reward_normalization: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 0f29b7039..3d7d3cdf3 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -21,6 +21,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) +from tianshou.policy import Algorithm from tianshou.policy.base import ( OnPolicyAlgorithm, Policy, @@ -29,7 +30,7 @@ ) from tianshou.utils import RunningMeanStd from tianshou.utils.net.continuous import ActorProb -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.discrete import Actor, dist_fn_categorical_from_logits # Dimension Naming Convention # B - Batch Size @@ -60,8 +61,8 @@ def __init__( *, actor: torch.nn.Module | ActorProb | Actor, dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, deterministic_eval: bool = False, + action_space: gym.Space, observation_space: gym.Space | None = None, # TODO: why change the default from the base? action_scaling: bool = True, @@ -77,9 +78,9 @@ def __init__( or a categorical distribution taking `model_output=logits` for discrete action spaces. Note that as user, you are responsible for ensuring that the distribution is compatible with the action space. - :param action_space: env's action space. :param deterministic_eval: if True, will use deterministic action (the dist's mode) instead of stochastic one during evaluation. Does not affect training. + :param action_space: env's action space. :param observation_space: Env's observation space. :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. @@ -140,49 +141,62 @@ def forward( return cast(DistBatchProtocol, result) -class Reinforce(OnPolicyAlgorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats]): - """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm. +class DiscreteActorPolicy(ActorPolicy): + def __init__( + self, + *, + actor: torch.nn.Module | Actor, + dist_fn: TDistFnDiscrete = dist_fn_categorical_from_logits, + deterministic_eval: bool = False, + action_space: gym.Space, + observation_space: gym.Space | None = None, + ) -> None: + """ + :param actor: the actor network following the rules: (`s_B` -> `dist_input_BD`). + :param dist_fn: distribution class for computing the action. + Maps model_output -> distribution, typically a categorical distribution + taking `model_output=logits`. + :param deterministic_eval: if True, will use deterministic action (the dist's mode) + instead of stochastic one during evaluation. Does not affect training. + :param action_space: the environment's (discrete) action space. + :param observation_space: the environment's observation space. + """ + if not isinstance(action_space, gym.spaces.Discrete): + raise ValueError(f"Action space must be an instance of Discrete; got {action_space}") + super().__init__( + actor=actor, + dist_fn=dist_fn, + deterministic_eval=deterministic_eval, + action_space=action_space, + observation_space=observation_space, + action_scaling=False, + action_bound_method=None, + ) - .. seealso:: - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. - """ +TActorPolicy = TypeVar("TActorPolicy", bound=ActorPolicy) + +class DiscountedReturnComputation: def __init__( self, - *, - policy: ActorPolicy, discount_factor: float = 0.99, - # TODO: rename to return_normalization? reward_normalization: bool = False, - optim: torch.optim.Optimizer, - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: + ): """ - :param policy: the policy - :param optim: optimizer for actor network. - :param discount_factor: in [0, 1]. + :param discount_factor: the future reward discount factor gamma in [0, 1]. :param reward_normalization: if True, will normalize the *returns* by subtracting the running mean and dividing by the running standard deviation. - Can be detrimental to performance! See TODO in process_fn. - :param lr_scheduler: if not None, will be called in `policy.update()`. + Can be detrimental to performance! """ - super().__init__( - policy=policy, - lr_scheduler=lr_scheduler, - ) - self.optim = optim - self.ret_rms = RunningMeanStd() - self._eps = 1e-8 assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self.gamma = discount_factor self.rew_norm = reward_normalization + self.ret_rms = RunningMeanStd() + self.eps = 1e-8 - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, + def add_discounted_returns( + self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray ) -> BatchWithReturnsProtocol: r"""Compute the discounted returns (Monte Carlo estimates) for each transition. @@ -205,7 +219,7 @@ def process_fn( """ v_s_ = np.full(indices.shape, self.ret_rms.mean) # gae_lambda = 1.0 means we use Monte Carlo estimate - unnormalized_returns, _ = self.compute_episodic_return( + unnormalized_returns, _ = Algorithm.compute_episodic_return( batch, buffer, indices, @@ -218,7 +232,7 @@ def process_fn( # This should be addressed soon! if self.rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / np.sqrt( - self.ret_rms.var + self._eps, + self.ret_rms.var + self.eps, ) self.ret_rms.update(unnormalized_returns) else: @@ -226,6 +240,71 @@ def process_fn( batch: BatchWithReturnsProtocol return batch + +class Reinforce(OnPolicyAlgorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats]): + """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. + """ + + def __init__( + self, + *, + policy: TActorPolicy, + discount_factor: float = 0.99, + reward_normalization: bool = False, + optim: torch.optim.Optimizer, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :param optim: optimizer for the policy's actor network. + :param discount_factor: in [0, 1]. + :param reward_normalization: if True, will normalize the *returns* + by subtracting the running mean and dividing by the running standard deviation. + Can be detrimental to performance! + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + lr_scheduler=lr_scheduler, + ) + self.discounted_return_computation = DiscountedReturnComputation( + discount_factor=discount_factor, + reward_normalization=reward_normalization, + ) + self.optim = optim + + @property + def gamma(self) -> float: + return self.discounted_return_computation.gamma + + @property + def rew_norm(self) -> bool: + return self.discounted_return_computation.rew_norm + + @property + def ret_rms(self) -> RunningMeanStd: + return self.discounted_return_computation.ret_rms + + @property + def _eps(self) -> float: + return self.discounted_return_computation.eps + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + return self.discounted_return_computation.add_discounted_returns( + batch, + buffer, + indices, + ) + # TODO: why does mypy complain? def _update_with_batch( # type: ignore self, diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 2bd212643..b85536df3 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -248,11 +248,11 @@ class OnlineTrainingConfig(TrainingConfig): test_in_train: bool = True """ - Whether to apply a test step within a training step depending on the early stopping criterion + Whether to apply a test step within a training step depending on the early stopping criterion (given by :attr:`stop_fn`) being satisfied based on the data collected within the training step. - Specifically, after each collect step, we check whether the early stopping criterion (:attr:`stop_fn`) + Specifically, after each collect step, we check whether the early stopping criterion (:attr:`stop_fn`) would be satisfied by data we collected (provided that at least one episode was indeed completed, such - that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step + that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step (collecting :attr:`episode_per_test` episodes in order to evaluate performance), and if the early stopping criterion is also satisfied based on the test data, we stop training early. """ diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 6ea654929..e8b596f65 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -10,6 +10,11 @@ from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim +def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions.Categorical: + """Default distribution function for categorical actors.""" + return torch.distributions.Categorical(logits=logits) + + # TODO rename to DiscreteActor? class Actor(BaseActor): """Simple actor network for discrete action spaces. From 5e5ee4e7f4bc1395c600f7c4b977aa695d494d45 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 11 Mar 2025 01:44:14 +0100 Subject: [PATCH 045/230] v2: Adapt discrete/test_ppo --- test/discrete/test_ppo.py | 50 ++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 491a7c3c8..c21045a5c 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -5,15 +5,14 @@ import numpy as np import torch import torch.nn as nn -from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPO from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelfree.pg import DiscreteActorPolicy +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net from tianshou.utils.net.discrete import Actor, Critic @@ -96,12 +95,16 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy: PPO[PPOTrainingStats] = PPO( + policy = DiscreteActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + deterministic_eval=True, + ) + algorithm = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, - action_scaling=isinstance(env.action_space, Box), discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -111,18 +114,16 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: reward_normalization=args.rew_norm, dual_clip=args.dual_clip, value_clip=args.value_clip, - action_space=env.action_space, - deterministic_eval=True, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) @@ -135,18 +136,19 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) From 3fb465de34263733c87ebf2fc0a3ed349c6fec48 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 11 Mar 2025 02:00:41 +0100 Subject: [PATCH 046/230] v2: Adapt RandomActionPolicy --- tianshou/highlevel/algorithm.py | 6 ------ tianshou/policy/base.py | 11 +---------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 84793e712..b52fb0208 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -64,7 +64,6 @@ OffPolicyAlgorithm, OnPolicyAlgorithm, Policy, - RandomActionPolicy, ) from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.modelfree.dqn import DQNPolicy @@ -274,11 +273,6 @@ def create_trainer( ) -class RandomActionAlgorithmFactory(OnPolicyAlgorithmFactory): - def _create_algorithm(self, envs: Environments, device: TDevice) -> RandomActionPolicy: - return RandomActionPolicy(envs.get_action_space()) - - class ReinforceAlgorithmFactory(OnPolicyAlgorithmFactory): def __init__( self, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index bef0fec22..ddf8f8524 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -814,8 +814,7 @@ def create_trainer(self, config: "OfflineTrainingConfig") -> "OfflineTrainer": return OfflineTrainer(self, config) -# TODO must become Policy not Algorithm -class RandomActionPolicy(Algorithm): +class RandomActionPolicy(Policy): def __init__( self, action_space: gym.Space, @@ -836,14 +835,6 @@ def forward( act, next_state = self.actor.compute_action_batch(batch.obs), state return cast(ActStateBatchProtocol, Batch(act=act, state=next_state)) - def _update_with_batch( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TrainingStats: - return TrainingStats() - # TODO: rename? See docstring @njit From 5995429b28c66dc5ace3455038d93418773680d0 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 11 Mar 2025 22:45:26 +0100 Subject: [PATCH 047/230] v2: Adapt GAIL and test_gail --- examples/inverse/irl_gail.py | 4 +- test/offline/test_gail.py | 56 +++++++++------- tianshou/policy/__init__.py | 4 +- tianshou/policy/imitation/gail.py | 105 ++++++++++++------------------ 4 files changed, 75 insertions(+), 94 deletions(-) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 752624504..366deee89 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -24,7 +24,7 @@ ) from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs -from tianshou.policy import GAILPolicy +from tianshou.policy import GAIL from tianshou.policy.base import Algorithm from tianshou.trainer import OnPolicyTrainer from tianshou.utils import TensorboardLogger @@ -205,7 +205,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: ) print("dataset loaded") - policy: GAILPolicy = GAILPolicy( + policy: GAIL = GAIL( actor=actor, critic=critic, optim=optim, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 298cc91f4..c83c22aaf 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -11,8 +11,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, GAILPolicy -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy import GAIL, Algorithm +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -137,11 +138,15 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: GAILPolicy = GAILPolicy( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + ) + algorithm = GAIL( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, expert_buffer=buffer, disc_net=disc_net, disc_optim=disc_optim, @@ -157,15 +162,14 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log log_path = os.path.join(args.logdir, args.task, "gail") writer = SummaryWriter(log_path) @@ -184,7 +188,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - "model": policy.state_dict(), + "model": algorithm.state_dict(), "optim": optim.state_dict(), }, ckpt_path, @@ -197,27 +201,29 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) + algorithm.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + episode_per_collect=args.episode_per_collect, + step_per_collect=None, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index aecb82ec7..fbba9b217 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -28,7 +28,7 @@ from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRR -from tianshou.policy.imitation.gail import GAILPolicy +from tianshou.policy.imitation.gail import GAIL from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -60,7 +60,7 @@ "DiscreteBCQPolicy", "DiscreteCQLPolicy", "DiscreteCRR", - "GAILPolicy", + "GAIL", "PSRLPolicy", "ICMPolicy", "MultiAgentPolicyManager", diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 49a80084e..4b0846719 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from typing import Any, Literal, TypeVar +from typing import Any, TypeVar -import gymnasium as gym import numpy as np import torch import torch.nn.functional as F @@ -15,10 +14,9 @@ from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import PPO from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -32,59 +30,15 @@ class GailTrainingStats(PPOTrainingStats): TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats) -class GAILPolicy(PPO[TGailTrainingStats]): - r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. - - :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). - If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. - :param dist_fn: distribution class for computing the action. - :param action_space: env's action space - :param expert_buffer: the replay buffer containing expert experience. - :param disc_net: the discriminator network with input dim equals - state dim plus action dim and output dim equals 1. - :param disc_optim: the optimizer for the discriminator network. - :param disc_update_num: the number of discriminator grad steps per model grad step. - :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original - paper. - :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, - where c > 1 is a constant indicating the lower bound. Set to None - to disable dual-clip PPO. - :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. - :param advantage_normalization: whether to do per mini-batch advantage - normalization. - :param recompute_advantage: whether to recompute advantage every update - repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. - :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. - :param reward_normalization: normalize estimated values to have std close to 1. - :param deterministic_eval: if True, use deterministic evaluation. - :param observation_space: the space of the observation. - :param action_scaling: if True, scale the action from [-1, 1] to the range of - action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.PPOPolicy` for more detailed - explanation. - """ +class GAIL(PPO[TGailTrainingStats]): + r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.""" def __init__( self, *, - actor: torch.nn.Module | ActorProb | DiscreteActor, + policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistFnDiscrOrCont, - action_space: gym.Space, expert_buffer: ReplayBuffer, disc_net: torch.nn.Module, disc_optim: torch.optim.Optimizer, @@ -102,18 +56,40 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - deterministic_eval: bool = False, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip", "tanh"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + r""" + :param policy: the policy. + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic networks. + :param expert_buffer: the replay buffer containing expert experience. + :param disc_net: the discriminator network with input dim equals + state dim plus action dim and output dim equals 1. + :param disc_optim: the optimizer for the discriminator network. + :param disc_update_num: the number of discriminator grad steps per model grad step. + :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original + paper. + :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, + where c > 1 is a constant indicating the lower bound. Set to None + to disable dual-clip PPO. + :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. + :param advantage_normalization: whether to do per mini-batch advantage + normalization. + :param recompute_advantage: whether to recompute advantage every update + repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. + :param vf_coef: weight for value loss. + :param ent_coef: weight for entropy loss. + :param max_grad_norm: clipping gradients in back propagation. + :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ super().__init__( - actor=actor, + policy=policy, critic=critic, optim=optim, - dist_fn=dist_fn, - action_space=action_space, eps_clip=eps_clip, dual_clip=dual_clip, value_clip=value_clip, @@ -126,17 +102,16 @@ def __init__( max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - deterministic_eval=deterministic_eval, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, lr_scheduler=lr_scheduler, ) self.disc_net = disc_net self.disc_optim = disc_optim self.disc_update_num = disc_update_num self.expert_buffer = expert_buffer - self.action_dim = actor.output_dim + # TODO: This violates the type requirement; nn.Module does not necessarily have output_dim! + # Use IntermediateModule or perhaps a class more general than BaseActor which defines + # only the output dimension? + self.action_dim = self.policy.actor.output_dim def process_fn( self, @@ -184,7 +159,7 @@ def _update_with_batch( # type: ignore acc_pis.append((logits_pi < 0).float().mean().item()) acc_exps.append((logits_exp > 0).float().mean().item()) # update policy - ppo_loss_stat = super().learn(batch, batch_size, repeat, **kwargs) + ppo_loss_stat = super()._update_with_batch(batch, batch_size, repeat, **kwargs) disc_losses_summary = SequenceSummaryStats.from_sequence(losses) acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) From 09d8cf25bba49358b17807bf3dbf781f15d4a0c0 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 11 Mar 2025 23:10:34 +0100 Subject: [PATCH 048/230] v2: Adapt DiscreteBCQ and test_bcq * Fix hierarchy: DiscreteBCQ (offline) was derived from DQN (off-policy) --- examples/offline/atari_bcq.py | 4 +- test/offline/test_discrete_bcq.py | 49 +++--- tianshou/policy/__init__.py | 4 +- tianshou/policy/imitation/discrete_bcq.py | 200 +++++++++++++--------- 4 files changed, 154 insertions(+), 103 deletions(-) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 310087427..091b2f7c9 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -16,7 +16,7 @@ from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteBCQPolicy +from tianshou.policy import DiscreteBCQ from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils.net.common import ActorCritic @@ -121,7 +121,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: actor_critic = ActorCritic(policy_net, imitation_net) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy - policy: DiscreteBCQPolicy = DiscreteBCQPolicy( + policy: DiscreteBCQ = DiscreteBCQ( model=policy_net, imitator=imitation_net, optim=optim, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 1b3cd701a..1dd61ecb9 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -15,8 +15,9 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, DiscreteBCQPolicy -from tianshou.trainer import OfflineTrainer +from tianshou.policy import Algorithm, DiscreteBCQ +from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor @@ -88,16 +89,19 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: actor_critic = ActorCritic(policy_net, imitation_net) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - policy: DiscreteBCQPolicy = DiscreteBCQPolicy( + policy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, - optim=optim, action_space=env.action_space, + unlikely_action_threshold=args.unlikely_action_threshold, + ) + algorithm: DiscreteBCQ = DiscreteBCQ( + policy=policy, + optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, eval_eps=args.eps_test, - unlikely_action_threshold=args.unlikely_action_threshold, imitation_logits_penalty=args.imitation_logits_penalty, ) # buffer @@ -112,7 +116,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: buffer = gather_data() # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_bcq") writer = SummaryWriter(log_path) @@ -131,7 +135,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - "model": policy.state_dict(), + "model": algorithm.state_dict(), "optim": optim.state_dict(), }, ckpt_path, @@ -144,26 +148,27 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint["model"]) + algorithm.load_state_dict(checkpoint["model"]) optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.update_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index fbba9b217..a3ee7ef9e 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -25,7 +25,7 @@ from tianshou.policy.imitation.bcq import BCQ from tianshou.policy.imitation.cql import CQLPolicy from tianshou.policy.imitation.td3_bc import TD3BCPolicy -from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.policy.imitation.discrete_bcq import DiscreteBCQ from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRR from tianshou.policy.imitation.gail import GAIL @@ -57,7 +57,7 @@ "BCQ", "CQLPolicy", "TD3BCPolicy", - "DiscreteBCQPolicy", + "DiscreteBCQ", "DiscreteCQLPolicy", "DiscreteCRR", "GAIL", diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 294c4cd9e..c8ece09d1 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,4 +1,5 @@ import math +from copy import deepcopy from dataclasses import dataclass from typing import Any, Generic, Self, TypeVar, cast @@ -9,12 +10,12 @@ from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.types import ( + BatchWithReturnsProtocol, ImitationBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import DeepQLearning -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import OfflineAlgorithm, Policy, TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNTrainingStats float_info = torch.finfo(torch.float32) @@ -31,75 +32,38 @@ class DiscreteBCQTrainingStats(DQNTrainingStats): TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteBCQTrainingStats) -class DiscreteBCQPolicy(DeepQLearning, Generic[TDiscreteBCQTrainingStats]): - """Implementation of discrete BCQ algorithm. arXiv:1910.01708. - - :param model: a model following the rules (s_B -> action_values_BA) - :param imitator: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) - :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead - :param target_update_freq: the target network update frequency. - :param eval_eps: the epsilon-greedy noise added in evaluation. - :param unlikely_action_threshold: the threshold (tau) for unlikely - actions, as shown in Equ. (17) in the paper. - :param imitation_logits_penalty: regularization weight for imitation - logits. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - +class DiscreteBCQPolicy(Policy): def __init__( self, *, model: torch.nn.Module, imitator: torch.nn.Module, - optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, - discount_factor: float = 0.99, - estimation_step: int = 1, target_update_freq: int = 8000, - eval_eps: float = 1e-3, unlikely_action_threshold: float = 0.3, - imitation_logits_penalty: float = 1e-2, - reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, + action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param model: a model following the rules (s_B -> action_values_BA) + :param imitator: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) + :param target_update_freq: the target network update frequency. + :param unlikely_action_threshold: the threshold (tau) for unlikely + actions, as shown in Equ. (17) in the paper. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param action_space: the environment's action space. + :param observation_space: the environment's observation space. + """ super().__init__( - model=model, - optim=optim, action_space=action_space, - discount_factor=discount_factor, - estimation_step=estimation_step, - target_update_freq=target_update_freq, - reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, observation_space=observation_space, - lr_scheduler=lr_scheduler, ) + self.model = model + self.imitator = imitator assert ( target_update_freq > 0 ), f"BCQ needs target_update_freq>0 but got: {target_update_freq}." - self.imitator = imitator assert ( 0.0 <= unlikely_action_threshold < 1.0 ), f"unlikely_action_threshold should be in [0, 1) but got: {unlikely_action_threshold}" @@ -107,23 +71,7 @@ def __init__( self._log_tau = math.log(unlikely_action_threshold) else: self._log_tau = -np.inf - assert 0.0 <= eval_eps < 1.0 - self.eps = eval_eps - self._weight_reg = imitation_logits_penalty - - def train(self, mode: bool = True) -> Self: - self.training = mode - self.model.train(mode) - self.imitator.train(mode) - return self - - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - batch = buffer[indices] # batch.obs_next: s_{t+n} - next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - act = self(next_obs_batch).act - target_q, _ = self.model_old(batch.obs_next) - return target_q[np.arange(len(act)), act] + self.max_action_num: int | None = None def forward( # type: ignore self, @@ -131,9 +79,6 @@ def forward( # type: ignore state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> ImitationBatchProtocol: - # TODO: Liskov substitution principle is violated here, the superclass - # produces a batch with the field logits, but this one doesn't. - # Should be fixed in the future! q_value, state = self.model(batch.obs, state=state, info=batch.info) if self.max_action_num is None: self.max_action_num = q_value.shape[1] @@ -147,6 +92,107 @@ def forward( # type: ignore result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits) return cast(ImitationBatchProtocol, result) + +class DiscreteBCQ( + OfflineAlgorithm[DiscreteBCQPolicy, TDiscreteBCQTrainingStats], + Generic[TDiscreteBCQTrainingStats], +): + """Implementation of the discrete batch-constrained deep Q-learning (BCQ) algorithm. arXiv:1910.01708.""" + + def __init__( + self, + *, + policy: DiscreteBCQPolicy, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 8000, + eval_eps: float = 1e-3, + imitation_logits_penalty: float = 1e-2, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :param optim: a torch.optim for optimizing the model. + :param discount_factor: in [0, 1]. + :param estimation_step: the number of steps to look ahead + :param target_update_freq: the target network update frequency. + :param eval_eps: the epsilon-greedy noise added in evaluation. + :param imitation_logits_penalty: regularization weight for imitation + logits. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + lr_scheduler=lr_scheduler, + ) + self.optim = optim + assert ( + 0.0 <= discount_factor <= 1.0 + ), f"discount factor should be in [0, 1] but got: {discount_factor}" + self.gamma = discount_factor + assert ( + estimation_step > 0 + ), f"estimation_step should be greater than 0 but got: {estimation_step}" + self.n_step = estimation_step + self._target = target_update_freq > 0 + self.freq = target_update_freq + self._iter = 0 + if self._target: + self.model_old = deepcopy(self.policy.model) + self.model_old.eval() + self.rew_norm = reward_normalization + self.is_double = is_double + self.clip_loss_grad = clip_loss_grad + assert 0.0 <= eval_eps < 1.0 + self.eps = eval_eps + self._weight_reg = imitation_logits_penalty + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> BatchWithReturnsProtocol: + return self.compute_nstep_return( + batch=batch, + buffer=buffer, + indices=indices, + target_q_fn=self._target_q, + gamma=self.gamma, + n_step=self.n_step, + rew_norm=self.rew_norm, + ) + + def train(self, mode: bool = True) -> Self: + self.training = mode + self.policy.model.train(mode) + self.policy.imitator.train(mode) + return self + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + batch = buffer[indices] # batch.obs_next: s_{t+n} + next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + act = self.policy(next_obs_batch).act + target_q, _ = self.model_old(batch.obs_next) + return target_q[np.arange(len(act)), act] + + def _update_lagged_network_weights(self) -> None: + self.model_old.load_state_dict(self.policy.model.state_dict()) + def _update_with_batch( self, batch: RolloutBatchProtocol, @@ -154,11 +200,11 @@ def _update_with_batch( **kwargs: Any, ) -> TDiscreteBCQTrainingStats: if self._iter % self.freq == 0: - self.sync_weight() + self._update_lagged_network_weights() self._iter += 1 target_q = batch.returns.flatten() - result = self(batch) + result = self.policy(batch) imitation_logits = result.imitation_logits current_q = result.q_value[np.arange(len(target_q)), batch.act] act = to_torch(batch.act, dtype=torch.long, device=target_q.device) From f47e97a520178b4d82dfb77494c54ae1da27a1df Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 11 Mar 2025 23:15:51 +0100 Subject: [PATCH 049/230] v2: Adapt CQL and test_cql * Fix hierarchy: CQL (offline) no longer inherits from sac (off-policy), resolving several violations of the Liskov substitution principle * Alpha: Add factory method from_float_or_instance * Move method process_buffer from Algorithm to OfflineAlgorithm * Add util function torch_util.torch_device --- CHANGELOG.md | 6 +- examples/offline/d4rl_cql.py | 4 +- test/offline/test_cql.py | 54 +++--- tianshou/policy/__init__.py | 4 +- tianshou/policy/base.py | 16 +- tianshou/policy/imitation/cql.py | 222 +++++++++------------- tianshou/policy/modelfree/discrete_sac.py | 5 +- tianshou/policy/modelfree/redq.py | 5 +- tianshou/policy/modelfree/sac.py | 14 +- tianshou/utils/torch_utils.py | 5 + 10 files changed, 150 insertions(+), 185 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e16f86ac4..e592b9ede 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,10 +66,14 @@ making the codebase more consistent while preserving the original functionality. * Introduced a policy base class `ContinuousPolicyWithExplorationNoise` which encapsulates noise generation for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). - * Fixed issues in the class hierarchy (e.g. violations of the Liskov substitution principle): + * Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOffPolicyAlgorithm` * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) + * `CQL`: + * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). + * Remove parameter `estimation_step`, which was not actually used (it was only passed it on to its + superclass). * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) * `NPG`: Inherit from `AbstractActorCriticWithAdvantage` instead of `A2C` (which is now has the same base class) * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 23e95fcf9..a52962cba 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -13,7 +13,7 @@ from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import CQLPolicy +from tianshou.policy import CQL from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger @@ -284,7 +284,7 @@ def test_cql() -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: CQLPolicy = CQLPolicy( + policy: CQL = CQL( actor=actor, policy_optim=actor_optim, critic=critic, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 8eac8fee3..4df3cae4d 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -2,7 +2,6 @@ import datetime import os import pickle -import pprint from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -12,9 +11,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, CQLPolicy +from tianshou.policy import CQL, Algorithm from tianshou.policy.imitation.cql import CQLTrainingStats -from tianshou.trainer import OfflineTrainer +from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -134,17 +134,20 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: target_entropy = -np.prod(args.action_shape) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) - policy: CQLPolicy[CQLTrainingStats] = CQLPolicy( + policy = SACPolicy( actor=actor, - policy_optim=actor_optim, - critic=critic, - critic_optim=critic_optim, # CQL seems to perform better without action scaling # TODO: investigate why action_scaling=False, action_space=env.action_space, + ) + algorithm: CQL[CQLTrainingStats] = CQL( + policy=policy, + policy_optim=actor_optim, + critic=critic, + critic_optim=critic_optim, cql_alpha_lr=args.cql_alpha_lr, cql_weight=args.cql_weight, tau=args.tau, @@ -155,18 +158,17 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: lagrange_threshold=args.lagrange_threshold, min_action=args.min_action, max_action=args.max_action, - device=args.device, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' @@ -182,22 +184,18 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer - trainer = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - stop_fn=stop_fn, - logger=logger, + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + ) ) - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - pprint.pprint(epoch_stat) - # print(info) - - assert stop_fn(epoch_stat.info_stat.best_reward) + assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index a3ee7ef9e..3f70daa81 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -23,7 +23,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSAC from tianshou.policy.imitation.base import ImitationLearning from tianshou.policy.imitation.bcq import BCQ -from tianshou.policy.imitation.cql import CQLPolicy +from tianshou.policy.imitation.cql import CQL from tianshou.policy.imitation.td3_bc import TD3BCPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQ from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy @@ -55,7 +55,7 @@ "DiscreteSAC", "ImitationLearning", "BCQ", - "CQLPolicy", + "CQL", "TD3BCPolicy", "DiscreteBCQ", "DiscreteCQLPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index ddf8f8524..07703cb15 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -474,18 +474,6 @@ def _polyak_parameter_update(self, tgt: nn.Module, src: nn.Module, tau: float) - for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) - def process_buffer(self, buffer: TBuffer) -> TBuffer: - """Pre-process the replay buffer, e.g., to add new keys. - - Used in BaseTrainer initialization method, usually used by offline trainers. - - Note: this will only be called once, when the trainer is initialized! - If the buffer is empty by then, there will be nothing to process. - This method is meant to be overridden by policies which will be trained - offline at some stage, e.g., in a pre-training step. - """ - return buffer - def process_fn( self, batch: RolloutBatchProtocol, @@ -808,6 +796,10 @@ class OfflineAlgorithm( Generic[TPolicy, TTrainingStats], ABC, ): + def process_buffer(self, buffer: TBuffer) -> TBuffer: + """Pre-process the replay buffer to prepare for offline learning, e.g. to add new keys.""" + return buffer + def create_trainer(self, config: "OfflineTrainingConfig") -> "OfflineTrainer": from tianshou.trainer.base import OfflineTrainer diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 7b57d1fda..374520f1b 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -1,7 +1,7 @@ +from copy import deepcopy from dataclasses import dataclass -from typing import Any, Literal, Self, TypeVar, cast +from typing import Any, Self, TypeVar, cast -import gymnasium as gym import numpy as np import torch import torch.nn.functional as F @@ -11,12 +11,11 @@ from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.buffer.base import TBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.exploration import BaseNoise -from tianshou.policy import SAC -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.sac import SACTrainingStats +from tianshou.policy.base import OfflineAlgorithm, TLearningRateScheduler +from tianshou.policy.modelfree.sac import Alpha, SACPolicy, SACTrainingStats from tianshou.utils.conversion import to_optional_float -from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.optim import clone_optimizer +from tianshou.utils.torch_utils import torch_device @dataclass(kw_only=True) @@ -30,70 +29,24 @@ class CQLTrainingStats(SACTrainingStats): TCQLTrainingStats = TypeVar("TCQLTrainingStats", bound=CQLTrainingStats) -class CQLPolicy(SAC[TCQLTrainingStats]): - """Implementation of CQL algorithm. arXiv:2006.04779. - - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param policy_optim: The optimizer for actor network. - :param critic: The first critic network. - :param critic_optim: The optimizer for the first critic network. - :param action_space: Env's action space. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param cql_alpha_lr: The learning rate of cql_log_alpha. - :param cql_weight: - :param tau: Parameter for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. - :param alpha: Entropy regularization coefficient or a tuple - (target_entropy, log_alpha, alpha_optim) for automatic tuning. - :param temperature: - :param with_lagrange: Whether to use Lagrange. - TODO: extend documentation - what does this mean? - :param lagrange_threshold: The value of tau in CQL(Lagrange). - :param min_action: The minimum value of each dimension of action. - :param max_action: The maximum value of each dimension of action. - :param num_repeat_actions: The number of times the action is repeated when calculating log-sum-exp. - :param alpha_min: Lower bound for clipping cql_alpha. - :param alpha_max: Upper bound for clipping cql_alpha. - :param clip_grad: Clip_grad for updating critic network. - :param calibrated: calibrate Q-values as in CalQL paper `arXiv:2303.05479`. - Useful for offline pre-training followed by online training, - and also was observed to achieve better results than vanilla cql. - :param device: Which device to create this model on. - :param estimation_step: Estimation steps. - :param exploration_noise: Type of exploration noise. - :param deterministic_eval: Flag for deterministic evaluation. - :param action_scaling: Flag for action scaling. - :param action_bound_method: Method for action bounding. Only used if the - action_space is continuous. - :param observation_space: Env's Observation space. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in - optimizer in each policy.update(). - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ +# TODO: Perhaps SACPolicy should get a more generic name +class CQL(OfflineAlgorithm[SACPolicy, TCQLTrainingStats]): + """Implementation of the conservative Q-learning (CQL) algorithm. arXiv:2006.04779.""" def __init__( self, *, - actor: ActorProb, + policy: SACPolicy, policy_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, - action_space: gym.spaces.Box, critic2: torch.nn.Module | None = None, critic2_optim: torch.optim.Optimizer | None = None, cql_alpha_lr: float = 1e-4, cql_weight: float = 1.0, tau: float = 0.005, gamma: float = 0.99, - alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2, + alpha: float | Alpha = 0.2, temperature: float = 1.0, with_lagrange: bool = True, lagrange_threshold: float = 10.0, @@ -104,37 +57,66 @@ def __init__( alpha_max: float = 1e6, clip_grad: float = 1.0, calibrated: bool = True, - # TODO: why does this one have device? Almost no other policies have it - device: str | torch.device = "cpu", - estimation_step: int = 1, - exploration_noise: BaseNoise | Literal["default"] | None = None, - deterministic_eval: bool = True, - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", - observation_space: gym.Space | None = None, + estimation_step: int = 1, # TODO remove lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param policy_optim: The optimizer for actor network. + :param critic: The first critic network. + :param critic_optim: The optimizer for the first critic network. + :param action_space: Env's action space. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param cql_alpha_lr: The learning rate of cql_log_alpha. + :param cql_weight: + :param tau: Parameter for soft update of the target network. + :param gamma: Discount factor, in [0, 1]. + :param alpha: the entropy regularization coefficient alpha or an object + which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). + :param temperature: + :param with_lagrange: Whether to use Lagrange. + TODO: extend documentation - what does this mean? + :param lagrange_threshold: The value of tau in CQL(Lagrange). + :param min_action: The minimum value of each dimension of action. + :param max_action: The maximum value of each dimension of action. + :param num_repeat_actions: The number of times the action is repeated when calculating log-sum-exp. + :param alpha_min: Lower bound for clipping cql_alpha. + :param alpha_max: Upper bound for clipping cql_alpha. + :param clip_grad: Clip_grad for updating critic network. + :param calibrated: calibrate Q-values as in CalQL paper `arXiv:2303.05479`. + Useful for offline pre-training followed by online training, + and also was observed to achieve better results than vanilla cql. + :param estimation_step: Estimation steps. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). + """ super().__init__( - actor=actor, - policy_optim=policy_optim, - critic=critic, - critic_optim=critic_optim, - action_space=action_space, - critic2=critic2, - critic2_optim=critic2_optim, - tau=tau, - gamma=gamma, - deterministic_eval=deterministic_eval, - alpha=alpha, - exploration_noise=exploration_noise, - estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, + policy=policy, lr_scheduler=lr_scheduler, ) - # There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy. - self.device = device + + device = torch_device(policy) + + self.policy_optim = policy_optim + self.critic = critic + self.critic_optim = critic_optim + self.critic2 = critic2 or deepcopy(critic) + self.critic2_optim = critic2_optim or clone_optimizer( + critic_optim, self.critic2.parameters() + ) + self.critic_old = deepcopy(self.critic) + self.critic2_old = deepcopy(self.critic2) + self.critic_old.eval() + self.critic2_old.eval() + + self.tau = tau + self.gamma = gamma + self.alpha = Alpha.from_float_or_instance(alpha) + self.temperature = temperature self.with_lagrange = with_lagrange self.lagrange_threshold = lagrange_threshold @@ -157,9 +139,9 @@ def __init__( self.calibrated = calibrated def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" + """Sets the module in training mode, except for the lagged networks.""" self.training = mode - self.actor.train(mode) + self.policy.train(mode) self.critic.train(mode) self.critic2.train(mode) return self @@ -169,34 +151,34 @@ def _update_lagged_network_weights(self) -> None: self._polyak_parameter_update(self.critic_old, self.critic, self.tau) self._polyak_parameter_update(self.critic2_old, self.critic2, self.tau) - def actor_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _policy_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch = Batch(obs=obs, info=[None] * len(obs)) - obs_result = self(batch) + obs_result = self.policy(batch) return obs_result.act, obs_result.log_prob - def calc_actor_loss(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - act_pred, log_pi = self.actor_pred(obs) + def _calc_policy_loss(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + act_pred, log_pi = self._policy_pred(obs) q1 = self.critic(obs, act_pred) q2 = self.critic2(obs, act_pred) min_Q = torch.min(q1, q2) # self.alpha: float | torch.Tensor - actor_loss = (self.alpha * log_pi - min_Q).mean() + actor_loss = (self.alpha.value * log_pi - min_Q).mean() # actor_loss.shape: (), log_pi.shape: (batch_size, 1) return actor_loss, log_pi - def calc_pi_values( + def _calc_pi_values( self, obs_pi: torch.Tensor, obs_to_pred: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - act_pred, log_pi = self.actor_pred(obs_pi) + act_pred, log_pi = self._policy_pred(obs_pi) q1 = self.critic(obs_to_pred, act_pred) q2 = self.critic2(obs_to_pred, act_pred) return q1 - log_pi.detach(), q2 - log_pi.detach() - def calc_random_values( + def _calc_random_values( self, obs: torch.Tensor, act: torch.Tensor, @@ -234,51 +216,29 @@ def process_buffer(self, buffer: TBuffer) -> TBuffer: ) return buffer - def process_fn( - self, - batch: RolloutBatchProtocol, - buffer: ReplayBuffer, - indices: np.ndarray, - ) -> RolloutBatchProtocol: - # TODO: mypy rightly complains here b/c the design violates - # Liskov Substitution Principle - # DDPGPolicy.process_fn() results in a batch with returns but - # CQLPolicy.process_fn() doesn't add the returns. - # Should probably be fixed! - return batch - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLTrainingStats: # type: ignore - batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) + device = torch_device(self.policy) + batch: Batch = to_torch(batch, dtype=torch.float, device=device) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next batch_size = obs.shape[0] # compute actor loss and update actor - actor_loss, log_pi = self.calc_actor_loss(obs) - self.actor_optim.zero_grad() + actor_loss, log_pi = self._calc_policy_loss(obs) + self.policy_optim.zero_grad() actor_loss.backward() - self.actor_optim.step() - - alpha_loss = None - # compute alpha loss - if self.is_auto_alpha: - log_pi = log_pi + self.target_entropy - alpha_loss = -(self.log_alpha * log_pi.detach()).mean() - self.alpha_optim.zero_grad() - # update log_alpha - alpha_loss.backward() - self.alpha_optim.step() - # update alpha - # TODO: it's probably a bad idea to track both alpha and log_alpha in different fields - self.alpha = self.log_alpha.detach().exp() + self.policy_optim.step() + + entropy = -log_pi.detach() + alpha_loss = self.alpha.update(entropy) # compute target_Q with torch.no_grad(): - act_next, new_log_pi = self.actor_pred(obs_next) + act_next, new_log_pi = self._policy_pred(obs_next) target_Q1 = self.critic_old(obs_next, act_next) target_Q2 = self.critic2_old(obs_next, act_next) - target_Q = torch.min(target_Q1, target_Q2) - self.alpha * new_log_pi + target_Q = torch.min(target_Q1, target_Q2) - self.alpha.value * new_log_pi target_Q = rew + torch.logical_not(batch.done) * self.gamma * target_Q.flatten() target_Q = target_Q.float() @@ -296,7 +256,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: random_actions = ( torch.FloatTensor(batch_size * self.num_repeat_actions, act.shape[-1]) .uniform_(-self.min_action, self.max_action) - .to(self.device) + .to(device) ) obs_len = len(obs.shape) @@ -306,10 +266,10 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size) # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim) - current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs) - next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs) + current_pi_value1, current_pi_value2 = self._calc_pi_values(tmp_obs, tmp_obs) + next_pi_value1, next_pi_value2 = self._calc_pi_values(tmp_obs_next, tmp_obs) - random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions) + random_value1, random_value2 = self._calc_random_values(tmp_obs, random_actions) for value in [ current_pi_value1, @@ -395,7 +355,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: actor_loss=to_optional_float(actor_loss), critic1_loss=to_optional_float(critic1_loss), critic2_loss=to_optional_float(critic2_loss), - alpha=to_optional_float(self.alpha), + alpha=to_optional_float(self.alpha.value), alpha_loss=to_optional_float(alpha_loss), cql_alpha_loss=to_optional_float(cql_alpha_loss), cql_alpha=to_optional_float(cql_alpha), diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 8e808fe22..be634a0fd 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -13,7 +13,7 @@ RolloutBatchProtocol, ) from tianshou.policy.base import Policy, TLearningRateScheduler -from tianshou.policy.modelfree.sac import Alpha, FixedAlpha, SACTrainingStats +from tianshou.policy.modelfree.sac import Alpha, SACTrainingStats from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.utils.net.discrete import Critic @@ -121,8 +121,7 @@ def __init__( estimation_step=estimation_step, lr_scheduler=lr_scheduler, ) - self.alpha = FixedAlpha(alpha) if isinstance(alpha, float) else alpha - assert isinstance(self.alpha, Alpha) + self.alpha = Alpha.from_float_or_instance(alpha) def _target_q_compute_value( self, obs_batch: Batch, act_batch: DistBatchProtocol diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 92891d50a..8d774e99a 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -19,7 +19,7 @@ ContinuousPolicyWithExplorationNoise, DDPGTrainingStats, ) -from tianshou.policy.modelfree.sac import Alpha, FixedAlpha +from tianshou.policy.modelfree.sac import Alpha from tianshou.utils.net.continuous import ActorProb @@ -168,8 +168,7 @@ def __init__( self._last_actor_loss = 0.0 # only for logging purposes - self.alpha = FixedAlpha(alpha) if isinstance(alpha, float) else alpha - assert isinstance(self.alpha, Alpha) + self.alpha = Alpha.from_float_or_instance(alpha) def _target_q_compute_value( self, obs_batch: Batch, act_batch: DistLogProbBatchProtocol diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 244f1368f..98508e06a 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, cast +from typing import Any, Generic, Literal, TypeVar, Union, cast import gymnasium as gym import numpy as np @@ -119,6 +119,15 @@ def forward( # type: ignore class Alpha(ABC): """Defines the interface for the entropy regularization coefficient alpha.""" + @staticmethod + def from_float_or_instance(alpha: Union[float, "Alpha"]) -> "Alpha": + if isinstance(alpha, float): + return FixedAlpha(alpha) + elif isinstance(alpha, Alpha): + return alpha + else: + raise ValueError(f"Expected float or Alpha instance, but got {alpha=}") + @property @abstractmethod def value(self) -> float: @@ -243,8 +252,7 @@ def __init__( lr_scheduler=lr_scheduler, ) self.deterministic_eval = deterministic_eval - self.alpha = FixedAlpha(alpha) if isinstance(alpha, float) else alpha - assert isinstance(self.alpha, Alpha) + self.alpha = Alpha.from_float_or_instance(alpha) self._check_field_validity() def _check_field_validity(self) -> None: diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 8526c5303..1c001a544 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -75,3 +75,8 @@ def create_uniform_action_dist( else: raise ValueError(f"Unsupported action space type: {type(action_space)}") + + +def torch_device(module: torch.nn.Module) -> torch.device: + """Gets the device of a torch module.""" + return next(module.parameters()).device From bf85a5841954be3b09c25f7b56fc67a6692c3620 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 12 Mar 2025 16:12:41 +0100 Subject: [PATCH 050/230] v2: Introduce mixins for the handling of lagged networks (target networks) * New classes LaggedNetworkPolyakUpdateAlgorithmMixin and LaggedNetworkFullUpdateAlgorithmMixin are provided, which fully handle lagged networks (target networks) * A lagged network can now simply be added by calling _add_lagged_network * The train method must no longer be overridden to ensure that the target networks are never set to train mode/remain in eval mode * A method which updates all target networks with their source networks is automatically provided and does not need to be implemented specifically for every algorithm * Adapted classes/algorithms: * base classes: * ActorCriticOffPolicyAlgorithm * ActorDualCriticsOffPolicyAlgorithm * BCQ * CQL * DeepQLearning and sub-classes * DiscreteCRR * DiscreteBCQ * DDPG * TD3 * Fix: In BCQ, the train method implementation omitted the VAE network (train mode was never set!); the automatic mechamisms fix this issue --- tianshou/policy/base.py | 50 +++++++++---- tianshou/policy/imitation/bcq.py | 38 ++++------ tianshou/policy/imitation/cql.py | 31 +++------ tianshou/policy/imitation/discrete_bcq.py | 24 +++---- tianshou/policy/imitation/discrete_cql.py | 2 +- tianshou/policy/imitation/discrete_crr.py | 23 +++--- tianshou/policy/modelfree/bdqn.py | 2 +- tianshou/policy/modelfree/c51.py | 2 +- tianshou/policy/modelfree/ddpg.py | 29 ++------ tianshou/policy/modelfree/dqn.py | 25 +++---- tianshou/policy/modelfree/fqf.py | 2 +- tianshou/policy/modelfree/iqn.py | 2 +- tianshou/policy/modelfree/qrdqn.py | 2 +- tianshou/policy/modelfree/td3.py | 23 ++---- tianshou/utils/lagged_network.py | 85 +++++++++++++++++++++++ 15 files changed, 192 insertions(+), 148 deletions(-) create mode 100644 tianshou/utils/lagged_network.py diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 07703cb15..2b1d55379 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -25,6 +25,9 @@ RolloutBatchProtocol, ) from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.lagged_network import ( + LaggedNetworkCollection, +) from tianshou.utils.net.common import RandomActor from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode @@ -381,6 +384,41 @@ def add_exploration_noise( return act +class LaggedNetworkAlgorithmMixin(ABC): + def __init__(self) -> None: + self._lagged_networks = LaggedNetworkCollection() + + def _add_lagged_network(self, src: torch.nn.Module) -> torch.nn.Module: + """ + Adds a lagged network to the collection, returning the target network, which + is forced to eval mode. The target network is a copy of the source network, + which, however, supports only the forward method (hence the type torch.nn.Module); + attribute access is not supported. + + :param source: the source network whose parameters are to be copied to the target network + :return: the target network, which supports only the forward method and is forced to eval mode + """ + return self._lagged_networks.add_lagged_network(src) + + @abstractmethod + def _update_lagged_network_weights(self) -> None: + pass + + +class LaggedNetworkFullUpdateAlgorithmMixin(LaggedNetworkAlgorithmMixin): + def _update_lagged_network_weights(self) -> None: + self._lagged_networks.full_parameter_update() + + +class LaggedNetworkPolyakUpdateAlgorithmMixin(LaggedNetworkAlgorithmMixin): + def __init__(self, tau: float) -> None: + super().__init__() + self.tau = tau + + def _update_lagged_network_weights(self) -> None: + self._lagged_networks.polyak_parameter_update(self.tau) + + TPolicy = TypeVar("TPolicy", bound=Policy) TTrainingConfig = TypeVar( "TTrainingConfig", @@ -462,18 +500,6 @@ def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id - def _polyak_parameter_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None: - """Softly updates the parameters of a target network `tgt` with the parameters of a source network `src` - using Polyak averaging: `tau * src + (1 - tau) * tgt`. - - :param tgt: the target network that receives the parameter update - :param src: the source network whose parameters are used for the update - :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being - the fraction with which to retain the target network's parameters. - """ - for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): - tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) - def process_fn( self, batch: RolloutBatchProtocol, diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 267f45417..e8c6778f2 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -1,6 +1,6 @@ import copy from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar, cast +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -11,6 +11,7 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy.base import ( + LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, Policy, TLearningRateScheduler, @@ -95,7 +96,11 @@ def forward( return cast(ActBatchProtocol, Batch(act=act_group)) -class BCQ(OfflineAlgorithm[BCQPolicy, TBCQTrainingStats], Generic[TBCQTrainingStats]): +class BCQ( + OfflineAlgorithm[BCQPolicy, TBCQTrainingStats], + LaggedNetworkPolyakUpdateAlgorithmMixin, + Generic[TBCQTrainingStats], +): """Implementation of Batch-Constrained Deep Q-learning (BCQ) algorithm. arXiv:1812.02900.""" def __init__( @@ -132,42 +137,25 @@ def __init__( policy=policy, lr_scheduler=lr_scheduler, ) - self.actor_perturbation_target = copy.deepcopy(self.policy.actor_perturbation) + LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) + self.actor_perturbation_target = self._add_lagged_network(self.policy.actor_perturbation) self.actor_perturbation_optim = actor_perturbation_optim - self.critic_target = copy.deepcopy(self.policy.critic) + self.critic_target = self._add_lagged_network(self.policy.critic) self.critic_optim = critic_optim critic2 = critic2 or copy.deepcopy(self.policy.critic) critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) self.critic2 = critic2 - self.critic2_target = copy.deepcopy(self.critic2) + self.critic2_target = self._add_lagged_network(self.critic2) self.critic2_optim = critic2_optim self.vae_optim = vae_optim self.gamma = gamma - self.tau = tau self.lmbda = lmbda self.num_sampled_action = num_sampled_action - def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" - # TODO: vae is not considered; this is probably a bug! - self.training = mode - self.policy.actor_perturbation.train(mode) - self.policy.critic.train(mode) - self.critic2.train(mode) - return self - - def sync_weight(self) -> None: - """Soft-update the weight for the target network.""" - self._polyak_parameter_update(self.critic_target, self.policy.critic, self.tau) - self._polyak_parameter_update(self.critic2_target, self.critic2, self.tau) - self._polyak_parameter_update( - self.actor_perturbation_target, self.policy.actor_perturbation, self.tau - ) - def _update_with_batch( self, batch: RolloutBatchProtocol, @@ -246,8 +234,8 @@ def _update_with_batch( actor_loss.backward() self.actor_perturbation_optim.step() - # update target network - self.sync_weight() + # update target networks + self._update_lagged_network_weights() return BCQTrainingStats( # type: ignore actor_loss=actor_loss.item(), diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 374520f1b..eb58b59c8 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -1,6 +1,6 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Any, Self, TypeVar, cast +from typing import Any, TypeVar, cast import numpy as np import torch @@ -11,7 +11,11 @@ from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.buffer.base import TBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.base import OfflineAlgorithm, TLearningRateScheduler +from tianshou.policy.base import ( + LaggedNetworkPolyakUpdateAlgorithmMixin, + OfflineAlgorithm, + TLearningRateScheduler, +) from tianshou.policy.modelfree.sac import Alpha, SACPolicy, SACTrainingStats from tianshou.utils.conversion import to_optional_float from tianshou.utils.optim import clone_optimizer @@ -30,7 +34,7 @@ class CQLTrainingStats(SACTrainingStats): # TODO: Perhaps SACPolicy should get a more generic name -class CQL(OfflineAlgorithm[SACPolicy, TCQLTrainingStats]): +class CQL(OfflineAlgorithm[SACPolicy, TCQLTrainingStats], LaggedNetworkPolyakUpdateAlgorithmMixin): """Implementation of the conservative Q-learning (CQL) algorithm. arXiv:2006.04779.""" def __init__( @@ -98,6 +102,7 @@ def __init__( policy=policy, lr_scheduler=lr_scheduler, ) + LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) device = torch_device(policy) @@ -108,12 +113,9 @@ def __init__( self.critic2_optim = critic2_optim or clone_optimizer( critic_optim, self.critic2.parameters() ) - self.critic_old = deepcopy(self.critic) - self.critic2_old = deepcopy(self.critic2) - self.critic_old.eval() - self.critic2_old.eval() + self.critic_old = self._add_lagged_network(self.critic) + self.critic2_old = self._add_lagged_network(self.critic2) - self.tau = tau self.gamma = gamma self.alpha = Alpha.from_float_or_instance(alpha) @@ -138,19 +140,6 @@ def __init__( self.calibrated = calibrated - def train(self, mode: bool = True) -> Self: - """Sets the module in training mode, except for the lagged networks.""" - self.training = mode - self.policy.train(mode) - self.critic.train(mode) - self.critic2.train(mode) - return self - - def _update_lagged_network_weights(self) -> None: - """Soft-update the weight for the target network.""" - self._polyak_parameter_update(self.critic_old, self.critic, self.tau) - self._polyak_parameter_update(self.critic2_old, self.critic2, self.tau) - def _policy_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch = Batch(obs=obs, info=[None] * len(obs)) obs_result = self.policy(batch) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index c8ece09d1..5cce77fc4 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,7 +1,6 @@ import math -from copy import deepcopy from dataclasses import dataclass -from typing import Any, Generic, Self, TypeVar, cast +from typing import Any, Generic, TypeVar, cast import gymnasium as gym import numpy as np @@ -15,7 +14,12 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy.base import OfflineAlgorithm, Policy, TLearningRateScheduler +from tianshou.policy.base import ( + LaggedNetworkFullUpdateAlgorithmMixin, + OfflineAlgorithm, + Policy, + TLearningRateScheduler, +) from tianshou.policy.modelfree.dqn import DQNTrainingStats float_info = torch.finfo(torch.float32) @@ -95,6 +99,7 @@ def forward( # type: ignore class DiscreteBCQ( OfflineAlgorithm[DiscreteBCQPolicy, TDiscreteBCQTrainingStats], + LaggedNetworkFullUpdateAlgorithmMixin, Generic[TDiscreteBCQTrainingStats], ): """Implementation of the discrete batch-constrained deep Q-learning (BCQ) algorithm. arXiv:1910.01708.""" @@ -138,6 +143,7 @@ def __init__( policy=policy, lr_scheduler=lr_scheduler, ) + LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.optim = optim assert ( 0.0 <= discount_factor <= 1.0 @@ -151,8 +157,7 @@ def __init__( self.freq = target_update_freq self._iter = 0 if self._target: - self.model_old = deepcopy(self.policy.model) - self.model_old.eval() + self.model_old = self._add_lagged_network(self.policy.model) self.rew_norm = reward_normalization self.is_double = is_double self.clip_loss_grad = clip_loss_grad @@ -176,12 +181,6 @@ def process_fn( rew_norm=self.rew_norm, ) - def train(self, mode: bool = True) -> Self: - self.training = mode - self.policy.model.train(mode) - self.policy.imitator.train(mode) - return self - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs_next: s_{t+n} next_obs_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) @@ -190,9 +189,6 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: target_q, _ = self.model_old(batch.obs_next) return target_q[np.arange(len(act)), act] - def _update_lagged_network_weights(self) -> None: - self.model_old.load_state_dict(self.policy.model.state_dict()) - def _update_with_batch( self, batch: RolloutBatchProtocol, diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index f9b332128..30f0aafa3 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -89,7 +89,7 @@ def _update_with_batch( **kwargs: Any, ) -> TDiscreteCQLTrainingStats: if self._target and self._iter % self.freq == 0: - self.sync_weight() + self._update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) all_dist = self(batch).logits diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 599f19c82..0b5898d1a 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -1,4 +1,3 @@ -from copy import deepcopy from dataclasses import dataclass from typing import Any, Literal, TypeVar @@ -9,7 +8,11 @@ from tianshou.data import ReplayBuffer, to_torch, to_torch_as from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol -from tianshou.policy.base import OfflineAlgorithm, TLearningRateScheduler +from tianshou.policy.base import ( + LaggedNetworkFullUpdateAlgorithmMixin, + OfflineAlgorithm, + TLearningRateScheduler, +) from tianshou.policy.modelfree.pg import ( DiscountedReturnComputation, DiscreteActorPolicy, @@ -28,7 +31,10 @@ class DiscreteCRRTrainingStats(PGTrainingStats): TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteCRRTrainingStats) -class DiscreteCRR(OfflineAlgorithm[DiscreteActorPolicy, TDiscreteCRRTrainingStats]): +class DiscreteCRR( + OfflineAlgorithm[DiscreteActorPolicy, TDiscreteCRRTrainingStats], + LaggedNetworkFullUpdateAlgorithmMixin, +): r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.""" def __init__( @@ -70,6 +76,7 @@ def __init__( policy=policy, lr_scheduler=lr_scheduler, ) + LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.optim = optim self.discounted_return_computation = DiscountedReturnComputation( discount_factor=discount_factor, @@ -80,10 +87,8 @@ def __init__( self._freq = target_update_freq self._iter = 0 if self._target: - self.actor_old = deepcopy(self.policy.actor) - self.actor_old.eval() - self.critic_old = deepcopy(self.critic) - self.critic_old.eval() + self.actor_old = self._add_lagged_network(self.policy.actor) + self.critic_old = self._add_lagged_network(self.critic) else: self.actor_old = self.actor self.critic_old = self.critic @@ -104,10 +109,6 @@ def process_fn( indices, ) - def _update_lagged_network_weights(self) -> None: - self.actor_old.load_state_dict(self.policy.actor.state_dict()) - self.critic_old.load_state_dict(self.critic.state_dict()) - def _update_with_batch( # type: ignore self, batch: RolloutBatchProtocol, diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 08ffa5adf..26e7da5c3 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -191,7 +191,7 @@ def _update_with_batch( **kwargs: Any, ) -> TBDQNTrainingStats: if self._target and self._iter % self.freq == 0: - self.sync_weight() + self._update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 4905ae079..24ea7b2fe 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -136,7 +136,7 @@ def _update_with_batch( **kwargs: Any, ) -> TC51TrainingStats: if self._target and self._iter % self.freq == 0: - self.sync_weight() + self._update_lagged_network_weights() self.optim.zero_grad() with torch.no_grad(): target_dist = self._target_dist(batch) diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 5fdc375a0..43a21ddf2 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,8 +1,7 @@ import warnings from abc import ABC -from copy import deepcopy from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar, cast +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -20,6 +19,7 @@ ) from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy.base import ( + LaggedNetworkPolyakUpdateAlgorithmMixin, OffPolicyAlgorithm, Policy, TArrOrActBatch, @@ -160,6 +160,7 @@ def forward( class ActorCriticOffPolicyAlgorithm( OffPolicyAlgorithm[TPolicy, TTrainingStats], + LaggedNetworkPolyakUpdateAlgorithmMixin, Generic[TPolicy, TTrainingStats, TActBatchProtocol], ABC, ): @@ -210,12 +211,11 @@ def __init__( policy=policy, lr_scheduler=lr_scheduler, ) + LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) self.policy_optim = policy_optim self.critic = critic - self.critic_old = deepcopy(critic) - self.critic_old.eval() + self.critic_old = self._add_lagged_network(self.critic) self.critic_optim = critic_optim - self.tau = tau self.gamma = gamma self.estimation_step = estimation_step @@ -296,18 +296,6 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: act_batch = self._target_q_compute_action(obs_next_batch) return self._target_q_compute_value(obs_next_batch, act_batch) - def _update_lagged_network_weights(self) -> None: - """Updates the lagged network weights with the current weights using Polyak averaging.""" - self._polyak_parameter_update(self.critic_old, self.critic, self.tau) - - def train(self, mode: bool = True) -> Self: - """Sets the module to training mode, except for the lagged components.""" - # exclude `critic_old` from training - self.training = mode - self.policy.train(mode) - self.critic.train(mode) - return self - class DDPG( ActorCriticOffPolicyAlgorithm[DDPGPolicy, TDDPGTrainingStats, ActBatchProtocol], @@ -347,17 +335,12 @@ def __init__( gamma=gamma, estimation_step=estimation_step, ) - self.actor_old = deepcopy(policy.actor) - self.actor_old.eval() + self.actor_old = self._add_lagged_network(self.policy.actor) def _target_q_compute_action(self, obs_batch: Batch) -> ActBatchProtocol: # compute the action using the lagged actor network return self.policy(obs_batch, model=self.actor_old) - def _update_lagged_network_weights(self) -> None: - super()._update_lagged_network_weights() - self._polyak_parameter_update(self.actor_old, self.policy.actor, self.tau) - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore # critic td, critic_loss = self._minimize_critic_squared_loss(batch, self.critic, self.critic_optim) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 9536e7ca2..b1a365730 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,6 +1,5 @@ -from copy import deepcopy from dataclasses import dataclass -from typing import Any, Generic, Self, TypeVar, cast +from typing import Any, Generic, TypeVar, cast import gymnasium as gym import numpy as np @@ -17,6 +16,7 @@ RolloutBatchProtocol, ) from tianshou.policy.base import ( + LaggedNetworkFullUpdateAlgorithmMixin, OffPolicyAlgorithm, Policy, TArrOrActBatch, @@ -143,7 +143,9 @@ def add_exploration_noise( class DeepQLearning( - OffPolicyAlgorithm[TDQNPolicy, TDQNTrainingStats], Generic[TDQNPolicy, TDQNTrainingStats] + OffPolicyAlgorithm[TDQNPolicy, TDQNTrainingStats], + LaggedNetworkFullUpdateAlgorithmMixin, + Generic[TDQNPolicy, TDQNTrainingStats], ): """Implementation of Deep Q Network. arXiv:1312.5602. @@ -185,6 +187,7 @@ def __init__( policy=policy, lr_scheduler=lr_scheduler, ) + LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.optim = optim assert ( 0.0 <= discount_factor <= 1.0 @@ -198,23 +201,11 @@ def __init__( self.freq = target_update_freq self._iter = 0 if self._target: - self.model_old = deepcopy(self.policy.model) - self.model_old.eval() + self.model_old = self._add_lagged_network(self.policy.model) self.rew_norm = reward_normalization self.is_double = is_double self.clip_loss_grad = clip_loss_grad - def train(self, mode: bool = True) -> Self: - """Set the module in training mode, except for the target network.""" - # TODO: Determine whether this is called correctly and who relies on this being called (for all subclasses) - self.training = mode - self.policy.train(mode) - return self - - def sync_weight(self) -> None: - """Synchronize the weight for the target network.""" - self.model_old.load_state_dict(self.policy.model.state_dict()) - def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( obs=buffer[indices].obs_next, @@ -259,7 +250,7 @@ def _update_with_batch( **kwargs: Any, ) -> TDQNTrainingStats: if self._target and self._iter % self.freq == 0: - self.sync_weight() + self._update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self.policy(batch).logits diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 8e4609707..015d4f955 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -172,7 +172,7 @@ def _update_with_batch( **kwargs: Any, ) -> TFQFTrainingStats: if self._target and self._iter % self.freq == 0: - self.sync_weight() + self._update_lagged_network_weights() weight = batch.pop("weight", 1.0) out = self.policy(batch) curr_dist_orig = out.logits diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 829d28321..a1245e92d 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -141,7 +141,7 @@ def _update_with_batch( **kwargs: Any, ) -> TIQNTrainingStats: if self._target and self._iter % self.freq == 0: - self.sync_weight() + self._update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) action_batch = self.policy(batch) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index b4c2ec9c1..d8e2ed0b8 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -106,7 +106,7 @@ def _update_with_batch( **kwargs: Any, ) -> TQRDQNTrainingStats: if self._target and self._iter % self.freq == 0: - self.sync_weight() + self._update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) curr_dist = self.policy(batch).logits diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 7ccf9c48b..e8398c289 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,7 +1,7 @@ from abc import ABC from copy import deepcopy from dataclasses import dataclass -from typing import Any, Generic, Literal, Self, TypeVar +from typing import Any, Generic, TypeVar import torch @@ -10,7 +10,6 @@ ActStateBatchProtocol, RolloutBatchProtocol, ) -from tianshou.exploration import BaseNoise from tianshou.policy.base import ( TLearningRateScheduler, TPolicy, @@ -89,8 +88,8 @@ def __init__( raise ValueError("critic2_optim must be provided if critic2 is provided") critic2 = critic2 or deepcopy(critic) critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2, self.critic2_old = critic2, deepcopy(critic2) - self.critic2_old.eval() + self.critic2 = critic2 + self.critic2_old = self._add_lagged_network(self.critic2) self.critic2_optim = critic2_optim def _target_q_compute_value( @@ -103,15 +102,6 @@ def _target_q_compute_value( self.critic2_old(obs_batch.obs, act), ) - def train(self, mode: bool = True) -> Self: - super().train(mode=mode) - self.critic2.train(mode) - return self - - def _update_lagged_network_weights(self) -> None: - super()._update_lagged_network_weights() - self._polyak_parameter_update(self.critic2_old, self.critic2, self.tau) - class TD3( ActorDualCriticsOffPolicyAlgorithm[DDPGPolicy, TTD3TrainingStats, ActStateBatchProtocol], @@ -165,8 +155,7 @@ def __init__( estimation_step=estimation_step, lr_scheduler=lr_scheduler, ) - self.actor_old = deepcopy(policy.actor) - self.actor_old.eval() + self.actor_old = self._add_lagged_network(self.policy.actor) self.policy_noise = policy_noise self.update_actor_freq = update_actor_freq self.noise_clip = noise_clip @@ -187,10 +176,6 @@ def _target_q_compute_action(self, obs_batch: Batch) -> ActStateBatchProtocol: act_batch.act = act_ return act_batch - def _update_lagged_network_weights(self) -> None: - super()._update_lagged_network_weights() - self._polyak_parameter_update(self.actor_old, self.policy.actor, self.tau) - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( diff --git a/tianshou/utils/lagged_network.py b/tianshou/utils/lagged_network.py new file mode 100644 index 000000000..0a1114b3d --- /dev/null +++ b/tianshou/utils/lagged_network.py @@ -0,0 +1,85 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Self + +import torch + + +def polyak_parameter_update(tgt: torch.nn.Module, src: torch.nn.Module, tau: float) -> None: + """Softly updates the parameters of a target network `tgt` with the parameters of a source network `src` + using Polyak averaging: `tau * src + (1 - tau) * tgt`. + + :param tgt: the target network that receives the parameter update + :param src: the source network whose parameters are used for the update + :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being + the fraction with which to retain the target network's parameters. + """ + for tgt_param, src_param in zip(tgt.parameters(), src.parameters(), strict=True): + tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data) + + +class EvalModeModuleWrapper(torch.nn.Module): + """ + A wrapper around a torch.nn.Module that forces the module to eval mode. + + The wrapped module supports only the forward method, attribute access is not supported. + NOTE: It is recommended to support attribute/method access beyond this via `__getattr__`, + because torch.nn.Module already heavily relies on `__getattr__` to provides its own attribute access. + Overriding it naively will cause problems! + But it's also not necessary for our use cases; forward is enough. + """ + + def __init__(self, m: torch.nn.Module): + super().__init__() + m.eval() + self.module = m + + def forward(self, *args, **kwargs): # type: ignore + self.module.eval() + return self.module(*args, **kwargs) + + def train(self, mode: bool = True) -> Self: + super().train(mode=mode) + self.module.eval() # force eval mode + return self + + +@dataclass +class LaggedNetworkPair: + target: torch.nn.Module + source: torch.nn.Module + + +class LaggedNetworkCollection: + def __init__(self) -> None: + self._lagged_network_pairs: list[LaggedNetworkPair] = [] + + def add_lagged_network(self, source: torch.nn.Module) -> torch.nn.Module: + """ + Adds a lagged network to the collection, returning the target network, which + is forced to eval mode. The target network is a copy of the source network, + which, however, supports only the forward method (hence the type torch.nn.Module); + attribute access is not supported. + + :param source: the source network whose parameters are to be copied to the target network + :return: the target network, which supports only the forward method and is forced to eval mode + """ + target = deepcopy(source) + self._lagged_network_pairs.append(LaggedNetworkPair(target, source)) + return EvalModeModuleWrapper(target) + + def polyak_parameter_update(self, tau: float) -> None: + """Softly updates the parameters of each target network `tgt` with the parameters of a source network `src` + using Polyak averaging: `tau * src + (1 - tau) * tgt`. + + :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being + the fraction with which to retain the target network's parameters. + """ + for pair in self._lagged_network_pairs: + polyak_parameter_update(pair.target, pair.source, tau) + + def full_parameter_update(self) -> None: + """Fully updates the target networks with the source networks' parameters (exact copy).""" + for pair in self._lagged_network_pairs: + pair.target.load_state_dict(pair.source.state_dict()) + pair.target.eval() From ffe9d8ecf6d2a501d470cc2769bd1265b163f74f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 12 Mar 2025 20:16:38 +0100 Subject: [PATCH 051/230] v2: Improve on-policy class hierarchy * Refactor base class for actor-critic implementations (now called ActorCriticOnPolicyAlgorithm) to no longer inherit from Reinforce --- CHANGELOG.md | 4 ++- tianshou/policy/modelfree/a2c.py | 47 ++++++++++++++++++++------------ tianshou/policy/modelfree/npg.py | 6 ++-- tianshou/policy/modelfree/pg.py | 16 ----------- tianshou/policy/modelfree/ppo.py | 11 ++++---- 5 files changed, 41 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e592b9ede..fb38cebb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,14 +68,16 @@ for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). * Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): * Introduced base classes (to retain factorization without abusive inheritance): + * `ActorCriticOnPolicyAlgorithm` * `ActorCriticOffPolicyAlgorithm` * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) + * `A2C`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `Reinforce` * `CQL`: * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). * Remove parameter `estimation_step`, which was not actually used (it was only passed it on to its superclass). * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) - * `NPG`: Inherit from `AbstractActorCriticWithAdvantage` instead of `A2C` (which is now has the same base class) + * `NPG`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `A2C` * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 6ada890c0..e67b04181 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -9,9 +9,13 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy import Reinforce -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import ( + OnPolicyAlgorithm, + TLearningRateScheduler, + TrainingStats, +) from tianshou.policy.modelfree.pg import ActorPolicy, TPGTrainingStats +from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -28,7 +32,11 @@ class A2CTrainingStats(TrainingStats): TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) -class AbstractActorCriticWithAdvantage(Reinforce[TPGTrainingStats], Generic[TPGTrainingStats], ABC): +class ActorCriticOnPolicyAlgorithm( + OnPolicyAlgorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats], ABC +): + """Abstract base class for actor-critic algorithms that use generalized advantage estimation (GAE).""" + def __init__( self, *, @@ -41,11 +49,17 @@ def __init__( reward_normalization: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param critic: the critic network. (s -> V(s)) + :param optim: the optimizer for actor and critic network. + :param gae_lambda: in [0, 1], param for generalized advantage estimation (GAE). + :param max_batchsize: the maximum size of the batch when computing GAE. + :param discount_factor: in [0, 1]. + :param reward_normalization: normalize estimated values to have std close to 1. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ super().__init__( policy=policy, - optim=optim, - discount_factor=discount_factor, - reward_normalization=reward_normalization, lr_scheduler=lr_scheduler, ) self.critic = critic @@ -53,13 +67,19 @@ def __init__( self.gae_lambda = gae_lambda self.max_batchsize = max_batchsize self._actor_critic = ActorCritic(self.policy.actor, self.critic) + self.optim = optim + self.gamma = discount_factor + self.rew_norm = reward_normalization + self.ret_rms = RunningMeanStd() + self._eps = 1e-8 - def _compute_returns( + def _add_returns_and_advantages( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithAdvantagesProtocol: + """Adds the returns and advantages to the given batch.""" v_s, v_s_ = [], [] with torch.no_grad(): for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): @@ -72,7 +92,6 @@ def _compute_returns( # consistent with OPENAI baselines' value normalization pipeline. Empirical # study also shows that "minus mean" will harm performances a tiny little bit # due to unknown reasons (on Mujoco envs, not confident, though). - # TODO: see todo in PGPolicy.process_fn if self.rew_norm: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) @@ -95,14 +114,8 @@ def _compute_returns( return cast(BatchWithAdvantagesProtocol, batch) -class A2C(AbstractActorCriticWithAdvantage[TA2CTrainingStats], Generic[TA2CTrainingStats]): - """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ +class A2C(ActorCriticOnPolicyAlgorithm[TA2CTrainingStats], Generic[TA2CTrainingStats]): + """Implementation of (synchronous) Advantage Actor-Critic (A2C). arXiv:1602.01783.""" def __init__( self, @@ -152,7 +165,7 @@ def process_fn( buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithAdvantagesProtocol: - batch = self._compute_returns(batch, buffer, indices) + batch = self._add_returns_and_advantages(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) return batch diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 1c9571295..0f516534d 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -10,7 +10,7 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.a2c import AbstractActorCriticWithAdvantage +from tianshou.policy.modelfree.a2c import ActorCriticOnPolicyAlgorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -26,7 +26,7 @@ class NPGTrainingStats(TrainingStats): TNPGTrainingStats = TypeVar("TNPGTrainingStats", bound=NPGTrainingStats) -class NPG(AbstractActorCriticWithAdvantage[TNPGTrainingStats], Generic[TNPGTrainingStats]): +class NPG(ActorCriticOnPolicyAlgorithm[TNPGTrainingStats], Generic[TNPGTrainingStats]): """Implementation of Natural Policy Gradient. https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf @@ -84,7 +84,7 @@ def process_fn( buffer: ReplayBuffer, indices: np.ndarray, ) -> BatchWithAdvantagesProtocol: - batch = self._compute_returns(batch, buffer, indices) + batch = self._add_returns_and_advantages(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) old_log_prob = [] with torch.no_grad(): diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 3d7d3cdf3..43937fdf2 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -277,22 +277,6 @@ def __init__( ) self.optim = optim - @property - def gamma(self) -> float: - return self.discounted_return_computation.gamma - - @property - def rew_norm(self) -> bool: - return self.discounted_return_computation.rew_norm - - @property - def ret_rms(self) -> RunningMeanStd: - return self.discounted_return_computation.ret_rms - - @property - def _eps(self) -> float: - return self.discounted_return_computation.eps - def process_fn( self, batch: RolloutBatchProtocol, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 215235f03..86e8ecb5a 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Generic, Self, TypeVar +from typing import Any, Generic, Self, TypeVar, cast import numpy as np import torch @@ -131,17 +131,16 @@ def process_fn( indices: np.ndarray, ) -> LogpOldProtocol: if self.recompute_adv: - # buffer input `buffer` and `indices` to be used in `learn()`. + # buffer input `buffer` and `indices` to be used in `_update_with_batch()`. self._buffer, self._indices = buffer, indices - batch = self._compute_returns(batch, buffer, indices) + batch = self._add_returns_and_advantages(batch, buffer, indices) batch.act = to_torch_as(batch.act, batch.v_s) logp_old = [] with torch.no_grad(): for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): logp_old.append(self.policy(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(logp_old, dim=0).flatten() - batch: LogpOldProtocol - return batch + return cast(LogpOldProtocol, batch) # TODO: why does mypy complain? def _update_with_batch( # type: ignore @@ -157,7 +156,7 @@ def _update_with_batch( # type: ignore split_batch_size = batch_size or -1 for step in range(repeat): if self.recompute_adv and step > 0: - batch = self._compute_returns(batch, self._buffer, self._indices) + batch = self._add_returns_and_advantages(batch, self._buffer, self._indices) for minibatch in batch.split(split_batch_size, merge_last=True): gradient_steps += 1 # calculate loss for actor From 7a6c29db7bc6a1a71f46e905a2aae4f54b48bb5c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 12 Mar 2025 21:00:00 +0100 Subject: [PATCH 052/230] v2: Adapt ICM, test_dqn_icm and test_ppo_icm * Add base classes for wrapper algorithms: OnPolicyWrapperAlgorithm, OffPolicyWrapperAlgorithm * Specialize ICM for on-/off-policy (ICMOnPolicyWrapper, ICMOffPolicywrapper), using mixin _ICMMixin to factor out common code * Adapt high-level API parts pertaining to algorithm wrappers and ICM --- examples/atari/atari_dqn.py | 8 +- examples/atari/atari_dqn_hl.py | 6 +- examples/atari/atari_ppo.py | 8 +- examples/atari/atari_ppo_hl.py | 6 +- examples/atari/atari_sac.py | 9 +- examples/atari/atari_sac_hl.py | 6 +- examples/vizdoom/vizdoom_ppo.py | 8 +- test/modelbased/test_dqn_icm.py | 67 +++--- test/modelbased/test_ppo_icm.py | 77 +++--- tianshou/highlevel/algorithm.py | 12 +- tianshou/highlevel/experiment.py | 18 +- tianshou/highlevel/params/policy_wrapper.py | 34 ++- tianshou/policy/__init__.py | 4 +- tianshou/policy/base.py | 85 +++++++ tianshou/policy/modelbased/icm.py | 254 ++++++++++++-------- 15 files changed, 384 insertions(+), 218 deletions(-) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 7374d7a8d..606ce9a2c 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -13,7 +13,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DeepQLearning from tianshou.policy.base import Algorithm -from tianshou.policy.modelbased.icm import ICMPolicy +from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -104,7 +104,7 @@ def main(args: argparse.Namespace = get_args()) -> None: net = DQNet(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: DeepQLearning | ICMPolicy + policy: DeepQLearning | ICMOffPolicyWrapper policy = DeepQLearning( model=net, optim=optim, @@ -127,8 +127,8 @@ def main(args: argparse.Namespace = get_args()) -> None: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( - policy=policy, + policy = ICMOffPolicyWrapper( + wrapped_algorithm=policy, model=icm_net, optim=icm_optim, action_space=env.action_space, diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index dc3a9fd26..1f3469d39 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -17,7 +17,7 @@ ) from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.params.policy_wrapper import ( - PolicyWrapperFactoryIntrinsicCuriosity, + AlgorithmWrapperFactoryIntrinsicCuriosity, ) from tianshou.highlevel.trainer import ( EpochTestCallbackDQNSetEps, @@ -93,8 +93,8 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: - builder.with_policy_wrapper_factory( - PolicyWrapperFactoryIntrinsicCuriosity( + builder.with_algorithm_wrapper_factory( + AlgorithmWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), hidden_sizes=[512], lr=lr, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 6a5d19f23..6493979d0 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -13,8 +13,9 @@ from tianshou.env.atari.atari_network import DQNet, layer_init, scale_obs from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PPO, ICMPolicy +from tianshou.policy import PPO from tianshou.policy.base import Algorithm +from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -167,11 +168,10 @@ def dist(logits: torch.Tensor) -> Categorical: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] - policy=policy, + policy = ICMOnPolicyWrapper( # type: ignore[no-redef] + wrapped_algorithm=policy, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index e0939ecc8..06d59d555 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -19,7 +19,7 @@ from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_wrapper import ( - PolicyWrapperFactoryIntrinsicCuriosity, + AlgorithmWrapperFactoryIntrinsicCuriosity, ) @@ -105,8 +105,8 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: - builder.with_policy_wrapper_factory( - PolicyWrapperFactoryIntrinsicCuriosity( + builder.with_algorithm_wrapper_factory( + AlgorithmWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), hidden_sizes=hidden_sizes, lr=lr, diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 651d9bcb1..b50f1c6de 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -11,7 +11,7 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteSAC, ICMPolicy +from tianshou.policy import DiscreteSAC, ICMOffPolicyWrapper from tianshou.policy.base import Algorithm from tianshou.trainer import OffPolicyTrainer from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -124,7 +124,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: DiscreteSAC | ICMPolicy + policy: DiscreteSAC | ICMOffPolicyWrapper policy = DiscreteSAC( actor=actor, policy_optim=actor_optim, @@ -150,11 +150,10 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.actor_lr) - policy = ICMPolicy( - policy=policy, + policy = ICMOffPolicyWrapper( + wrapped_algorithm=policy, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 049928856..3ce60aa4d 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -19,7 +19,7 @@ from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_params import DiscreteSACParams from tianshou.highlevel.params.policy_wrapper import ( - PolicyWrapperFactoryIntrinsicCuriosity, + AlgorithmWrapperFactoryIntrinsicCuriosity, ) @@ -94,8 +94,8 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) if icm_lr_scale > 0: - builder.with_policy_wrapper_factory( - PolicyWrapperFactoryIntrinsicCuriosity( + builder.with_algorithm_wrapper_factory( + AlgorithmWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), hidden_sizes=hidden_sizes, lr=actor_lr, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 2c32562a1..3f76614d8 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -13,8 +13,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PPO, ICMPolicy +from tianshou.policy import PPO from tianshou.policy.base import Algorithm +from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.trainer import OnPolicyTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -177,11 +178,10 @@ def dist(logits: torch.Tensor) -> Categorical: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] - policy=policy, + policy = ICMOnPolicyWrapper( # type: ignore[no-redef] + wrapped_algorithm=policy, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index f7ee3cac2..c84aa0ce7 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -13,10 +13,9 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DeepQLearning, ICMPolicy -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DQNTrainingStats -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy import DeepQLearning, ICMOffPolicyWrapper +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -108,14 +107,19 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: # dueling=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DeepQLearning[DQNTrainingStats] = DeepQLearning( + policy = DQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + algorithm = DeepQLearning( + policy=policy, + optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) + + # ICM wrapper feature_dim = args.hidden_sizes[-1] obs_dim = space_info.observation_info.obs_dim feature_net = MLP( @@ -133,11 +137,10 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy: ICMPolicy = ICMPolicy( - policy=policy, + icm_algorithm: ICMOffPolicyWrapper = ICMOffPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.lr_scale, reward_scale=args.reward_scale, forward_loss_weight=args.forward_loss_weight, @@ -153,18 +156,23 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector - train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats]( + icm_algorithm, train_envs, buf, exploration_noise=True + ) + test_collector = Collector[CollectStats](icm_algorithm, test_envs, exploration_noise=True) + # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "dqn_icm") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: Algorithm) -> None: + def save_best_fn(policy: icm_algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -183,21 +191,22 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = icm_algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index ee1c46660..1eaff9e1b 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -9,10 +9,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PPO, ICMPolicy +from tianshou.policy import PPO from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ppo import PPOTrainingStats -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -87,34 +88,42 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) + # orthogonal initialization for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) + + # base algorithm: PPO optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy: PPO[PPOTrainingStats] = PPO( + policy = ActorPolicy( actor=actor, - critic=critic, - optim=optim, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), + action_space=env.action_space, + deterministic_eval=True, + ) + algorithm = PPO( + policy=policy, + critic=critic, + optim=optim, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -124,11 +133,11 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: reward_normalization=args.rew_norm, dual_clip=args.dual_clip, value_clip=args.value_clip, - action_space=env.action_space, - deterministic_eval=True, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv, ) + + # ICM wrapper feature_dim = args.hidden_sizes[-1] feature_net = MLP( space_info.observation_info.obs_dim, @@ -145,46 +154,48 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( - policy=policy, + icm_algorithm = ICMOnPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.lr_scale, reward_scale=args.reward_scale, forward_loss_weight=args.forward_loss_weight, ) + # collector train_collector = Collector[CollectStats]( - policy, + icm_algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](icm_algorithm, test_envs) + # log log_path = os.path.join(args.logdir, args.task, "ppo_icm") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: Algorithm) -> None: - torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + def save_best_fn(alg: Algorithm) -> None: + torch.save(alg.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = icm_algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index b52fb0208..a97b9f195 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -41,7 +41,7 @@ TD3Params, TRPOParams, ) -from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory +from tianshou.highlevel.params.policy_wrapper import AlgorithmWrapperFactory from tianshou.highlevel.persistence import PolicyPersistence from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.world import World @@ -99,7 +99,7 @@ class AlgorithmFactory(ABC, ToStringMixin): def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory): self.sampling_config = sampling_config self.optim_factory = optim_factory - self.policy_wrapper_factory: PolicyWrapperFactory | None = None + self.algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() def create_train_test_collector( @@ -146,9 +146,9 @@ def create_train_test_collector( def set_policy_wrapper_factory( self, - policy_wrapper_factory: PolicyWrapperFactory | None, + policy_wrapper_factory: AlgorithmWrapperFactory | None, ) -> None: - self.policy_wrapper_factory = policy_wrapper_factory + self.algorithm_wrapper_factory = policy_wrapper_factory def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: self.trainer_callbacks = callbacks @@ -166,8 +166,8 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: def create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: policy = self._create_algorithm(envs, device) - if self.policy_wrapper_factory is not None: - policy = self.policy_wrapper_factory.create_wrapped_policy( + if self.algorithm_wrapper_factory is not None: + policy = self.algorithm_wrapper_factory.create_wrapped_algorithm( policy, envs, self.optim_factory, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index d6f86eb95..337c16583 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -96,7 +96,7 @@ TD3Params, TRPOParams, ) -from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory +from tianshou.highlevel.params.policy_wrapper import AlgorithmWrapperFactory from tianshou.highlevel.persistence import ( PersistenceGroup, PolicyPersistence, @@ -521,7 +521,7 @@ def __init__( self._sampling_config = sampling_config self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactory | None = None - self._policy_wrapper_factory: PolicyWrapperFactory | None = None + self._algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() @@ -555,13 +555,15 @@ def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: self._logger_factory = logger_factory return self - def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self: - """Allows to define a wrapper around the policy that is created, extending the original policy. + def with_algorithm_wrapper_factory( + self, algorithm_wrapper_factory: AlgorithmWrapperFactory + ) -> Self: + """Allows to define a wrapper around the algorithm that is created, extending the original algorithm. - :param policy_wrapper_factory: the factory for the wrapper + :param algorithm_wrapper_factory: the factory for the wrapper :return: the builder """ - self._policy_wrapper_factory = policy_wrapper_factory + self._algorithm_wrapper_factory = algorithm_wrapper_factory return self def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: @@ -652,8 +654,8 @@ def build(self) -> Experiment: """ algorithm_factory = self._create_algorithm_factory() algorithm_factory.set_trainer_callbacks(self._trainer_callbacks) - if self._policy_wrapper_factory: - algorithm_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) + if self._algorithm_wrapper_factory: + algorithm_factory.set_policy_wrapper_factory(self._algorithm_wrapper_factory) experiment: Experiment = Experiment( config=self._config, env_factory=self._env_factory, diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index ab3994151..51a4438fc 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -8,26 +8,28 @@ from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.optim import OptimizerFactory -from tianshou.policy import Algorithm, ICMPolicy +from tianshou.policy import Algorithm, ICMOffPolicyWrapper +from tianshou.policy.base import OffPolicyAlgorithm, OnPolicyAlgorithm +from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.utils.net.discrete import IntrinsicCuriosityModule -TPolicyOut = TypeVar("TPolicyOut", bound=Algorithm) +TAlgorithmOut = TypeVar("TAlgorithmOut", bound=Algorithm) -class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC): +class AlgorithmWrapperFactory(Generic[TAlgorithmOut], ToStringMixin, ABC): @abstractmethod - def create_wrapped_policy( + def create_wrapped_algorithm( self, policy: Algorithm, envs: Environments, optim_factory: OptimizerFactory, device: TDevice, - ) -> TPolicyOut: + ) -> TAlgorithmOut: pass -class PolicyWrapperFactoryIntrinsicCuriosity( - PolicyWrapperFactory[ICMPolicy], +class AlgorithmWrapperFactoryIntrinsicCuriosity( + AlgorithmWrapperFactory[ICMOffPolicyWrapper | ICMOnPolicyWrapper], ): def __init__( self, @@ -46,13 +48,13 @@ def __init__( self.reward_scale = reward_scale self.forward_loss_weight = forward_loss_weight - def create_wrapped_policy( + def create_wrapped_algorithm( self, - policy: Algorithm, + algorithm: Algorithm, envs: Environments, optim_factory: OptimizerFactory, device: TDevice, - ) -> ICMPolicy: + ) -> ICMOffPolicyWrapper: feature_net = self.feature_net_factory.create_intermediate_module(envs, device) action_dim = envs.get_action_shape() if not isinstance(action_dim, int): @@ -66,11 +68,17 @@ def create_wrapped_policy( device=device, ) icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr) - return ICMPolicy( - policy=policy, + cls: type[ICMOffPolicyWrapper] | type[ICMOnPolicyWrapper] + if isinstance(algorithm, OffPolicyAlgorithm): + cls = ICMOffPolicyWrapper + elif isinstance(algorithm, OnPolicyAlgorithm): + cls = ICMOnPolicyWrapper + else: + raise ValueError(f"{algorithm} is not supported by ICM") + return cls( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=envs.get_action_space(), lr_scale=self.lr_scale, reward_scale=self.reward_scale, forward_loss_weight=self.forward_loss_weight, diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 3f70daa81..bfc71eeaa 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -30,7 +30,7 @@ from tianshou.policy.imitation.discrete_crr import DiscreteCRR from tianshou.policy.imitation.gail import GAIL from tianshou.policy.modelbased.psrl import PSRLPolicy -from tianshou.policy.modelbased.icm import ICMPolicy +from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager __all__ = [ @@ -62,7 +62,7 @@ "DiscreteCRR", "GAIL", "PSRLPolicy", - "ICMPolicy", + "ICMOffPolicyWrapper", "MultiAgentPolicyManager", "TrainingStats", ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 2b1d55379..83c0d2d13 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -832,6 +832,91 @@ def create_trainer(self, config: "OfflineTrainingConfig") -> "OfflineTrainer": return OfflineTrainer(self, config) +TWrappedAlgorthmTrainingStats = TypeVar("TWrappedAlgorthmTrainingStats", bound=TrainingStats) + + +class OnPolicyWrapperAlgorithm( + OnPolicyAlgorithm[TPolicy, TTrainingStats], + Generic[TPolicy, TTrainingStats, TWrappedAlgorthmTrainingStats], + ABC, +): + def __init__( + self, + wrapped_algorithm: OnPolicyAlgorithm[TPolicy, TWrappedAlgorthmTrainingStats], + lr_scheduler: TLearningRateScheduler | None = None, + ): + super().__init__(policy=wrapped_algorithm.policy, lr_scheduler=lr_scheduler) + self.wrapped_algorithm = wrapped_algorithm + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + """Performs the pre-processing as defined by the wrapped algorithm.""" + return self.wrapped_algorithm.process_fn(batch, buffer, indices) + + def post_process_fn( + self, + batch: BatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + """Performs the batch post-processing as defined by the wrapped algorithm.""" + self.wrapped_algorithm.post_process_fn(batch, buffer, indices) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TWrappedAlgorthmTrainingStats: + """Performs the update as defined by the wrapped algorithm.""" + return self.wrapped_algorithm._update_with_batch(batch, **kwargs) + + +class OffPolicyWrapperAlgorithm( + OffPolicyAlgorithm[TPolicy, TTrainingStats], + Generic[TPolicy, TTrainingStats, TWrappedAlgorthmTrainingStats], + ABC, +): + def __init__( + self, + wrapped_algorithm: OffPolicyAlgorithm[TPolicy, TWrappedAlgorthmTrainingStats], + lr_scheduler: TLearningRateScheduler | None = None, + ): + super().__init__(policy=wrapped_algorithm.policy, lr_scheduler=lr_scheduler) + self.wrapped_algorithm = wrapped_algorithm + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + """Performs the pre-processing as defined by the wrapped algorithm.""" + return self.wrapped_algorithm.process_fn(batch, buffer, indices) + + def post_process_fn( + self, + batch: BatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + """Performs the batch post-processing as defined by the wrapped algorithm.""" + self.wrapped_algorithm.post_process_fn(batch, buffer, indices) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TWrappedAlgorthmTrainingStats: + """Performs the update as defined by the wrapped algorithm.""" + return self.wrapped_algorithm._update_with_batch(batch, **kwargs) + + class RandomActionPolicy(Policy): def __init__( self, diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index c3611bd31..7de507396 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -1,18 +1,22 @@ -from typing import Any, Literal, Self, TypeVar +from typing import Any -import gymnasium as gym import numpy as np import torch import torch.nn.functional as F from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import Algorithm +from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.base import ( + OffPolicyAlgorithm, + OffPolicyWrapperAlgorithm, + OnPolicyAlgorithm, + OnPolicyWrapperAlgorithm, TLearningRateScheduler, + TPolicy, TrainingStats, TrainingStatsWrapper, + TTrainingStats, ) from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -32,95 +36,104 @@ def __init__( super().__init__(wrapped_stats) -class ICMPolicy(Algorithm): - """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. - - :param policy: a base policy to add ICM to. - :param model: the ICM model. - :param optim: a torch.optim for optimizing the model. - :param lr_scale: the scaling factor for ICM learning. - :param forward_loss_weight: the weight for forward model loss. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ +class _ICMMixin: + """Implementation of the Intrinsic Curiosity Module (ICM) algorithm. arXiv:1705.05363.""" def __init__( self, *, - policy: Algorithm, # [TTrainingStats] model: IntrinsicCuriosityModule, optim: torch.optim.Optimizer, lr_scale: float, reward_scale: float, forward_loss_weight: float, - action_space: gym.Space, - observation_space: gym.Space | None = None, - action_scaling: bool = False, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: - super().__init__( - action_space=action_space, - observation_space=observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, - ) - self.policy = policy + """ + :param model: the ICM model. + :param optim: the optimizer for parameter `model`. + :param lr_scale: the scaling factor for ICM learning. + :param forward_loss_weight: the weight for forward model loss. + """ self.model = model self.optim = optim self.lr_scale = lr_scale self.reward_scale = reward_scale self.forward_loss_weight = forward_loss_weight - def train(self, mode: bool = True) -> Self: - """Set the module in training mode.""" - self.policy.train(mode) - self.training = mode - self.model.train(mode) - return self + def _icm_preprocess_batch( + self, + batch: RolloutBatchProtocol, + ) -> None: + mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) + batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) + batch.rew += to_numpy(mse_loss * self.reward_scale) - def forward( + @staticmethod + def _icm_postprocess_batch(batch: BatchProtocol) -> None: + # restore original reward + batch.rew = batch.policy.orig_rew + + def _icm_update( self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> ActBatchProtocol: - """Compute action over the given batch data by inner policy. + batch: RolloutBatchProtocol, + original_stats: TrainingStats, + ) -> ICMTrainingStats: + self.optim.zero_grad() + act_hat = batch.policy.act_hat + act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) + inverse_loss = F.cross_entropy(act_hat, act).mean() + forward_loss = batch.policy.mse_loss.mean() + loss = ( + (1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss + ) * self.lr_scale + loss.backward() + self.optim.step() - .. seealso:: + return ICMTrainingStats( + original_stats, + icm_loss=loss.item(), + icm_forward_loss=forward_loss.item(), + icm_inverse_loss=inverse_loss.item(), + ) - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - return self.policy.forward(batch, state, **kwargs) - _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") +class ICMOffPolicyWrapper( + OffPolicyWrapperAlgorithm[TPolicy, ICMTrainingStats, TTrainingStats], _ICMMixin +): + """Implementation of the Intrinsic Curiosity Module (ICM) algorithm for off-policy learning. arXiv:1705.05363.""" - # TODO move to policy - # @override - def add_exploration_noise( + def __init__( self, - act: _TArrOrActBatch, - batch: ObsBatchProtocol, - ) -> _TArrOrActBatch: - return self.policy.add_exploration_noise(act, batch) - - def set_eps(self, eps: float) -> None: - """Set the eps for epsilon-greedy exploration.""" - if hasattr(self.policy, "set_eps"): - self.policy.set_eps(eps) - else: - raise NotImplementedError + *, + wrapped_algorithm: OffPolicyAlgorithm[TPolicy, TTrainingStats], + model: IntrinsicCuriosityModule, + optim: torch.optim.Optimizer, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param wrapped_algorithm: the base algorithm to which we want to add the ICM. + :param model: the ICM model. + :param optim: the optimizer for parameter `model`. + :param lr_scale: the scaling factor for ICM learning. + :param forward_loss_weight: the weight for forward model loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + OffPolicyWrapperAlgorithm.__init__( + self, + wrapped_algorithm=wrapped_algorithm, + lr_scheduler=lr_scheduler, + ) + _ICMMixin.__init__( + self, + model=model, + optim=optim, + lr_scale=lr_scale, + reward_scale=reward_scale, + forward_loss_weight=forward_loss_weight, + ) def process_fn( self, @@ -128,14 +141,8 @@ def process_fn( buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: - """Pre-process the data from the provided replay buffer. - - Used in :meth:`update`. Check out :ref:`process_fn` for more information. - """ - mse_loss, act_hat = self.model(batch.obs, batch.act, batch.obs_next) - batch.policy = Batch(orig_rew=batch.rew, act_hat=act_hat, mse_loss=mse_loss) - batch.rew += to_numpy(mse_loss * self.reward_scale) - return self.policy.process_fn(batch, buffer, indices) + self._icm_preprocess_batch(batch) + return super().process_fn(batch, buffer, indices) def post_process_fn( self, @@ -143,13 +150,8 @@ def post_process_fn( buffer: ReplayBuffer, indices: np.ndarray, ) -> None: - """Post-process the data from the provided replay buffer. - - Typical usage is to update the sampling weight in prioritized - experience replay. Used in :meth:`update`. - """ - self.policy.post_process_fn(batch, buffer, indices) - batch.rew = batch.policy.orig_rew # restore original reward + super().post_process_fn(batch, buffer, indices) + self._icm_postprocess_batch(batch) def _update_with_batch( self, @@ -157,21 +159,71 @@ def _update_with_batch( *args: Any, **kwargs: Any, ) -> ICMTrainingStats: - training_stat = self.policy._update_with_batch(batch, **kwargs) - self.optim.zero_grad() - act_hat = batch.policy.act_hat - act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) - inverse_loss = F.cross_entropy(act_hat, act).mean() - forward_loss = batch.policy.mse_loss.mean() - loss = ( - (1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss - ) * self.lr_scale - loss.backward() - self.optim.step() + wrapped_stats = super()._update_with_batch(batch, *args, **kwargs) + return self._icm_update(batch, wrapped_stats) - return ICMTrainingStats( - training_stat, - icm_loss=loss.item(), - icm_forward_loss=forward_loss.item(), - icm_inverse_loss=inverse_loss.item(), + +class ICMOnPolicyWrapper( + OnPolicyWrapperAlgorithm[TPolicy, ICMTrainingStats, TTrainingStats], _ICMMixin +): + """Implementation of the Intrinsic Curiosity Module (ICM) algorithm for on-policy learning. arXiv:1705.05363.""" + + def __init__( + self, + *, + wrapped_algorithm: OnPolicyAlgorithm[TPolicy, TTrainingStats], + model: IntrinsicCuriosityModule, + optim: torch.optim.Optimizer, + lr_scale: float, + reward_scale: float, + forward_loss_weight: float, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param wrapped_algorithm: the base algorithm to which we want to add the ICM. + :param model: the ICM model. + :param optim: the optimizer for parameter `model`. + :param lr_scale: the scaling factor for ICM learning. + :param forward_loss_weight: the weight for forward model loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + OnPolicyWrapperAlgorithm.__init__( + self, + wrapped_algorithm=wrapped_algorithm, + lr_scheduler=lr_scheduler, ) + _ICMMixin.__init__( + self, + model=model, + optim=optim, + lr_scale=lr_scale, + reward_scale=reward_scale, + forward_loss_weight=forward_loss_weight, + ) + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + self._icm_preprocess_batch(batch) + return super().process_fn(batch, buffer, indices) + + def post_process_fn( + self, + batch: BatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> None: + super().post_process_fn(batch, buffer, indices) + self._icm_postprocess_batch(batch) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> ICMTrainingStats: + wrapped_stats = super()._update_with_batch(batch, *args, **kwargs) + return self._icm_update(batch, wrapped_stats) From ff5f889f34d6375bb7bc6d546b023bdc9043e3f1 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 12 Mar 2025 21:19:50 +0100 Subject: [PATCH 053/230] v2: Adapt PSRL and test_psrl --- test/modelbased/test_psrl.py | 46 ++++++------ tianshou/policy/__init__.py | 37 +--------- tianshou/policy/modelbased/psrl.py | 109 ++++++++++++++++------------- 3 files changed, 87 insertions(+), 105 deletions(-) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index d94f1cd1f..a150df52a 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -7,8 +7,9 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.policy import PSRLPolicy -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy import PSRL +from tianshou.policy.modelbased.psrl import PSRLPolicy +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger try: @@ -67,24 +68,27 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: trans_count_prior = np.ones((n_state, n_action, n_state)) rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) - policy: PSRLPolicy = PSRLPolicy( + policy = PSRLPolicy( trans_count_prior=trans_count_prior, rew_mean_prior=rew_mean_prior, rew_std_prior=rew_std_prior, action_space=env.action_space, discount_factor=args.gamma, epsilon=args.eps, + ) + algorithm: PSRL = PSRL( + policy=policy, add_done_loop=args.add_done_loop, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) train_collector.reset() - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) test_collector.reset() # Logger log_path = os.path.join(args.logdir, args.task, "psrl") @@ -103,19 +107,21 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold train_collector.collect(n_step=args.buffer_size, random=True) - # trainer, test it without logger - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=1, - episode_per_test=args.test_num, - batch_size=0, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - logger=logger, - test_in_train=False, - ).run() + # train (test it without logger) + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=1, + episode_per_test=args.test_num, + batch_size=0, + episode_per_collect=args.episode_per_collect, + step_per_collect=None, + stop_fn=stop_fn, + logger=logger, + test_in_train=False, + ) + ) assert result.best_reward >= args.reward_threshold diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index bfc71eeaa..0cef14012 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -29,40 +29,7 @@ from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRR from tianshou.policy.imitation.gail import GAIL -from tianshou.policy.modelbased.psrl import PSRLPolicy +from tianshou.policy.modelbased.psrl import PSRL from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper +from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager - -__all__ = [ - "Algorithm", - "MARLRandomPolicy", - "DeepQLearning", - "BranchingDuelingQNetwork", - "C51", - "RainbowDQN", - "QRDQN", - "IQN", - "FQF", - "Reinforce", - "A2C", - "NPG", - "DDPG", - "PPO", - "TRPO", - "TD3", - "SAC", - "REDQ", - "DiscreteSAC", - "ImitationLearning", - "BCQ", - "CQL", - "TD3BCPolicy", - "DiscreteBCQ", - "DiscreteCQLPolicy", - "DiscreteCRR", - "GAIL", - "PSRLPolicy", - "ICMOffPolicyWrapper", - "MultiAgentPolicyManager", - "TrainingStats", -] diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index c6f433d1f..987a41f56 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -8,8 +8,12 @@ from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import Algorithm -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import ( + OnPolicyAlgorithm, + Policy, + TLearningRateScheduler, + TrainingStats, +) @dataclass(kw_only=True) @@ -22,19 +26,7 @@ class PSRLTrainingStats(TrainingStats): class PSRLModel: - """Implementation of Posterior Sampling Reinforcement Learning Model. - - :param trans_count_prior: dirichlet prior (alphas), with shape - (n_state, n_action, n_state). - :param rew_mean_prior: means of the normal priors of rewards, - with shape (n_state, n_action). - :param rew_std_prior: standard deviations of the normal priors - of rewards, with shape (n_state, n_action). - :param discount_factor: in [0, 1]. - :param epsilon: for precision control in value iteration. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in - optimizer in each policy.update(). Default to None (no lr_scheduler). - """ + """Implementation of Posterior Sampling Reinforcement Learning Model.""" def __init__( self, @@ -44,6 +36,16 @@ def __init__( discount_factor: float, epsilon: float, ) -> None: + """ + :param trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). + :param discount_factor: in [0, 1]. + :param epsilon: for precision control in value iteration. + """ self.trans_count = trans_count_prior self.n_state, self.n_action = rew_mean_prior.shape self.rew_mean = rew_mean_prior @@ -150,32 +152,7 @@ def __call__( return self.policy[obs] -class PSRLPolicy(Algorithm): - """Implementation of Posterior Sampling Reinforcement Learning. - - Reference: Strens M. A Bayesian framework for reinforcement learning [C] - //ICML. 2000, 2000: 943-950. - - :param trans_count_prior: dirichlet prior (alphas), with shape - (n_state, n_action, n_state). - :param rew_mean_prior: means of the normal priors of rewards, - with shape (n_state, n_action). - :param rew_std_prior: standard deviations of the normal priors - of rewards, with shape (n_state, n_action). - :param action_space: Env's action_space. - :param discount_factor: in [0, 1]. - :param epsilon: for precision control in value iteration. - :param add_done_loop: whether to add an extra self-loop for the - terminal state in MDP. Default to False. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ - +class PSRLPolicy(Policy): def __init__( self, *, @@ -185,18 +162,25 @@ def __init__( action_space: gym.spaces.Discrete, discount_factor: float = 0.99, epsilon: float = 0.01, - add_done_loop: bool = False, observation_space: gym.Space | None = None, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param trans_count_prior: dirichlet prior (alphas), with shape + (n_state, n_action, n_state). + :param rew_mean_prior: means of the normal priors of rewards, + with shape (n_state, n_action). + :param rew_std_prior: standard deviations of the normal priors + of rewards, with shape (n_state, n_action). + :param action_space: Env's action_space. + :param epsilon: for precision control in value iteration. + :param observation_space: Env's observation space. + """ super().__init__( action_space=action_space, observation_space=observation_space, action_scaling=False, action_bound_method=None, - lr_scheduler=lr_scheduler, ) - assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self.model = PSRLModel( trans_count_prior, rew_mean_prior, @@ -204,7 +188,6 @@ def __init__( discount_factor, epsilon, ) - self._add_done_loop = add_done_loop def forward( self, @@ -227,13 +210,39 @@ def forward( act = self.model(batch.obs, state=state, info=batch.info) return cast(ActBatchProtocol, Batch(act=act)) + +class PSRL(OnPolicyAlgorithm[PSRLPolicy, TPSRLTrainingStats]): + """Implementation of Posterior Sampling Reinforcement Learning (PSRL). + + Reference: Strens M., A Bayesian Framework for Reinforcement Learning, ICML, 2000. + """ + + def __init__( + self, + *, + policy: PSRLPolicy, + add_done_loop: bool = False, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :param add_done_loop: whether to add an extra self-loop for the + terminal state in MDP. Default to False. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + lr_scheduler=lr_scheduler, + ) + self._add_done_loop = add_done_loop + def _update_with_batch( self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any, ) -> TPSRLTrainingStats: - n_s, n_a = self.model.n_state, self.model.n_action + n_s, n_a = self.policy.model.n_state, self.policy.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) rew_square_sum = np.zeros((n_s, n_a)) @@ -250,9 +259,9 @@ def _update_with_batch( # special operation for terminal states: add a self-loop trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 - self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) + self.policy.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) return PSRLTrainingStats( # type: ignore[return-value] - psrl_rew_mean=float(self.model.rew_mean.mean()), - psrl_rew_std=float(self.model.rew_std.mean()), + psrl_rew_mean=float(self.policy.model.rew_mean.mean()), + psrl_rew_std=float(self.policy.model.rew_std.mean()), ) From e4b1ba711898fe05038357066fc96916dcc5865c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 14 Mar 2025 12:45:37 +0100 Subject: [PATCH 054/230] v2: Rename algorithms back to acronym-based name (DQN, BDQN) --- CHANGELOG.md | 2 +- README.md | 8 +++++--- examples/atari/atari_dqn.py | 10 ++++------ examples/box2d/acrobot_dualdqn.py | 4 ++-- examples/box2d/bipedal_bdq.py | 4 ++-- examples/box2d/lunarlander_dqn.py | 4 ++-- examples/discrete/discrete_dqn.py | 2 +- test/discrete/test_bdqn.py | 8 ++++---- test/discrete/test_dqn.py | 4 ++-- test/discrete/test_drqn.py | 4 ++-- test/modelbased/test_dqn_icm.py | 4 ++-- test/pettingzoo/pistonball.py | 4 ++-- test/pettingzoo/pistonball_continuous.py | 6 +++--- test/pettingzoo/tic_tac_toe.py | 4 ++-- tianshou/highlevel/algorithm.py | 8 ++++---- tianshou/highlevel/trainer.py | 8 ++++---- tianshou/policy/__init__.py | 4 ++-- tianshou/policy/modelfree/bdqn.py | 10 +++++----- tianshou/policy/modelfree/c51.py | 4 ++-- tianshou/policy/modelfree/dqn.py | 2 +- tianshou/policy/modelfree/qrdqn.py | 6 ++---- 21 files changed, 54 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb38cebb8..867e21aed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,7 +51,7 @@ * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`. Migration information (`BasePolicy` -> `Algorithm`): * `PGPolicy` -> `Reinforce` - * `DQNPolicy` -> `DeepQLearning` + * `DQNPolicy` -> `DQN` * `DDPGPolicy` -> `DDPG` * The `Algorithm` abstraction can directly initiate the learning process via method `run_training`. * Internal design improvements: diff --git a/README.md b/README.md index fb6a6a9de..eb9067c36 100644 --- a/README.md +++ b/README.md @@ -370,7 +370,7 @@ optim = torch.optim.Adam(net.parameters(), lr=lr) Set up the policy and collectors: ```python -policy = ts.policy.DeepQLearning( +policy = ts.policy.DQN( model=net, optim=optim, discount_factor=gamma, @@ -378,8 +378,10 @@ policy = ts.policy.DeepQLearning( estimation_step=n_step, target_update_freq=target_freq ) -train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True) -test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method +train_collector = ts.data.Collector(policy, train_envs, + ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True) +test_collector = ts.data.Collector(policy, test_envs, + exploration_noise=True) # because DQN uses epsilon-greedy method ``` Let's train it: diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 606ce9a2c..07b82ac77 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -11,7 +11,7 @@ from examples.atari.atari_network import DQNet from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DeepQLearning +from tianshou.policy import DQN from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper from tianshou.trainer import OffPolicyTrainer @@ -104,8 +104,8 @@ def main(args: argparse.Namespace = get_args()) -> None: net = DQNet(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: DeepQLearning | ICMOffPolicyWrapper - policy = DeepQLearning( + policy: DQN | ICMOffPolicyWrapper + policy = DQN( model=net, optim=optim, action_space=env.action_space, @@ -114,9 +114,7 @@ def main(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ) if args.icm_lr_scale > 0: - feature_net = DeepQLearning( - *args.state_shape, args.action_shape, args.device, features_only=True - ) + feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 52e290cc1..b5820d2d4 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DeepQLearning +from tianshou.policy import DQN from tianshou.policy.base import Algorithm from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger @@ -75,7 +75,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: dueling_param=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DeepQLearning = DeepQLearning( + policy: DQN = DQN( model=net, optim=optim, action_space=env.action_space, diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 7b35b16c8..97c4a4683 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv -from tianshou.policy import BranchingDuelingQNetwork +from tianshou.policy import BDQN from tianshou.policy.base import Algorithm from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger @@ -101,7 +101,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BranchingDuelingQNetwork = BranchingDuelingQNetwork( + policy: BDQN = BDQN( model=net, optim=optim, discount_factor=args.gamma, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 75fc5eaee..332df7747 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import DeepQLearning +from tianshou.policy import DQN from tianshou.policy.base import Algorithm from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -77,7 +77,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: dueling_param=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DeepQLearning = DeepQLearning( + policy: DQN = DQN( model=net, optim=optim, action_space=env.action_space, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 3605c0985..a2f6596a8 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -38,7 +38,7 @@ def main() -> None: optim = torch.optim.Adam(net.parameters(), lr=lr) policy = DQNPolicy(model=net, action_space=env.action_space) - algorithm: ts.policy.DeepQLearning = ts.policy.DeepQLearning( + algorithm: ts.policy.DQN = ts.policy.DQN( policy=policy, optim=optim, discount_factor=gamma, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 1a9d9efc4..cbb8239de 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -6,8 +6,8 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, DummyVectorEnv -from tianshou.policy import BranchingDuelingQNetwork -from tianshou.policy.modelfree.bdqn import BranchingDuelingQNetworkPolicy +from tianshou.policy import BDQN +from tianshou.policy.modelfree.bdqn import BDQNPolicy from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils.net.common import BranchingNet @@ -99,11 +99,11 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = BranchingDuelingQNetworkPolicy( + policy = BDQNPolicy( model=net, action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? ) - algorithm: BranchingDuelingQNetwork = BranchingDuelingQNetwork( + algorithm: BDQN = BDQN( policy=policy, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 0f0b1ebb6..6e9e81903 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -14,7 +14,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DeepQLearning +from tianshou.policy import DQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.trainer.base import OffPolicyTrainingConfig @@ -91,7 +91,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: policy = DQNPolicy( model=net, action_space=env.action_space, observation_space=env.observation_space ) - algorithm: DeepQLearning = DeepQLearning( + algorithm: DQN = DQN( policy=policy, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5cbf30325..6e2395303 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -8,7 +8,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DeepQLearning +from tianshou.policy import DQN from tianshou.policy.base import Algorithm from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger @@ -74,7 +74,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: args.device, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DeepQLearning = DeepQLearning( + policy: DQN = DQN( model=net, optim=optim, discount_factor=args.gamma, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index c84aa0ce7..b88379b9d 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -13,7 +13,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DeepQLearning, ICMOffPolicyWrapper +from tianshou.policy import DQN, ICMOffPolicyWrapper from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger @@ -111,7 +111,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: model=net, action_space=env.action_space, ) - algorithm = DeepQLearning( + algorithm = DQN( policy=policy, optim=optim, discount_factor=args.gamma, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index d17db94f7..1849717eb 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import Algorithm, DeepQLearning, MultiAgentPolicyManager +from tianshou.policy import DQN, Algorithm, MultiAgentPolicyManager from tianshou.trainer import OffPolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -97,7 +97,7 @@ def get_agents( device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent: DeepQLearning = DeepQLearning( + agent: DQN = DQN( model=net, optim=optim, action_space=env.action_space, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 4c476512d..9db48cc3f 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -21,7 +21,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic -class DQN(nn.Module): +class DQNet(nn.Module): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -156,7 +156,7 @@ def get_agents( optims = [] for _ in range(args.n_pistons): # model - net = DQN( + net = DQNet( observation_space.shape[2], observation_space.shape[1], observation_space.shape[0], @@ -169,7 +169,7 @@ def get_agents( max_action=args.max_action, device=args.device, ).to(args.device) - net2 = DQN( + net2 = DQNet( observation_space.shape[2], observation_space.shape[1], observation_space.shape[0], diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 9ad50f9c2..cf53760dd 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -14,8 +14,8 @@ from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import ( + DQN, BasePolicy, - DeepQLearning, MARLRandomPolicy, MultiAgentPolicyManager, ) @@ -120,7 +120,7 @@ def get_agents( ).to(args.device) if optim is None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent_learn = DeepQLearning( + agent_learn = DQN( model=net, optim=optim, action_space=env.action_space, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index a97b9f195..1ebfb8653 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -48,6 +48,7 @@ from tianshou.policy import ( A2C, DDPG, + DQN, IQN, NPG, PPO, @@ -56,7 +57,6 @@ TD3, TRPO, Algorithm, - DeepQLearning, DiscreteSAC, Reinforce, ) @@ -466,7 +466,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: ) -class DeepQLearningAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DeepQLearning]): +class DeepQLearningAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DQN]): def _create_policy( self, model: torch.nn.Module, @@ -483,8 +483,8 @@ def _create_policy( observation_space=observation_space, ) - def _get_algorithm_class(self) -> type[DeepQLearning]: - return DeepQLearning + def _get_algorithm_class(self) -> type[DQN]: + return DQN class IQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[IQNParams, IQN]): diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 04b43a227..4e1397e25 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -8,7 +8,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger -from tianshou.policy import Algorithm, DeepQLearning +from tianshou.policy import DQN, Algorithm TPolicy = TypeVar("TPolicy", bound=Algorithm) log = logging.getLogger(__name__) @@ -90,7 +90,7 @@ def __init__(self, eps_test: float): self.eps_test = eps_test def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy = cast(DeepQLearning, context.policy) + policy = cast(DQN, context.policy) policy.set_eps(self.eps_test) @@ -105,7 +105,7 @@ def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = self.decay_steps = decay_steps def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy = cast(DeepQLearning, context.policy) + policy = cast(DQN, context.policy) logger = context.logger if env_step <= self.decay_steps: eps = self.eps_train - env_step / self.decay_steps * ( @@ -126,7 +126,7 @@ def __init__(self, eps_test: float): self.eps_test = eps_test def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: - policy = cast(DeepQLearning, context.policy) + policy = cast(DQN, context.policy) policy.set_eps(self.eps_test) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 0cef14012..21d7f0ebe 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -3,11 +3,11 @@ from tianshou.policy.base import Algorithm, TrainingStats from tianshou.policy.modelfree.pg import Reinforce -from tianshou.policy.modelfree.dqn import DeepQLearning +from tianshou.policy.modelfree.dqn import DQN from tianshou.policy.modelfree.ddpg import DDPG from tianshou.policy.random import MARLRandomPolicy -from tianshou.policy.modelfree.bdqn import BranchingDuelingQNetwork +from tianshou.policy.modelfree.bdqn import BDQN from tianshou.policy.modelfree.c51 import C51 from tianshou.policy.modelfree.rainbow import RainbowDQN from tianshou.policy.modelfree.qrdqn import QRDQN diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 26e7da5c3..18a1e46f6 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -15,7 +15,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import DeepQLearning +from tianshou.policy import DQN from tianshou.policy.base import TArrOrActBatch, TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats from tianshou.utils.net.common import BranchingNet @@ -31,7 +31,7 @@ class BDQNTrainingStats(DQNTrainingStats): TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) -class BranchingDuelingQNetworkPolicy(DQNPolicy[BranchingNet]): +class BDQNPolicy(DQNPolicy[BranchingNet]): def __init__( self, *, @@ -87,13 +87,13 @@ def add_exploration_noise( return act -class BranchingDuelingQNetwork(DeepQLearning[BranchingDuelingQNetworkPolicy, TBDQNTrainingStats]): - """Implementation of the Branching Dueling Q-Network algorithm arXiv:1711.08946.""" +class BDQN(DQN[BDQNPolicy, TBDQNTrainingStats]): + """Implementation of the Branching Dueling Q-Network (BDQN) algorithm arXiv:1711.08946.""" def __init__( self, *, - policy: BranchingDuelingQNetworkPolicy, + policy: BDQNPolicy, optim: torch.optim.Optimizer, discount_factor: float = 0.99, estimation_step: int = 1, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 24ea7b2fe..39d8718a6 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -7,7 +7,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import DeepQLearning +from tianshou.policy import DQN from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats from tianshou.utils.net.common import Net @@ -57,7 +57,7 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc return super().compute_q_value((logits * self.support).sum(2), mask) -class C51(DeepQLearning[C51Policy, TC51TrainingStats], Generic[TC51TrainingStats]): +class C51(DQN[C51Policy, TC51TrainingStats], Generic[TC51TrainingStats]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. :param policy: a policy following the rules (s -> action_values_BA) :param optim: a torch.optim for optimizing the policy. diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index b1a365730..bc70f34a3 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -142,7 +142,7 @@ def add_exploration_noise( TDQNPolicy = TypeVar("TDQNPolicy", bound=DQNPolicy) -class DeepQLearning( +class DQN( OffPolicyAlgorithm[TDQNPolicy, TDQNTrainingStats], LaggedNetworkFullUpdateAlgorithmMixin, Generic[TDQNPolicy, TDQNTrainingStats], diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index d8e2ed0b8..9c10ef583 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -8,7 +8,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import DeepQLearning +from tianshou.policy import DQN from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats @@ -29,9 +29,7 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc TQRDQNPolicy = TypeVar("TQRDQNPolicy", bound=QRDQNPolicy) -class QRDQN( - DeepQLearning[TQRDQNPolicy, TQRDQNTrainingStats], Generic[TQRDQNPolicy, TQRDQNTrainingStats] -): +class QRDQN(DQN[TQRDQNPolicy, TQRDQNTrainingStats], Generic[TQRDQNPolicy, TQRDQNTrainingStats]): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.""" def __init__( From 26e87b9d7a2391f4a6fe295a65b55128ee92f48b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 14 Mar 2025 13:32:56 +0100 Subject: [PATCH 055/230] v2: Improve class hierarchy of deep Q-learning algorithms * Introduce base class QLearningOffPolicyAlgorithm * QRDQN, C51, BDQN: Inherit from QLearningOffPolicyAlgorithm instead of DQN * Remove algorithm parameters that were not being used and were only passed on to the base class (for no reason; had no effect whatsoever): * QRDQN, C51, BDQN, IQN, FQF, DiscreteCQL: Remove clip_loss_grad * QRDQN, C51, IQN, FQF, DiscreteCQL: Remove is_double --- CHANGELOG.md | 18 ++- tianshou/policy/imitation/discrete_cql.py | 12 +- tianshou/policy/modelfree/bdqn.py | 24 ++-- tianshou/policy/modelfree/c51.py | 56 ++++----- tianshou/policy/modelfree/dqn.py | 147 +++++++++++++++------- tianshou/policy/modelfree/fqf.py | 14 +-- tianshou/policy/modelfree/iqn.py | 12 +- tianshou/policy/modelfree/qrdqn.py | 26 ++-- tianshou/policy/modelfree/rainbow.py | 2 +- 9 files changed, 171 insertions(+), 140 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 867e21aed..2b8af73d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -71,13 +71,29 @@ * `ActorCriticOnPolicyAlgorithm` * `ActorCriticOffPolicyAlgorithm` * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) + * `QLearningOffPolicyAlgorithm` * `A2C`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `Reinforce` + * `BDQN`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameter `clip_loss_grad` (unused; only passed on to former base class) + * `C51`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) * `CQL`: * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). * Remove parameter `estimation_step`, which was not actually used (it was only passed it on to its superclass). + * `DiscreteCQL`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) - * `NPG`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `A2C` + * `FQF`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). + * `IQN`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). + * `NPG`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `A2C` + * `QRDQN`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 30f0aafa3..2a2836c1c 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -37,10 +37,6 @@ class DiscreteCQLPolicy(QRDQN): you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. :param observation_space: Env's observation space. :param lr_scheduler: if not None, will be called in `policy.update()`. @@ -61,8 +57,6 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: @@ -75,8 +69,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, observation_space=observation_space, lr_scheduler=lr_scheduler, ) @@ -88,8 +80,7 @@ def _update_with_batch( *args: Any, **kwargs: Any, ) -> TDiscreteCQLTrainingStats: - if self._target and self._iter % self.freq == 0: - self._update_lagged_network_weights() + self._periodically_update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) all_dist = self(batch).logits @@ -115,7 +106,6 @@ def _update_with_batch( loss = qr_loss + min_q_loss * self.min_q_weight loss.backward() self.optim.step() - self._iter += 1 return DiscreteCQLTrainingStats( # type: ignore[return-value] loss=loss.item(), diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 18a1e46f6..bee04a729 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -15,9 +15,12 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import DQN from tianshou.policy.base import TArrOrActBatch, TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats +from tianshou.policy.modelfree.dqn import ( + DQNPolicy, + DQNTrainingStats, + QLearningOffPolicyAlgorithm, +) from tianshou.utils.net.common import BranchingNet mark_used(ActBatchProtocol) @@ -87,7 +90,7 @@ def add_exploration_noise( return act -class BDQN(DQN[BDQNPolicy, TBDQNTrainingStats]): +class BDQN(QLearningOffPolicyAlgorithm[BDQNPolicy, TBDQNTrainingStats]): """Implementation of the Branching Dueling Q-Network (BDQN) algorithm arXiv:1711.08946.""" def __init__( @@ -100,7 +103,6 @@ def __init__( target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, - clip_loss_grad: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ @@ -112,10 +114,7 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. + :param is_double: whether to use double DQN. :param lr_scheduler: if not None, will be called in `policy.update()`. """ assert ( @@ -128,10 +127,9 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, lr_scheduler=lr_scheduler, ) + self.is_double = is_double def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( @@ -139,7 +137,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: info=[None] * len(indices), ) # obs_next: s_{t+n} result = self.policy(obs_next_batch) - if self._target: + if self.use_target_network: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) target_q = self.policy(obs_next_batch, model=self.model_old).logits else: @@ -190,8 +188,7 @@ def _update_with_batch( *args: Any, **kwargs: Any, ) -> TBDQNTrainingStats: - if self._target and self._iter % self.freq == 0: - self._update_lagged_network_weights() + self._periodically_update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) @@ -206,6 +203,5 @@ def _update_with_batch( batch.weight = td_error.sum(-1).sum(-1) # prio-buffer loss.backward() self.optim.step() - self._iter += 1 return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 39d8718a6..fdbbb9521 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -7,9 +7,12 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import DQN from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats +from tianshou.policy.modelfree.dqn import ( + DQNPolicy, + DQNTrainingStats, + QLearningOffPolicyAlgorithm, +) from tianshou.utils.net.common import Net @@ -57,28 +60,8 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc return super().compute_q_value((logits * self.support).sum(2), mask) -class C51(DQN[C51Policy, TC51TrainingStats], Generic[TC51TrainingStats]): - """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. - :param policy: a policy following the rules (s -> action_values_BA) - :param optim: a torch.optim for optimizing the policy. - :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed - explanation. - """ +class C51(QLearningOffPolicyAlgorithm[C51Policy, TC51TrainingStats], Generic[TC51TrainingStats]): + """Implementation of Categorical Deep Q-Network. arXiv:1707.06887.""" def __init__( self, @@ -89,10 +72,23 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: + """ + :param policy: a policy following the rules (s -> action_values_BA) + :param optim: a torch.optim for optimizing the policy. + :param discount_factor: in [0, 1]. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ super().__init__( policy=policy, optim=optim, @@ -100,8 +96,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, lr_scheduler=lr_scheduler, ) self.delta_z = (policy.v_max - policy.v_min) / (policy.num_atoms - 1) @@ -111,7 +105,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: obs_next_batch = Batch(obs=batch.obs_next, info=[None] * len(batch)) - if self._target: + if self.use_target_network: act = self.policy(obs_next_batch).act next_dist = self.policy(obs_next_batch, model=self.model_old).logits else: @@ -135,8 +129,7 @@ def _update_with_batch( *args: Any, **kwargs: Any, ) -> TC51TrainingStats: - if self._target and self._iter % self.freq == 0: - self._update_lagged_network_weights() + self._periodically_update_lagged_network_weights() self.optim.zero_grad() with torch.no_grad(): target_dist = self._target_dist(batch) @@ -150,6 +143,5 @@ def _update_with_batch( batch.weight = cross_entropy.detach() # prio-buffer loss.backward() self.optim.step() - self._iter += 1 return C51TrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index bc70f34a3..886e61dd9 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Generic, TypeVar, cast @@ -22,6 +23,7 @@ TArrOrActBatch, TLearningRateScheduler, TrainingStats, + TTrainingStats, ) from tianshou.utils.net.common import Net @@ -142,17 +144,14 @@ def add_exploration_noise( TDQNPolicy = TypeVar("TDQNPolicy", bound=DQNPolicy) -class DQN( - OffPolicyAlgorithm[TDQNPolicy, TDQNTrainingStats], - LaggedNetworkFullUpdateAlgorithmMixin, - Generic[TDQNPolicy, TDQNTrainingStats], +class QLearningOffPolicyAlgorithm( + OffPolicyAlgorithm[TDQNPolicy, TTrainingStats], LaggedNetworkFullUpdateAlgorithmMixin, ABC ): - """Implementation of Deep Q Network. arXiv:1312.5602. - - Implementation of Double Q-Learning. arXiv:1509.06461. - - Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is - implemented in the network side, not here). + """ + Base class for Q-learning off-policy algorithms that use a Q-function to compute the + n-step return. + It optionally uses a lagged model, which is used as a target network and which is + fully updated periodically. """ def __init__( @@ -164,8 +163,6 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ @@ -173,22 +170,18 @@ def __init__( :param optim: the optimizer for the policy :param discount_factor: in [0, 1]. :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the frequency with which to update the weights of the target network; + 0 if a target network shall not be used. :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, lr_scheduler=lr_scheduler, ) - LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.optim = optim + LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) assert ( 0.0 <= discount_factor <= 1.0 ), f"discount factor should be in [0, 1] but got: {discount_factor}" @@ -197,30 +190,21 @@ def __init__( estimation_step > 0 ), f"estimation_step should be greater than 0 but got: {estimation_step}" self.n_step = estimation_step - self._target = target_update_freq > 0 - self.freq = target_update_freq - self._iter = 0 - if self._target: - self.model_old = self._add_lagged_network(self.policy.model) self.rew_norm = reward_normalization - self.is_double = is_double - self.clip_loss_grad = clip_loss_grad + self.target_update_freq = target_update_freq + # TODO: 1 would be a more reasonable initialization given how it is incremented + self._iter = 0 + self.model_old = ( + self._add_lagged_network(self.policy.model) if self.use_target_network else None + ) + @property + def use_target_network(self) -> bool: + return self.target_update_freq > 0 + + @abstractmethod def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - obs_next_batch = Batch( - obs=buffer[indices].obs_next, - info=[None] * len(indices), - ) # obs_next: s_{t+n} - result = self.policy(obs_next_batch) - if self._target: - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - target_q = self.policy(obs_next_batch, model=self.model_old).logits - else: - target_q = result.logits - if self.is_double: - return target_q[np.arange(len(result.act)), result.act] - # Nature DQN, over estimate - return target_q.max(dim=1)[0] + pass def process_fn( self, @@ -243,14 +227,92 @@ def process_fn( rew_norm=self.rew_norm, ) + def _periodically_update_lagged_network_weights(self) -> None: + """ + Periodically updates the parameters of the lagged target network (if any), i.e. + every n-th call (where n=`target_update_freq`), the target network's parameters + are fully updated with the model's parameters. + """ + if self.use_target_network and self._iter % self.target_update_freq == 0: + self._update_lagged_network_weights() + self._iter += 1 + + +class DQN( + QLearningOffPolicyAlgorithm[TDQNPolicy, TDQNTrainingStats], + Generic[TDQNPolicy, TDQNTrainingStats], +): + """Implementation of Deep Q Network. arXiv:1312.5602. + + Implementation of Double Q-Learning. arXiv:1509.06461. + + Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is + implemented in the network side, not here). + """ + + def __init__( + self, + *, + policy: TDQNPolicy, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + is_double: bool = True, + clip_loss_grad: bool = False, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer for the policy + :param discount_factor: in [0, 1]. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the frequency with which to update the weights of the target network; + 0 if a target network shall not be used. + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param is_double: use double dqn. + :param clip_loss_grad: clip the gradient of the loss in accordance + with nature14236; this amounts to using the Huber loss instead of + the MSE loss. + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + super().__init__( + policy=policy, + optim=optim, + lr_scheduler=lr_scheduler, + discount_factor=discount_factor, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + reward_normalization=reward_normalization, + ) + self.is_double = is_double + self.clip_loss_grad = clip_loss_grad + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + obs_next_batch = Batch( + obs=buffer[indices].obs_next, + info=[None] * len(indices), + ) # obs_next: s_{t+n} + result = self.policy(obs_next_batch) + if self.use_target_network: + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + target_q = self.policy(obs_next_batch, model=self.model_old).logits + else: + target_q = result.logits + if self.is_double: + return target_q[np.arange(len(result.act)), result.act] + # Nature DQN, over estimate + return target_q.max(dim=1)[0] + def _update_with_batch( self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any, ) -> TDQNTrainingStats: - if self._target and self._iter % self.freq == 0: - self._update_lagged_network_weights() + self._periodically_update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self.policy(batch).logits @@ -268,6 +330,5 @@ def _update_with_batch( batch.weight = td_error # prio-buffer loss.backward() self.optim.step() - self._iter += 1 return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 015d4f955..3d84b90eb 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -109,8 +109,6 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ @@ -126,10 +124,6 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. :param observation_space: Env's observation space. :param lr_scheduler: if not None, will be called in `policy.update()`. """ @@ -141,8 +135,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, lr_scheduler=lr_scheduler, ) self.ent_coef = ent_coef @@ -153,7 +145,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} - if self._target: + if self.use_target_network: result = self.policy(obs_next_batch) act, fractions = result.act, result.fractions next_dist = self.policy( @@ -171,8 +163,7 @@ def _update_with_batch( *args: Any, **kwargs: Any, ) -> TFQFTrainingStats: - if self._target and self._iter % self.freq == 0: - self._update_lagged_network_weights() + self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) out = self.policy(batch) curr_dist_orig = out.logits @@ -227,7 +218,6 @@ def _update_with_batch( self.optim.zero_grad() quantile_loss.backward() self.optim.step() - self._iter += 1 return FQFTrainingStats( # type: ignore[return-value] loss=quantile_loss.item() + fraction_entropy_loss.item(), diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index a1245e92d..0d0b71dfd 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -100,8 +100,6 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ @@ -115,10 +113,6 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( @@ -129,8 +123,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, lr_scheduler=lr_scheduler, ) @@ -140,8 +132,7 @@ def _update_with_batch( *args: Any, **kwargs: Any, ) -> TIQNTrainingStats: - if self._target and self._iter % self.freq == 0: - self._update_lagged_network_weights() + self._periodically_update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) action_batch = self.policy(batch) @@ -165,6 +156,5 @@ def _update_with_batch( batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer loss.backward() self.optim.step() - self._iter += 1 return IQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 9c10ef583..ca67b82fa 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -8,9 +8,12 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import DQN from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.dqn import DQNPolicy, DQNTrainingStats +from tianshou.policy.modelfree.dqn import ( + DQNPolicy, + DQNTrainingStats, + QLearningOffPolicyAlgorithm, +) @dataclass(kw_only=True) @@ -29,7 +32,10 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc TQRDQNPolicy = TypeVar("TQRDQNPolicy", bound=QRDQNPolicy) -class QRDQN(DQN[TQRDQNPolicy, TQRDQNTrainingStats], Generic[TQRDQNPolicy, TQRDQNTrainingStats]): +class QRDQN( + QLearningOffPolicyAlgorithm[TQRDQNPolicy, TQRDQNTrainingStats], + Generic[TQRDQNPolicy, TQRDQNTrainingStats], +): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.""" def __init__( @@ -42,8 +48,6 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - is_double: bool = True, - clip_loss_grad: bool = False, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ @@ -57,10 +61,6 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. :param lr_scheduler: if not None, will be called in `policy.update()`. """ assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" @@ -71,8 +71,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - is_double=is_double, - clip_loss_grad=clip_loss_grad, lr_scheduler=lr_scheduler, ) self.num_quantiles = num_quantiles @@ -88,7 +86,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs=buffer[indices].obs_next, info=[None] * len(indices), ) # obs_next: s_{t+n} - if self._target: + if self.use_target_network: act = self.policy(obs_next_batch).act next_dist = self.policy(obs_next_batch, model=self.model_old).logits else: @@ -103,8 +101,7 @@ def _update_with_batch( *args: Any, **kwargs: Any, ) -> TQRDQNTrainingStats: - if self._target and self._iter % self.freq == 0: - self._update_lagged_network_weights() + self._periodically_update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) curr_dist = self.policy(batch).logits @@ -124,6 +121,5 @@ def _update_with_batch( batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer loss.backward() self.optim.step() - self._iter += 1 return QRDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 6232045dc..5a93228b4 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -50,6 +50,6 @@ def _update_with_batch( **kwargs: Any, ) -> TRainbowTrainingStats: self._sample_noise(self.policy.model) - if self._target and self._sample_noise(self.model_old): + if self.use_target_network and self._sample_noise(self.model_old): self.model_old.train() # so that NoisyLinear takes effect return super()._update_with_batch(batch, **kwargs) From ed7a43af7bdfeea1e1adbc5d6cfe552a9fe1d5da Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 14 Mar 2025 14:51:46 +0100 Subject: [PATCH 056/230] v2: Adapt DiscreteCQL (using diamond inheritance to convert from off-policy to offline) and test_discrete_cql --- examples/offline/atari_cql.py | 4 +- test/offline/test_discrete_cql.py | 39 +++++++------- tianshou/policy/__init__.py | 2 +- tianshou/policy/base.py | 9 +++- tianshou/policy/imitation/discrete_cql.py | 62 ++++++++++------------- 5 files changed, 61 insertions(+), 55 deletions(-) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 48ce2faaf..28e394b18 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -17,7 +17,7 @@ from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteCQLPolicy +from tianshou.policy import DiscreteCQL from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils.space_info import SpaceInfo @@ -107,7 +107,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: DiscreteCQLPolicy = DiscreteCQLPolicy( + policy: DiscreteCQL = DiscreteCQL( model=net, optim=optim, action_space=env.action_space, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 2e7c5af70..142fffd4f 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -15,8 +15,9 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, DiscreteCQLPolicy -from tianshou.trainer import OfflineTrainer +from tianshou.policy import Algorithm, DiscreteCQL +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -79,10 +80,13 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DiscreteCQLPolicy = DiscreteCQLPolicy( + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + algorithm: DiscreteCQL = DiscreteCQL( + policy=policy, + optim=optim, discount_factor=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, @@ -101,7 +105,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: buffer = gather_data() # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_cql") writer = SummaryWriter(log_path) @@ -113,17 +117,18 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.update_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 21d7f0ebe..7be5f95ea 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -26,7 +26,7 @@ from tianshou.policy.imitation.cql import CQL from tianshou.policy.imitation.td3_bc import TD3BCPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQ -from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy +from tianshou.policy.imitation.discrete_cql import DiscreteCQL from tianshou.policy.imitation.discrete_crr import DiscreteCRR from tianshou.policy.imitation.gail import GAIL from tianshou.policy.modelbased.psrl import PSRL diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 83c0d2d13..0197ccdc6 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from tianshou.trainer.base import ( + InfoStats, OfflineTrainer, OfflineTrainingConfig, OffPolicyTrainer, @@ -790,7 +791,7 @@ def compute_nstep_return( def create_trainer(self, config: TTrainingConfig) -> "Trainer": pass - def run_training(self, config: TTrainingConfig): + def run_training(self, config: TTrainingConfig) -> "InfoStats": trainer = self.create_trainer(config) return trainer.run() @@ -826,6 +827,12 @@ def process_buffer(self, buffer: TBuffer) -> TBuffer: """Pre-process the replay buffer to prepare for offline learning, e.g. to add new keys.""" return buffer + def run_training(self, config: "OfflineTrainingConfig") -> "InfoStats": + # NOTE: This override is required for correct typing when converting + # an algorithm to an offline algorithm using diamond inheritance + # (e.g. DiscreteCQL) in order to make it match first in the MRO + return super().run_training(config) + def create_trainer(self, config: "OfflineTrainingConfig") -> "OfflineTrainer": from tianshou.trainer.base import OfflineTrainer diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 2a2836c1c..c34c06e79 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Any, TypeVar -import gymnasium as gym import numpy as np import torch import torch.nn.functional as F @@ -9,8 +8,8 @@ from tianshou.data import to_torch from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import QRDQN -from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats +from tianshou.policy.base import OfflineAlgorithm, TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats @dataclass(kw_only=True) @@ -22,54 +21,49 @@ class DiscreteCQLTrainingStats(QRDQNTrainingStats): TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteCQLTrainingStats) -class DiscreteCQLPolicy(QRDQN): - """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. - - :param model: a model following the rules (s_B -> action_values_BA) - :param optim: a torch.optim for optimizing the model. - :param action_space: Env's action space. - :param min_q_weight: the weight for the cql loss. - :param discount_factor: in [0, 1]. - :param num_quantiles: the number of quantile midpoints in the inverse - cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. - - .. seealso:: - Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed - explanation. - """ +# NOTE: This uses diamond inheritance to convert from off-policy to offline +class DiscreteCQL( + OfflineAlgorithm[QRDQNPolicy, TDiscreteCQLTrainingStats], + QRDQN[QRDQNPolicy, TDiscreteCQLTrainingStats], +): + """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.""" def __init__( self, *, - model: torch.nn.Module, + policy: QRDQNPolicy, optim: torch.optim.Optimizer, - action_space: gym.spaces.Discrete, min_q_weight: float = 10.0, discount_factor: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, ) -> None: - super().__init__( - model=model, + """ + :param policy: the policy + :param optim: a torch.optim for optimizing the model. + :param min_q_weight: the weight for the cql loss. + :param discount_factor: in [0, 1]. + :param num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value. + :param estimation_step: the number of steps to look ahead. + :param target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param reward_normalization: normalize the **returns** to Normal(0, 1). + TODO: rename to return_normalization? + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + QRDQN.__init__( + self, + policy=policy, optim=optim, - action_space=action_space, discount_factor=discount_factor, num_quantiles=num_quantiles, estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) self.min_q_weight = min_q_weight @@ -83,7 +77,7 @@ def _update_with_batch( self._periodically_update_lagged_network_weights() self.optim.zero_grad() weight = batch.pop("weight", 1.0) - all_dist = self(batch).logits + all_dist = self.policy(batch).logits act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) @@ -99,7 +93,7 @@ def _update_with_batch( # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer # add CQL loss - q = self.compute_q_value(all_dist, None) + q = self.policy.compute_q_value(all_dist, None) dataset_expec = q.gather(1, act.unsqueeze(1)).mean() negative_sampling = q.logsumexp(1).mean() min_q_loss = negative_sampling - dataset_expec From cec9a3c6ab22a5df467355ba36ce94ac436c13ef Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 14 Mar 2025 15:00:54 +0100 Subject: [PATCH 057/230] v2: Adapt TD3BC (using diamond inheritance to convert from off-policy to offline) and test_td3_bc --- examples/offline/d4rl_td3_bc.py | 4 +- test/offline/test_td3_bc.py | 48 ++++++++-------- tianshou/policy/__init__.py | 2 +- tianshou/policy/imitation/td3_bc.py | 89 ++++++++++++----------------- 4 files changed, 62 insertions(+), 81 deletions(-) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index a9cc7d8ce..ed58fb179 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3BCPolicy +from tianshou.policy import TD3BC from tianshou.policy.base import Algorithm from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger @@ -135,7 +135,7 @@ def test_td3_bc() -> None: critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy: TD3BCPolicy = TD3BCPolicy( + policy: TD3BC = TD3BC( actor=actor, policy_optim=actor_optim, critic=critic1, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 70aee4ba3..481809d2c 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -12,9 +12,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3BCPolicy +from tianshou.policy import TD3BC from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -127,8 +128,12 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy: TD3BCPolicy = TD3BCPolicy( + policy = DDPGPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: TD3BC = TD3BC( + policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, @@ -142,18 +147,17 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: noise_clip=args.noise_clip, alpha=args.alpha, estimation_step=args.n_step, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector # buffer has been gathered # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc' @@ -168,23 +172,19 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - # trainer - trainer = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - stop_fn=stop_fn, - logger=logger, + # train + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + stop_fn=stop_fn, + logger=logger, + ) ) - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - print(epoch_stat) - # print(info) - - assert stop_fn(epoch_stat.info_stat.best_reward) + assert stop_fn(result.best_reward) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 7be5f95ea..cab562423 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -24,7 +24,7 @@ from tianshou.policy.imitation.base import ImitationLearning from tianshou.policy.imitation.bcq import BCQ from tianshou.policy.imitation.cql import CQL -from tianshou.policy.imitation.td3_bc import TD3BCPolicy +from tianshou.policy.imitation.td3_bc import TD3BC from tianshou.policy.imitation.discrete_bcq import DiscreteBCQ from tianshou.policy.imitation.discrete_cql import DiscreteCQL from tianshou.policy.imitation.discrete_crr import DiscreteCRR diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index efde31a8c..cb634e690 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from typing import Any, Literal, TypeVar +from typing import Any, TypeVar -import gymnasium as gym import torch import torch.nn.functional as F @@ -9,7 +8,8 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import TD3 -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import OfflineAlgorithm, TLearningRateScheduler +from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.modelfree.td3 import TD3TrainingStats @@ -21,51 +21,17 @@ class TD3BCTrainingStats(TD3TrainingStats): TTD3BCTrainingStats = TypeVar("TTD3BCTrainingStats", bound=TD3BCTrainingStats) -class TD3BCPolicy(TD3[TTD3BCTrainingStats]): - """Implementation of TD3+BC. arXiv:2106.06860. - - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> actions) - :param policy_optim: the optimizer for actor network. - :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. - :param action_space: Env's action space. Should be gym.spaces.Box. - :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). - :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. - :param exploration_noise: add noise to action for exploration. - This is useful when solving "hard exploration" problems. - "default" is equivalent to GaussianNoise(sigma=0.1). - :param policy_noise: the noise used in updating policy network. - :param update_actor_freq: the update frequency of actor network. - :param noise_clip: the clipping range used in updating policy network. - :param alpha: the value of alpha, which controls the weight for TD3 learning - relative to behavior cloning. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ +# NOTE: This uses diamond inheritance to convert from off-policy to offline +class TD3BC(OfflineAlgorithm[DDPGPolicy, TTD3BCTrainingStats], TD3[TTD3BCTrainingStats]): + """Implementation of TD3+BC. arXiv:2106.06860.""" def __init__( self, *, - actor: torch.nn.Module, + policy: DDPGPolicy, policy_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, - action_space: gym.Space, critic2: torch.nn.Module | None = None, critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, @@ -77,29 +43,44 @@ def __init__( # TODO: same name as alpha in SAC and REDQ, which also inherit from DDPGPolicy. Rename? alpha: float = 2.5, estimation_step: int = 1, - observation_space: gym.Space | None = None, - action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", lr_scheduler: TLearningRateScheduler | None = None, ) -> None: - super().__init__( - actor=actor, + """ + :param policy: the policy + :param policy_optim: the optimizer for policy. + :param critic: the first critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer for the first critic network. + :param critic2: the second critic network. (s, a -> Q(s, a)). + If None, use the same network as critic (via deepcopy). + :param critic2_optim: the optimizer for the second critic network. + If None, clone critic_optim to use for critic2.parameters(). + :param tau: param for soft update of the target network. + :param gamma: discount factor, in [0, 1]. + :param exploration_noise: add noise to action for exploration. + This is useful when solving "hard exploration" problems. + "default" is equivalent to GaussianNoise(sigma=0.1). + :param policy_noise: the noise used in updating policy network. + :param update_actor_freq: the update frequency of actor network. + :param noise_clip: the clipping range used in updating policy network. + :param alpha: the value of alpha, which controls the weight for TD3 learning + relative to behavior cloning. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate + in optimizer in each policy.update() + """ + TD3.__init__( + self, + policy=policy, policy_optim=policy_optim, critic=critic, critic_optim=critic_optim, - action_space=action_space, critic2=critic2, critic2_optim=critic2_optim, tau=tau, gamma=gamma, - exploration_noise=exploration_noise, policy_noise=policy_noise, noise_clip=noise_clip, update_actor_freq=update_actor_freq, estimation_step=estimation_step, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - observation_space=observation_space, lr_scheduler=lr_scheduler, ) self.alpha = alpha @@ -116,14 +97,14 @@ def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: # actor if self._cnt % self.update_actor_freq == 0: - act = self(batch, eps=0.0).act + act = self.policy(batch, eps=0.0).act q_value = self.critic(batch.obs, act) lmbda = self.alpha / q_value.abs().mean().detach() actor_loss = -lmbda * q_value.mean() + F.mse_loss(act, to_torch_as(batch.act, act)) - self.actor_optim.zero_grad() + self.policy_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() - self.actor_optim.step() + self.policy_optim.step() self._update_lagged_network_weights() self._cnt += 1 From 00badd4154f9ac550d2788673c8dd18d718e32f9 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 14 Mar 2025 16:16:39 +0100 Subject: [PATCH 058/230] v2: Fix Algorithm updating/learning interface * Functions update and _update_with_batch (formerly learn) no longer have *args/**kwargs * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated --- CHANGELOG.md | 6 +- tianshou/policy/base.py | 122 ++++++++++++++-------- tianshou/policy/imitation/base.py | 2 - tianshou/policy/imitation/bcq.py | 2 - tianshou/policy/imitation/cql.py | 4 +- tianshou/policy/imitation/discrete_bcq.py | 2 - tianshou/policy/imitation/discrete_cql.py | 4 +- tianshou/policy/imitation/discrete_crr.py | 4 +- tianshou/policy/imitation/gail.py | 5 +- tianshou/policy/imitation/td3_bc.py | 4 +- tianshou/policy/modelbased/icm.py | 13 +-- tianshou/policy/modelbased/psrl.py | 10 +- tianshou/policy/modelfree/a2c.py | 8 +- tianshou/policy/modelfree/bdqn.py | 2 - tianshou/policy/modelfree/c51.py | 4 +- tianshou/policy/modelfree/ddpg.py | 2 +- tianshou/policy/modelfree/discrete_sac.py | 2 +- tianshou/policy/modelfree/dqn.py | 2 - tianshou/policy/modelfree/fqf.py | 2 - tianshou/policy/modelfree/iqn.py | 2 - tianshou/policy/modelfree/npg.py | 1 - tianshou/policy/modelfree/pg.py | 2 - tianshou/policy/modelfree/ppo.py | 7 +- tianshou/policy/modelfree/qrdqn.py | 4 +- tianshou/policy/modelfree/rainbow.py | 6 +- tianshou/policy/modelfree/redq.py | 2 +- tianshou/policy/modelfree/sac.py | 2 +- tianshou/policy/modelfree/td3.py | 2 +- tianshou/policy/modelfree/trpo.py | 3 +- tianshou/trainer/base.py | 33 +++--- 30 files changed, 130 insertions(+), 134 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b8af73d0..fca38ef19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,7 +53,11 @@ * `PGPolicy` -> `Reinforce` * `DQNPolicy` -> `DQN` * `DDPGPolicy` -> `DDPG` - * The `Algorithm` abstraction can directly initiate the learning process via method `run_training`. + * Interface changes/improvements: + * The updating interface has been cleaned up: + * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. + * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. + * New method `run_training`: The `Algorithm` abstraction can now directly initiate the learning process via this method. * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 0197ccdc6..2101c8b0c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -518,34 +518,6 @@ def process_fn( """ return batch - @abstractmethod - def _update_with_batch( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TTrainingStats: - """Update policy with a given batch of data. - - :return: A dataclass object, including the data needed to be logged (e.g., loss). - - .. note:: - - In order to distinguish the collecting state, updating state and - testing state, you can check the policy state by ``self.training`` - and ``self.updating``. Please refer to :ref:`policy_state` for more - detailed explanation. - - .. warning:: - - If you use ``torch.distributions.Normal`` and - ``torch.distributions.Categorical`` to calculate the log_prob, - please be careful about the shape: Categorical distribution gives - "[batch_size]" shape while Normal distribution gives "[batch_size, - 1]" shape. The auto-broadcasting of numerical operation with torch - tensors will amplify this error. - """ - def post_process_fn( self, batch: BatchProtocol, @@ -570,11 +542,11 @@ def post_process_fn( "Prioritized replay is disabled for this batch.", ) - def update( + def _update( self, sample_size: int | None, buffer: ReplayBuffer | None, - **kwargs: Any, + update_with_batch_fn: Callable[[RolloutBatchProtocol], TTrainingStats], ) -> TTrainingStats: """Update the policy network and replay buffer. @@ -588,15 +560,12 @@ def update( :param sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. None also means it will extract all the data from the buffer, but it will be shuffled - first. TODO: remove the option for 0? + first. :param buffer: the corresponding replay buffer. :return: A dataclass object containing the data needed to be logged (e.g., loss) from ``policy.learn()``. """ - # TODO: when does this happen? - # -> this happens never in practice as update is either called with a collector buffer or an assert before - if not self.policy.is_within_training_step: raise RuntimeError( f"update() was called outside of a training step as signalled by {self.policy.is_within_training_step=} " @@ -611,7 +580,7 @@ def update( self.updating = True batch = self.process_fn(batch, buffer, indices) with torch_train_mode(self): - training_stat = self._update_with_batch(batch, **kwargs) + training_stat = update_with_batch_fn(batch) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: self.lr_scheduler.step() @@ -806,6 +775,25 @@ def create_trainer(self, config: "OnPolicyTrainingConfig") -> "OnPolicyTrainer": return OnPolicyTrainer(self, config) + @abstractmethod + def _update_with_batch( + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int + ) -> TTrainingStats: + pass + + def update( + self, + buffer: ReplayBuffer, + batch_size: int, + repeat: int, + ) -> TTrainingStats: + update_with_batch_fn = lambda batch: self._update_with_batch( + batch=batch, batch_size=batch_size, repeat=repeat + ) + return super()._update( + sample_size=0, buffer=buffer, update_with_batch_fn=update_with_batch_fn + ) + class OffPolicyAlgorithm( Algorithm[TPolicy, "OffPolicyTrainingConfig", TTrainingStats], @@ -817,6 +805,29 @@ def create_trainer(self, config: "OffPolicyTrainingConfig") -> "OffPolicyTrainer return OffPolicyTrainer(self, config) + @abstractmethod + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> TTrainingStats: + """Performs an update step based on the given batch of data, updating the network + parameters. + + :param batch: the batch of data + :return: a dataclas object containing statistics on the learning process, including + the data needed to be logged (e.g. loss values). + """ + + def update( + self, + buffer: ReplayBuffer, + sample_size: int | None, + ) -> TTrainingStats: + update_with_batch_fn = lambda batch: self._update_with_batch(batch) + return super()._update( + sample_size=sample_size, buffer=buffer, update_with_batch_fn=update_with_batch_fn + ) + class OfflineAlgorithm( Algorithm[TPolicy, "OfflineTrainingConfig", TTrainingStats], @@ -838,6 +849,29 @@ def create_trainer(self, config: "OfflineTrainingConfig") -> "OfflineTrainer": return OfflineTrainer(self, config) + @abstractmethod + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> TTrainingStats: + """Performs an update step based on the given batch of data, updating the network + parameters. + + :param batch: the batch of data + :return: a dataclas object containing statistics on the learning process, including + the data needed to be logged (e.g. loss values). + """ + + def update( + self, + buffer: ReplayBuffer, + sample_size: int | None, + ) -> TTrainingStats: + update_with_batch_fn = lambda batch: self._update_with_batch(batch) + return super()._update( + sample_size=sample_size, buffer=buffer, update_with_batch_fn=update_with_batch_fn + ) + TWrappedAlgorthmTrainingStats = TypeVar("TWrappedAlgorthmTrainingStats", bound=TrainingStats) @@ -874,13 +908,12 @@ def post_process_fn( self.wrapped_algorithm.post_process_fn(batch, buffer, indices) def _update_with_batch( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TWrappedAlgorthmTrainingStats: + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int + ) -> TTrainingStats: """Performs the update as defined by the wrapped algorithm.""" - return self.wrapped_algorithm._update_with_batch(batch, **kwargs) + return self.wrapped_algorithm._update_with_batch( + batch, batch_size=batch_size, repeat=repeat + ) class OffPolicyWrapperAlgorithm( @@ -914,14 +947,13 @@ def post_process_fn( """Performs the batch post-processing as defined by the wrapped algorithm.""" self.wrapped_algorithm.post_process_fn(batch, buffer, indices) + @abstractmethod def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TWrappedAlgorthmTrainingStats: + ) -> TTrainingStats: """Performs the update as defined by the wrapped algorithm.""" - return self.wrapped_algorithm._update_with_batch(batch, **kwargs) + return self.wrapped_algorithm._update_with_batch(batch) class RandomActionPolicy(Policy): diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index eb6882440..51ebd92e1 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -118,8 +118,6 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - *ags: Any, - **kwargs: Any, ) -> TImitationTrainingStats: self.optim.zero_grad() if self.policy.action_type == "continuous": # regression diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index e8c6778f2..3f171c0c8 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -159,8 +159,6 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TBCQTrainingStats: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index eb58b59c8..12cf8577e 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -1,6 +1,6 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Any, TypeVar, cast +from typing import TypeVar, cast import numpy as np import torch @@ -205,7 +205,7 @@ def process_buffer(self, buffer: TBuffer) -> TBuffer: ) return buffer - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TCQLTrainingStats: # type: ignore device = torch_device(self.policy) batch: Batch = to_torch(batch, dtype=torch.float, device=device) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 5cce77fc4..d13eca89b 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -192,8 +192,6 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TDiscreteBCQTrainingStats: if self._iter % self.freq == 0: self._update_lagged_network_weights() diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index c34c06e79..c4703dcff 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, TypeVar +from typing import TypeVar import numpy as np import torch @@ -71,8 +71,6 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TDiscreteCQLTrainingStats: self._periodically_update_lagged_network_weights() self.optim.zero_grad() diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 0b5898d1a..e8da15491 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Literal, TypeVar +from typing import Literal, TypeVar import numpy as np import torch @@ -112,8 +112,6 @@ def process_fn( def _update_with_batch( # type: ignore self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TDiscreteCRRTrainingStats: if self._target and self._iter % self._freq == 0: self._update_lagged_network_weights() diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 4b0846719..4b3d1829d 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, TypeVar +from typing import TypeVar import numpy as np import torch @@ -138,7 +138,6 @@ def _update_with_batch( # type: ignore batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - **kwargs: Any, ) -> TGailTrainingStats: # update discriminator losses = [] @@ -159,7 +158,7 @@ def _update_with_batch( # type: ignore acc_pis.append((logits_pi < 0).float().mean().item()) acc_exps.append((logits_exp > 0).float().mean().item()) # update policy - ppo_loss_stat = super()._update_with_batch(batch, batch_size, repeat, **kwargs) + ppo_loss_stat = super()._update_with_batch(batch, batch_size, repeat) disc_losses_summary = SequenceSummaryStats.from_sequence(losses) acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index cb634e690..5877df467 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, TypeVar +from typing import TypeVar import torch import torch.nn.functional as F @@ -85,7 +85,7 @@ def __init__( ) self.alpha = alpha - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TTD3BCTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( batch, self.critic, self.critic_optim diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 7de507396..a192d972b 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -1,5 +1,3 @@ -from typing import Any - import numpy as np import torch import torch.nn.functional as F @@ -156,10 +154,8 @@ def post_process_fn( def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> ICMTrainingStats: - wrapped_stats = super()._update_with_batch(batch, *args, **kwargs) + wrapped_stats = super()._update_with_batch(batch) return self._icm_update(batch, wrapped_stats) @@ -220,10 +216,7 @@ def post_process_fn( self._icm_postprocess_batch(batch) def _update_with_batch( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> ICMTrainingStats: - wrapped_stats = super()._update_with_batch(batch, *args, **kwargs) + wrapped_stats = super()._update_with_batch(batch, batch_size=batch_size, repeat=repeat) return self._icm_update(batch, wrapped_stats) diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 987a41f56..f1eae2102 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -237,11 +237,13 @@ def __init__( self._add_done_loop = add_done_loop def _update_with_batch( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> TPSRLTrainingStats: + # NOTE: In contrast to other on-policy algorithms, this algorithm ignores + # the batch_size and repeat arguments. + # PSRL, being a Bayesian approach, updates its posterior distribution of + # the MDP parameters based on the collected transition data as a whole, + # rather than performing gradient-based updates that benefit from mini-batching. n_s, n_a = self.policy.model.n_state, self.policy.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index e67b04181..6f46d862b 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,6 +1,6 @@ from abc import ABC from dataclasses import dataclass -from typing import Any, Generic, TypeVar, cast +from typing import Generic, TypeVar, cast import numpy as np import torch @@ -169,15 +169,11 @@ def process_fn( batch.act = to_torch_as(batch.act, batch.v_s) return batch - # TODO: mypy complains b/c signature is different from superclass, although - # it's compatible. Can this be fixed? - def _update_with_batch( # type: ignore + def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - *args: Any, - **kwargs: Any, ) -> TA2CTrainingStats: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index bee04a729..f8818750d 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -185,8 +185,6 @@ def process_fn( def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TBDQNTrainingStats: self._periodically_update_lagged_network_weights() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index fdbbb9521..22934f561 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Generic, TypeVar import gymnasium as gym import numpy as np @@ -126,8 +126,6 @@ def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TC51TrainingStats: self._periodically_update_lagged_network_weights() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 43a21ddf2..cea89ce86 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -341,7 +341,7 @@ def _target_q_compute_action(self, obs_batch: Batch) -> ActBatchProtocol: # compute the action using the lagged actor network return self.policy(obs_batch, model=self.actor_old) - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDDPGTrainingStats: # type: ignore # critic td, critic_loss = self._minimize_critic_squared_loss(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index be634a0fd..fbbd3cbe5 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -133,7 +133,7 @@ def _target_q_compute_value( ) return target_q.sum(dim=-1) + self.alpha.value * dist.entropy() - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDiscreteSACTrainingStats: # type: ignore weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 886e61dd9..a42ba1fea 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -309,8 +309,6 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TDQNTrainingStats: self._periodically_update_lagged_network_weights() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 3d84b90eb..259fd8e3e 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -160,8 +160,6 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TFQFTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 0d0b71dfd..b2d668d0e 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -129,8 +129,6 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TIQNTrainingStats: self._periodically_update_lagged_network_weights() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 0f516534d..f2aa4d1a2 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -100,7 +100,6 @@ def _update_with_batch( # type: ignore batch: Batch, batch_size: int | None, repeat: int, - **kwargs: Any, ) -> TNPGTrainingStats: actor_losses, vf_losses, kls = [], [], [] split_batch_size = batch_size or -1 diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 43937fdf2..81f7796f5 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -295,8 +295,6 @@ def _update_with_batch( # type: ignore batch: BatchWithReturnsProtocol, batch_size: int | None, repeat: int, - *args: Any, - **kwargs: Any, ) -> TPGTrainingStats: losses = [] split_batch_size = batch_size or -1 diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 86e8ecb5a..3d7ec4f4a 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Generic, Self, TypeVar, cast +from typing import Generic, Self, TypeVar, cast import numpy as np import torch @@ -142,14 +142,11 @@ def process_fn( batch.logp_old = torch.cat(logp_old, dim=0).flatten() return cast(LogpOldProtocol, batch) - # TODO: why does mypy complain? - def _update_with_batch( # type: ignore + def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - *args: Any, - **kwargs: Any, ) -> TPPOTrainingStats: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] gradient_steps = 0 diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index ca67b82fa..28bad5ac6 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -1,6 +1,6 @@ import warnings from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Generic, TypeVar import numpy as np import torch @@ -98,8 +98,6 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TQRDQNTrainingStats: self._periodically_update_lagged_network_weights() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 5a93228b4..3e95bcf97 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, TypeVar +from typing import TypeVar from torch import nn @@ -46,10 +46,8 @@ def _sample_noise(model: nn.Module) -> bool: def _update_with_batch( self, batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, ) -> TRainbowTrainingStats: self._sample_noise(self.policy.model) if self.use_target_network and self._sample_noise(self.model_old): self.model_old.train() # so that NoisyLinear takes effect - return super()._update_with_batch(batch, **kwargs) + return super()._update_with_batch(batch) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 8d774e99a..07d247b4d 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -187,7 +187,7 @@ def _target_q_compute_value( return target_q - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TREDQTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TREDQTrainingStats: # type: ignore # critic ensemble weight = getattr(batch, "weight", 1.0) current_qs = self.critic(batch.obs, batch.act).flatten(1) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 98508e06a..2e65cd9b1 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -268,7 +268,7 @@ def _target_q_compute_value( min_q_value = super()._target_q_compute_value(obs_batch, act_batch) return min_q_value - self.alpha.value * act_batch.log_prob - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TSACTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( batch, self.critic, self.critic_optim diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index e8398c289..9998d89de 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -176,7 +176,7 @@ def _target_q_compute_action(self, obs_batch: Batch) -> ActStateBatchProtocol: act_batch.act = act_ return act_batch - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TTD3TrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( batch, self.critic, self.critic_optim diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index f84433f85..cc68ace02 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -1,6 +1,6 @@ import warnings from dataclasses import dataclass -from typing import Any, TypeVar +from typing import TypeVar import torch import torch.nn.functional as F @@ -84,7 +84,6 @@ def _update_with_batch( # type: ignore batch: Batch, batch_size: int | None, repeat: int, - **kwargs: Any, ) -> TTRPOTrainingStats: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] split_batch_size = batch_size or -1 diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index b85536df3..cebaee91c 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -28,7 +28,7 @@ from collections.abc import Callable from dataclasses import asdict, dataclass from functools import partial -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import Generic, TypeVar import numpy as np import tqdm @@ -46,7 +46,13 @@ ) from tianshou.data.buffer.base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase -from tianshou.policy.base import TrainingStats +from tianshou.policy.base import ( + Algorithm, + OfflineAlgorithm, + OffPolicyAlgorithm, + OnPolicyAlgorithm, + TrainingStats, +) from tianshou.utils import ( BaseLogger, LazyLogger, @@ -55,9 +61,6 @@ from tianshou.utils.logging import set_numerical_fields_to_precision from tianshou.utils.torch_utils import policy_within_training_step -if TYPE_CHECKING: - from tianshou.policy import Algorithm - log = logging.getLogger(__name__) @@ -317,9 +320,10 @@ class OfflineTrainingConfig(TrainingConfig): TTrainingConfig = TypeVar("TTrainingConfig", bound=TrainingConfig) TOnlineTrainingConfig = TypeVar("TOnlineTrainingConfig", bound=OnlineTrainingConfig) +TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) -class Trainer(Generic[TTrainingConfig], ABC): +class Trainer(Generic[TAlgorithm, TTrainingConfig], ABC): """ Base class for trainers in Tianshou, which orchestrate the training process and call upon an RL algorithm's specific network updating logic to perform the actual gradient updates. @@ -330,7 +334,7 @@ class Trainer(Generic[TTrainingConfig], ABC): def __init__( self, - policy: "Algorithm", + policy: TAlgorithm, config: TTrainingConfig, ): self.algorithm = policy @@ -704,7 +708,7 @@ def run( return self._create_info_stats() -class OfflineTrainer(Trainer[OfflineTrainingConfig]): +class OfflineTrainer(Trainer[OfflineAlgorithm, OfflineTrainingConfig]): """An offline trainer, which samples mini-batches from a given buffer and passes them to the algorithm's update function. """ @@ -757,7 +761,9 @@ def _create_epoch_pbar_data_dict( return {} -class OnlineTrainer(Trainer[TOnlineTrainingConfig], Generic[TOnlineTrainingConfig], ABC): +class OnlineTrainer( + Trainer[TAlgorithm, TOnlineTrainingConfig], Generic[TAlgorithm, TOnlineTrainingConfig], ABC +): """ An online trainer, which collects data from the environment in each training step and uses the collected data to perform an update step, the nature of which is to be defined @@ -954,7 +960,7 @@ def _create_epoch_pbar_data_dict( return result -class OffPolicyTrainer(OnlineTrainer[OffPolicyTrainingConfig]): +class OffPolicyTrainer(OnlineTrainer[OffPolicyAlgorithm, OffPolicyTrainingConfig]): """An off-policy trainer, which samples mini-batches from the buffer of collected data and passes them to algorithm's `update` function. @@ -1004,7 +1010,7 @@ def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: return update_stat -class OnPolicyTrainer(OnlineTrainer[OnPolicyTrainingConfig]): +class OnPolicyTrainer(OnlineTrainer[OnPolicyAlgorithm, OnPolicyTrainingConfig]): """An on-policy trainer, which passes the entire buffer to the algorithm's `update` methods and resets the buffer thereafter. @@ -1023,12 +1029,7 @@ def _update_step( f"Performing on-policy update on buffer of length {len(self.config.train_collector.buffer)}", ) training_stat = self.algorithm.update( - sample_size=0, buffer=self.config.train_collector.buffer, - # Note: sample_size is None, so the whole buffer is used for the update. - # The kwargs are in the end passed to the .learn method, which uses - # batch_size to iterate through the buffer in mini-batches - # Off-policy algos typically don't use the batch_size kwarg at all batch_size=self.config.batch_size, repeat=self.config.repeat_per_collect, ) From 20231bd6b57d394427caf0ea6702861f59268743 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 14 Mar 2025 18:04:17 +0100 Subject: [PATCH 059/230] v2: Adapt multi-agent RL algorithms * MultiAgentPolicyManager is replaced by MultiAgent[On|Off]PolicyAlgorithm in conjunction with helper class MARLDispatcher and MultiAgentPolicy * Adapt examples/tests: tic_tac_toe, pistonball, pistonball_continuous * Remove agent_id attribute of Algorithm (formerly BasePolicy), as this can be entirely internal to MARL algorithms * MARLRandomPolicy is replaced by MARLRandomDiscreteMaskedOffPolicyAlgorithm --- CHANGELOG.md | 12 +- test/pettingzoo/pistonball.py | 64 ++--- test/pettingzoo/pistonball_continuous.py | 67 ++--- test/pettingzoo/tic_tac_toe.py | 100 ++++---- tianshou/env/pettingzoo_env.py | 2 +- tianshou/policy/__init__.py | 4 +- tianshou/policy/base.py | 4 - tianshou/policy/multiagent/mapolicy.py | 296 ++++++++++++++--------- tianshou/policy/random.py | 69 +++--- 9 files changed, 360 insertions(+), 258 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fca38ef19..b8363a14b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,10 +49,13 @@ * `Policy` and `Algorithm` abstractions (formerly unified in `BasePolicy`): * We now conceptually differentiate between the learning algorithm and the policy being optimised: * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`. - Migration information (`BasePolicy` -> `Algorithm`): + Migration information: The instantiation of a policy is replaced by the instantiation of an `Algorithm`, + which is passed a `Policy`. In most cases, the former policy class name `Policy` is replaced by algorithm + class ``; exceptions are noted below. * `PGPolicy` -> `Reinforce` - * `DQNPolicy` -> `DQN` - * `DDPGPolicy` -> `DDPG` + * `MultiAgentPolicyManager` -> `MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm` + * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` + For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. * Interface changes/improvements: * The updating interface has been cleaned up: * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. @@ -70,6 +73,9 @@ making the codebase more consistent while preserving the original functionality. * Introduced a policy base class `ContinuousPolicyWithExplorationNoise` which encapsulates noise generation for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). + * Multi-agent RL methods are now differentiated by the type of the sub-algorithms being employed + (`MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm`), which renders all interfaces clean. + Helper class `MARLDispatcher` has been factored out to manage the dispatching of data to the respective agents. * Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOnPolicyAlgorithm` diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 1849717eb..ce6f743bf 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -11,8 +11,9 @@ from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import DQN, Algorithm, MultiAgentPolicyManager -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy import DQN, Algorithm, MultiAgentOffPolicyAlgorithm +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -97,10 +98,13 @@ def get_agents( device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent: DQN = DQN( + policy = DQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + agent: DQN = DQN( + policy=policy, + optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, @@ -108,7 +112,7 @@ def get_agents( agents.append(agent) optims.append(optim) - policy = MultiAgentPolicyManager(policies=agents, env=env) + policy = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) return policy, optims, env.agents @@ -125,16 +129,16 @@ def train_agent( train_envs.seed(args.seed) test_envs.seed(args.seed) - policy, optim, agents = get_agents(args, agents=agents, optims=optims) + marl_algorithm, optim, agents = get_agents(args, agents=agents, optims=optims) # collector train_collector = Collector[CollectStats]( - policy, + marl_algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -150,35 +154,35 @@ def stop_fn(mean_rewards: float) -> bool: return False def train_fn(epoch: int, env_step: int) -> None: - [agent.set_eps(args.eps_train) for agent in policy.policies.values()] + [agent.set_eps(args.eps_train) for agent in marl_algorithm.policy.policies.values()] def test_fn(epoch: int, env_step: int | None) -> None: - [agent.set_eps(args.eps_test) for agent in policy.policies.values()] + [agent.set_eps(args.eps_test) for agent in marl_algorithm.policy.policies.values()] def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - update_per_step=args.update_per_step, - logger=logger, - test_in_train=False, - reward_metric=reward_metric, - ).run() - - return result, policy + result = marl_algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric, + ) + ) + return result, marl_algorithm def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None) -> None: diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 9db48cc3f..471aa1671 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -15,8 +15,10 @@ from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import PPO, Algorithm, MultiAgentPolicyManager -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy import PPO, Algorithm +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.continuous import ActorProb, Critic @@ -186,11 +188,17 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - agent: PPO = PPO( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_space=env.action_space, + action_scaling=True, + action_bound_method="clip", + ) + agent: PPO = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -203,17 +211,14 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # dual clip cause monotonically increasing log_std :) value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space, ) agents.append(agent) optims.append(optim) - policy = MultiAgentPolicyManager( - policies=agents, + policy = MultiAgentOnPolicyAlgorithm( + algorithms=agents, env=env, - action_scaling=True, - action_bound_method="clip", ) return policy, optims, env.agents @@ -231,16 +236,16 @@ def train_agent( train_envs.seed(args.seed) test_envs.seed(args.seed) - policy, optim, agents = get_agents(args, agents=agents, optims=optims) + marl_algorithm, optim, agents = get_agents(args, agents=agents, optims=optims) # collector train_collector = Collector[CollectStats]( - policy, + marl_algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=False, # True ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](marl_algorithm, test_envs) # train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") @@ -257,24 +262,26 @@ def stop_fn(mean_rewards: float) -> bool: def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - resume_from_log=args.resume, - ).run() - - return result, policy + # train + result = marl_algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + episode_per_collect=args.episode_per_collect, + step_per_collect=None, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + resume_from_log=args.resume, + ) + ) + + return result, marl_algorithm def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None) -> None: diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index cf53760dd..fbc5ca6ab 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -15,11 +15,12 @@ from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import ( DQN, - BasePolicy, - MARLRandomPolicy, - MultiAgentPolicyManager, + Algorithm, + MARLRandomDiscreteMaskedOffPolicyAlgorithm, + MultiAgentOffPolicyAlgorithm, ) -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -98,10 +99,10 @@ def get_args() -> argparse.Namespace: def get_agents( args: argparse.Namespace = get_args(), - agent_learn: BasePolicy | None = None, - agent_opponent: BasePolicy | None = None, + agent_learn: Algorithm | None = None, + agent_opponent: Algorithm | None = None, optim: torch.optim.Optimizer | None = None, -) -> tuple[BasePolicy, torch.optim.Optimizer | None, list]: +) -> tuple[MultiAgentOffPolicyAlgorithm, torch.optim.Optimizer | None, list]: env = get_env() observation_space = ( env.observation_space.spaces["observation"] @@ -120,10 +121,13 @@ def get_agents( ).to(args.device) if optim is None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent_learn = DQN( + algorithm = DQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + agent_learn = DQN( + policy=algorithm, + optim=optim, estimation_step=args.n_step, discount_factor=args.gamma, target_update_freq=args.target_update_freq, @@ -136,22 +140,24 @@ def get_agents( agent_opponent = deepcopy(agent_learn) agent_opponent.load_state_dict(torch.load(args.opponent_path)) else: - agent_opponent = MARLRandomPolicy(action_space=env.action_space) + agent_opponent = MARLRandomDiscreteMaskedOffPolicyAlgorithm( + action_space=env.action_space + ) if args.agent_id == 1: agents = [agent_learn, agent_opponent] else: agents = [agent_opponent, agent_learn] - policy = MultiAgentPolicyManager(policies=agents, env=env) - return policy, optim, env.agents + algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) + return algorithm, optim, env.agents def train_agent( args: argparse.Namespace = get_args(), - agent_learn: BasePolicy | None = None, - agent_opponent: BasePolicy | None = None, + agent_learn: Algorithm | None = None, + agent_opponent: Algorithm | None = None, optim: torch.optim.Optimizer | None = None, -) -> tuple[InfoStats, BasePolicy]: +) -> tuple[InfoStats, Algorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -160,7 +166,7 @@ def train_agent( train_envs.seed(args.seed) test_envs.seed(args.seed) - policy, optim, agents = get_agents( + marl_algorithm, optim, agents = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent, @@ -169,13 +175,12 @@ def train_agent( # collector train_collector = Collector[CollectStats]( - policy, + marl_algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -184,56 +189,59 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy: BasePolicy) -> None: + player_agent_id = agents[args.agent_id - 1] + + def save_best_fn(policy: Algorithm) -> None: if hasattr(args, "model_save_path"): model_save_path = args.model_save_path else: model_save_path = os.path.join(args.logdir, "tic_tac_toe", "dqn", "policy.pth") - torch.save(policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path) + torch.save(policy.get_algorithm(player_agent_id).state_dict(), model_save_path) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.win_rate def train_fn(epoch: int, env_step: int) -> None: - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train) + marl_algorithm.get_algorithm(player_agent_id).policy.set_eps(args.eps_train) def test_fn(epoch: int, env_step: int | None) -> None: - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + marl_algorithm.get_algorithm(player_agent_id).policy.set_eps(args.eps_test) def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, args.agent_id - 1] # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - update_per_step=args.update_per_step, - logger=logger, - test_in_train=False, - reward_metric=reward_metric, - ).run() - - return result, policy.policies[agents[args.agent_id - 1]] + result = marl_algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric, + ) + ) + + return result, marl_algorithm.get_algorithm(player_agent_id) def watch( args: argparse.Namespace = get_args(), - agent_learn: BasePolicy | None = None, - agent_opponent: BasePolicy | None = None, + agent_learn: Algorithm | None = None, + agent_opponent: Algorithm | None = None, ) -> None: env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + policy.algorithms[agents[args.agent_id - 1]].policy.set_eps(args.eps_test) collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render, reset_before_collect=True) result.pprint_asdict() diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 196741453..2e8287ba9 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -18,7 +18,7 @@ class PettingZooEnv(AECEnv, ABC): - """The interface for petting zoo environments. + """The interface for petting zoo environments which support multi-agent RL. Multi-agent environments must be wrapped as :class:`~tianshou.env.PettingZooEnv`. Here is the usage: diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index cab562423..a12b3fc50 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -6,7 +6,7 @@ from tianshou.policy.modelfree.dqn import DQN from tianshou.policy.modelfree.ddpg import DDPG -from tianshou.policy.random import MARLRandomPolicy +from tianshou.policy.random import MARLRandomDiscreteMaskedOffPolicyAlgorithm from tianshou.policy.modelfree.bdqn import BDQN from tianshou.policy.modelfree.c51 import C51 from tianshou.policy.modelfree.rainbow import RainbowDQN @@ -32,4 +32,4 @@ from tianshou.policy.modelbased.psrl import PSRL from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager +from tianshou.policy.multiagent.mapolicy import MultiAgentOffPolicyAlgorithm diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 2101c8b0c..b3bbba386 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -497,10 +497,6 @@ def __setstate__(self, state: dict[str, Any]) -> None: state["is_within_training_step"] = False self.__dict__ = state - def set_agent_id(self, agent_id: int) -> None: - """Set self.agent_id = agent_id, for MARL.""" - self.agent_id = agent_id - def process_fn( self, batch: RolloutBatchProtocol, diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 4df09473b..f610d433d 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,13 +1,22 @@ -from typing import Any, Literal, Protocol, Self, TypeVar, cast, overload +from collections.abc import Callable +from typing import Any, Generic, Literal, Protocol, Self, TypeVar, cast, overload import numpy as np from overrides import override +from sensai.util.helper import mark_used +from torch.nn import ModuleList from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import Algorithm -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import ( + OffPolicyAlgorithm, + OnPolicyAlgorithm, + Policy, + TLearningRateScheduler, + TrainingStats, +) try: from tianshou.env.pettingzoo_env import PettingZooEnv @@ -15,6 +24,9 @@ PettingZooEnv = None # type: ignore +mark_used(ActBatchProtocol) + + class MapTrainingStats(TrainingStats): def __init__( self, @@ -63,107 +75,20 @@ def __getitem__(self, index: str | IndexType) -> Any: ... -class MultiAgentPolicyManager(Algorithm): - """Multi-agent policy manager for MARL. - - This multi-agent policy manager accepts a list of - :class:`~tianshou.policy.BasePolicy`. It dispatches the batch data to each - of these policies when the "forward" is called. The same as "process_fn" - and "learn": it splits the data and feeds them to each policy. A figure in - :ref:`marl_example` can help you better understand this procedure. - - :param policies: a list of policies. - :param env: a PettingZooEnv. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. - """ - - def __init__( - self, - *, - policies: list[Algorithm], - # TODO: 1 why restrict to PettingZooEnv? - # TODO: 2 This is the only policy that takes an env in init, is it really needed? - env: PettingZooEnv, - action_scaling: bool = False, - action_bound_method: Literal["clip", "tanh"] | None = "clip", - lr_scheduler: TLearningRateScheduler | None = None, - ) -> None: +class MultiAgentPolicy(Policy): + def __init__(self, policies: dict[str | int, Policy]): + p0 = next(iter(policies.values())) super().__init__( - action_space=env.action_space, - observation_space=env.observation_space, - action_scaling=action_scaling, - action_bound_method=action_bound_method, - lr_scheduler=lr_scheduler, + action_space=p0.action_space, + observation_space=p0.observation_space, + action_scaling=False, + action_bound_method=None, ) - assert len(policies) == len(env.agents), "One policy must be assigned for each agent." - - self.agent_idx = env.agent_idx - for i, policy in enumerate(policies): - # agent_id 0 is reserved for the environment proxy - # (this MultiAgentPolicyManager) - policy.set_agent_id(env.agents[i]) - - self.policies: dict[str | int, Algorithm] = dict(zip(env.agents, policies, strict=True)) - """Maps agent_id to policy.""" - - # TODO: unused - remove it? - def replace_policy(self, policy: Algorithm, agent_id: int) -> None: - """Replace the "agent_id"th policy in this manager.""" - policy.set_agent_id(agent_id) - self.policies[agent_id] = policy - - # TODO: violates Liskov substitution principle - def process_fn( # type: ignore - self, - batch: MAPRolloutBatchProtocol, - buffer: ReplayBuffer, - indice: np.ndarray, - ) -> MAPRolloutBatchProtocol: - """Dispatch batch data from `obs.agent_id` to every policy's process_fn. - - Save original multi-dimensional rew in "save_rew", set rew to the - reward of each agent during their "process_fn", and restore the - original reward afterwards. - """ - # TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol - results: dict[str | int, RolloutBatchProtocol] = {} - assert isinstance( - batch.obs, - BatchProtocol, - ), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" - # reward can be empty Batch (after initial reset) or nparray. - has_rew = isinstance(buffer.rew, np.ndarray) - if has_rew: # save the original reward in save_rew - # Since we do not override buffer.__setattr__, here we use _meta to - # change buffer.rew, otherwise buffer.rew = Batch() has no effect. - save_rew, buffer._meta.rew = buffer.rew, Batch() # type: ignore - for agent, policy in self.policies.items(): - agent_index = np.nonzero(batch.obs.agent_id == agent)[0] - if len(agent_index) == 0: - results[agent] = cast(RolloutBatchProtocol, Batch()) - continue - tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] - if has_rew: - tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]] - buffer._meta.rew = save_rew[:, self.agent_idx[agent]] - if not hasattr(tmp_batch.obs, "mask"): - if hasattr(tmp_batch.obs, "obs"): - tmp_batch.obs = tmp_batch.obs.obs - if hasattr(tmp_batch.obs_next, "obs"): - tmp_batch.obs_next = tmp_batch.obs_next.obs - results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice) - if has_rew: # restore from save_rew - buffer._meta.rew = save_rew - return cast(MAPRolloutBatchProtocol, Batch(results)) + self.policies = policies + self._submodules = ModuleList(policies.values()) _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") - # TODO: Move to policy - # @override def add_exploration_noise( self, act: _TArrOrActBatch, @@ -260,29 +185,176 @@ def forward( # type: ignore holder["state"] = state_dict return holder - # Violates Liskov substitution principle - def _update_with_batch( # type: ignore + +TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) + + +class MARLDispatcher(Generic[TAlgorithm]): + """ + Supports multi-agent learning by dispatching calls to the corresponding + algorithm for each agent. + """ + + def __init__(self, algorithms: list[TAlgorithm], env: PettingZooEnv): + agent_ids = env.agents + assert len(algorithms) == len(agent_ids), "One policy must be assigned for each agent." + self.algorithms: dict[str | int, TAlgorithm] = dict(zip(agent_ids, algorithms, strict=True)) + """maps agent_id to the corresponding algorithm.""" + self.agent_idx = env.agent_idx + """maps agent_id to 0-based index.""" + + def create_policy(self) -> MultiAgentPolicy: + return MultiAgentPolicy({agent_id: a.policy for agent_id, a in self.algorithms.items()}) + + def dispatch_process_fn( # type: ignore self, batch: MAPRolloutBatchProtocol, - *args: Any, - **kwargs: Any, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> MAPRolloutBatchProtocol: + """Dispatch batch data from `obs.agent_id` to every algorithm's processing function. + + Save original multi-dimensional rew in "save_rew", set rew to the + reward of each agent during their "process_fn", and restore the + original reward afterwards. + """ + # TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol + results: dict[str | int, RolloutBatchProtocol] = {} + assert isinstance( + batch.obs, + BatchProtocol, + ), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" + # reward can be empty Batch (after initial reset) or nparray. + has_rew = isinstance(buffer.rew, np.ndarray) + if has_rew: # save the original reward in save_rew + # Since we do not override buffer.__setattr__, here we use _meta to + # change buffer.rew, otherwise buffer.rew = Batch() has no effect. + save_rew, buffer._meta.rew = buffer.rew, Batch() # type: ignore + for agent, algorithm in self.algorithms.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent)[0] + if len(agent_index) == 0: + results[agent] = cast(RolloutBatchProtocol, Batch()) + continue + tmp_batch, tmp_indice = batch[agent_index], indices[agent_index] + if has_rew: + tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]] + buffer._meta.rew = save_rew[:, self.agent_idx[agent]] + if not hasattr(tmp_batch.obs, "mask"): + if hasattr(tmp_batch.obs, "obs"): + tmp_batch.obs = tmp_batch.obs.obs + if hasattr(tmp_batch.obs_next, "obs"): + tmp_batch.obs_next = tmp_batch.obs_next.obs + results[agent] = algorithm.process_fn(tmp_batch, buffer, tmp_indice) + if has_rew: # restore from save_rew + buffer._meta.rew = save_rew + return cast(MAPRolloutBatchProtocol, Batch(results)) + + def dispatch_update_with_batch( # type: ignore + self, + batch: MAPRolloutBatchProtocol, + algorithm_update_with_batch_fn: Callable[[TAlgorithm, RolloutBatchProtocol], TrainingStats], ) -> MapTrainingStats: - """Dispatch the data to all policies for learning. + """Dispatch the respective subset of the batch data to each algorithm. :param batch: must map agent_ids to rollout batches + :param algorithm_update_with_batch_fn: a function that performs the algorithm-specific + update with the given agent-specific batch data """ agent_id_to_stats = {} - for agent_id, policy in self.policies.items(): + for agent_id, algorithm in self.algorithms.items(): data = batch[agent_id] if len(data.get_keys()) != 0: - train_stats = policy._update_with_batch(batch=data, **kwargs) + train_stats = algorithm_update_with_batch_fn(algorithm, data) agent_id_to_stats[agent_id] = train_stats return MapTrainingStats(agent_id_to_stats) - # Need a train method that set all sub-policies to train mode. - # No need for a similar eval function, as eval internally uses the train function. - def train(self, mode: bool = True) -> Self: - """Set each internal policy in training mode.""" - for policy in self.policies.values(): - policy.train(mode) - return self + +class MultiAgentOffPolicyAlgorithm(OffPolicyAlgorithm[MultiAgentPolicy, MapTrainingStats]): + """Multi-agent reinforcement learning where each agent uses off-policy learning.""" + + def __init__( + self, + *, + algorithms: list[OffPolicyAlgorithm], + env: PettingZooEnv, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param algorithms: a list of off-policy algorithms. + :param env: the multi-agent RL environment + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + self._dispatcher: MARLDispatcher[OffPolicyAlgorithm] = MARLDispatcher(algorithms, env) + super().__init__( + policy=self._dispatcher.create_policy(), + lr_scheduler=lr_scheduler, + ) + self._submodules = ModuleList(algorithms) + + def get_algorithm(self, agent_id: str | int) -> OffPolicyAlgorithm: + return self._dispatcher.algorithms[agent_id] + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + batch = cast(MAPRolloutBatchProtocol, batch) + return self._dispatcher.dispatch_process_fn(batch, buffer, indices) + + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> MapTrainingStats: + batch = cast(MAPRolloutBatchProtocol, batch) + + def update(algorithm: OffPolicyAlgorithm, data: RolloutBatchProtocol) -> TrainingStats: + return algorithm._update_with_batch(data) + + return self._dispatcher.dispatch_update_with_batch(batch, update) + + +class MultiAgentOnPolicyAlgorithm(OnPolicyAlgorithm[MultiAgentPolicy, MapTrainingStats]): + """Multi-agent reinforcement learning where each agent uses on-policy learning.""" + + def __init__( + self, + *, + algorithms: list[OnPolicyAlgorithm], + env: PettingZooEnv, + lr_scheduler: TLearningRateScheduler | None = None, + ) -> None: + """ + :param algorithms: a list of off-policy algorithms. + :param env: the multi-agent RL environment + :param lr_scheduler: if not None, will be called in `policy.update()`. + """ + self._dispatcher: MARLDispatcher[OnPolicyAlgorithm] = MARLDispatcher(algorithms, env) + super().__init__( + policy=self._dispatcher.create_policy(), + lr_scheduler=lr_scheduler, + ) + self._submodules = ModuleList(algorithms) + + def get_algorithm(self, agent_id: str | int) -> OnPolicyAlgorithm: + return self._dispatcher.algorithms[agent_id] + + def process_fn( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol: + batch = cast(MAPRolloutBatchProtocol, batch) + return self._dispatcher.dispatch_process_fn(batch, buffer, indices) + + def _update_with_batch( + self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int + ) -> MapTrainingStats: + batch = cast(MAPRolloutBatchProtocol, batch) + + def update(algorithm: OnPolicyAlgorithm, data: RolloutBatchProtocol) -> TrainingStats: + return algorithm._update_with_batch(data, batch_size, repeat) + + return self._dispatcher.dispatch_update_with_batch(batch, update) diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index c20f8419c..675c570d3 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -1,12 +1,13 @@ -from typing import Any, TypeVar, cast +from typing import TypeVar, cast +import gymnasium as gym import numpy as np from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import Algorithm -from tianshou.policy.base import TrainingStats +from tianshou.policy import base +from tianshou.policy.base import OffPolicyAlgorithm, TrainingStats class MARLRandomTrainingStats(TrainingStats): @@ -16,39 +17,47 @@ class MARLRandomTrainingStats(TrainingStats): TMARLRandomTrainingStats = TypeVar("TMARLRandomTrainingStats", bound=MARLRandomTrainingStats) -class MARLRandomPolicy(Algorithm): +class MARLRandomDiscreteMaskedOffPolicyAlgorithm(OffPolicyAlgorithm): """A random agent used in multi-agent learning. - It randomly chooses an action from the legal action. + It randomly chooses an action from the legal actions (according to the given mask). """ - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> ActBatchProtocol: - """Compute the random action over the given batch data. + class Policy(base.Policy): + """A random agent used in multi-agent learning. - The input should contain a mask in batch.obs, with "True" to be - available and "False" to be unavailable. For example, - ``batch.obs.mask == np.array([[False, True, False]])`` means with batch - size 1, action "1" is available but action "0" and "2" are unavailable. - - :return: A :class:`~tianshou.data.Batch` with "act" key, containing - the random action. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. + It randomly chooses an action from the legal actions. """ - mask = batch.obs.mask # type: ignore - logits = np.random.rand(*mask.shape) - logits[~mask] = -np.inf - result = Batch(act=logits.argmax(axis=-1)) - return cast(ActBatchProtocol, result) - def _update_with_batch(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TMARLRandomTrainingStats: # type: ignore + def __init__(self, action_space: gym.spaces.Space) -> None: + super().__init__(action_space=action_space) + + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: dict, + ) -> ActBatchProtocol: + """Compute the random action over the given batch data. + + The input should contain a mask in batch.obs, with "True" to be + available and "False" to be unavailable. For example, + ``batch.obs.mask == np.array([[False, True, False]])`` means with batch + size 1, action "1" is available but action "0" and "2" are unavailable. + + :return: A :class:`~tianshou.data.Batch` with "act" key, containing + the random action. + """ + mask = batch.obs.mask # type: ignore + logits = np.random.rand(*mask.shape) + logits[~mask] = -np.inf + result = Batch(act=logits.argmax(axis=-1)) + return cast(ActBatchProtocol, result) + + def __init__(self, action_space: gym.spaces.Space) -> None: + """:param action_space: the environment's action space.""" + super().__init__(policy=self.Policy(action_space)) + + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TMARLRandomTrainingStats: # type: ignore """Since a random agent learns nothing, it returns an empty dict.""" return MARLRandomTrainingStats() # type: ignore[return-value] From bd98c3fc4083339dba6baf9c9bff3e7fb11dcbf6 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 14 Mar 2025 20:27:56 +0100 Subject: [PATCH 060/230] v2: Adapt Atari examples --- examples/atari/atari_c51.py | 16 +++--- examples/atari/atari_dqn.py | 91 ++++++++++++++++++-------------- examples/atari/atari_fqf.py | 72 +++++++++++++++---------- examples/atari/atari_iqn.py | 68 ++++++++++++++---------- examples/atari/atari_ppo.py | 73 ++++++++++++------------- examples/atari/atari_qrdqn.py | 66 ++++++++++++++--------- examples/atari/atari_rainbow.py | 74 +++++++++++++++----------- examples/atari/atari_sac.py | 80 ++++++++++++++++------------ tianshou/policy/modelfree/c51.py | 3 +- tianshou/policy/modelfree/dqn.py | 3 +- tianshou/policy/modelfree/fqf.py | 3 +- tianshou/policy/modelfree/iqn.py | 3 +- 12 files changed, 319 insertions(+), 233 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index c4b52ed44..9af7e8a7b 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -3,11 +3,9 @@ import os import pprint import sys -from typing import cast import numpy as np import torch -from gym.spaces import Discrete from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import C51Net @@ -86,13 +84,15 @@ def main(args: argparse.Namespace = get_args()) -> None: # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model net = C51Net(*args.state_shape, args.action_shape, args.num_atoms, args.device) + + # define policy and algorithm optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy policy = C51Policy( model=net, - action_space=cast(Discrete, env.action_space), + action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, @@ -104,10 +104,12 @@ def main(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) - # load a previous policy + + # load a previous model if args.resume_path: algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -117,7 +119,8 @@ def main(args: argparse.Namespace = get_args()) -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - # collector + + # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) @@ -165,7 +168,6 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 07b82ac77..3dc77215c 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -6,15 +6,16 @@ import numpy as np import torch -from atari_wrapper import make_atari_env -from examples.atari.atari_network import DQNet from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DQN from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -94,27 +95,34 @@ def main(args: argparse.Namespace = get_args()) -> None: ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model net = DQNet(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: DQN | ICMOffPolicyWrapper - policy = DQN( + + # define policy and algorithm + policy = DQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + algorithm: DQN | ICMOffPolicyWrapper + algorithm = DQN( + policy=policy, + optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) if args.icm_lr_scale > 0: - feature_net = DQN(*args.state_shape, args.action_shape, args.device, features_only=True) + feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( @@ -125,19 +133,20 @@ def main(args: argparse.Namespace = get_args()) -> None: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMOffPolicyWrapper( - wrapped_algorithm=policy, + algorithm = ICMOffPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, - action_space=env.action_space, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, ).to(args.device) - # load a previous policy + + # load a previous model if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -147,9 +156,10 @@ def main(args: argparse.Namespace = get_args()) -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -198,10 +208,9 @@ def test_fn(epoch: int, env_step: int | None) -> None: def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") - torch.save({"model": policy.state_dict()}, ckpt_path) + torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path - # watch agent's performance def watch() -> None: print("Setup test envs ...") policy.set_eps(args.eps_test) @@ -215,7 +224,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -233,26 +244,28 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - resume_from_log=args.resume_id is not None, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index ff1a81ccf..fa61605ab 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -13,7 +13,8 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import FQF from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.fqf import FQFPolicy +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -80,12 +81,15 @@ def main(args: argparse.Namespace = get_args()) -> None: ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) net = FullQuantileFunction( @@ -98,23 +102,29 @@ def main(args: argparse.Namespace = get_args()) -> None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) - # define policy - policy: FQF = FQF( + + # define policy and algorithm + policy = FQFPolicy( model=net, - optim=optim, fraction_model=fraction_net, - fraction_optim=fraction_optim, action_space=env.action_space, + ) + algorithm: FQF = FQF( + policy=policy, + optim=optim, + fraction_optim=fraction_optim, discount_factor=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -124,9 +134,10 @@ def main(args: argparse.Namespace = get_args()) -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + + # collectors + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -172,7 +183,6 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") policy.set_eps(args.eps_test) @@ -186,7 +196,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -204,24 +216,26 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 5c9b4a386..9999c1550 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -13,7 +13,8 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import IQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.iqn import IQNPolicy +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -83,9 +84,11 @@ def main(args: argparse.Namespace = get_args()) -> None: # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) net = ImplicitQuantileNetwork( @@ -96,22 +99,28 @@ def main(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: IQN = IQN( + + # define policy and algorithm + policy = IQNPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=args.gamma, sample_size=args.sample_size, online_sample_size=args.online_sample_size, target_sample_size=args.target_sample_size, + ) + algorithm: IQN = IQN( + policy=policy, + optim=optim, + discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) - # load a previous policy + + # load previous model if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -122,8 +131,8 @@ def main(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -183,7 +192,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -201,25 +212,26 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 6493979d0..3c407f522 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -6,7 +6,6 @@ import numpy as np import torch -from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, VectorReplayBuffer @@ -16,7 +15,8 @@ from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelfree.pg import DiscreteActorPolicy +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -132,24 +132,22 @@ def main(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - # define policy - def dist(logits: torch.Tensor) -> Categorical: - return Categorical(logits=logits) - - policy: PPO = PPO( + # define algorithm + policy = DiscreteActorPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: PPO = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, - action_scaling=False, lr_scheduler=lr_scheduler, - action_space=env.action_space, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, @@ -168,8 +166,8 @@ def dist(logits: torch.Tensor) -> Categorical: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMOnPolicyWrapper( # type: ignore[no-redef] - wrapped_algorithm=policy, + algorithm = ICMOnPolicyWrapper( # type: ignore[no-redef] + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=args.icm_lr_scale, @@ -178,7 +176,7 @@ def dist(logits: torch.Tensor) -> Categorical: ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -190,8 +188,8 @@ def dist(logits: torch.Tensor) -> Categorical: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -227,10 +225,9 @@ def stop_fn(mean_rewards: float) -> bool: def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") - torch.save({"model": policy.state_dict()}, ckpt_path) + torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path - # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) @@ -243,7 +240,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -261,24 +260,26 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - resume_from_log=args.resume_id is not None, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + + # train + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 0cdd9a8a9..52398fb58 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -13,7 +13,8 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.trainer import OffPolicyTrainingConfig def get_args() -> argparse.Namespace: @@ -75,12 +76,15 @@ def main(args: argparse.Namespace = get_args()) -> None: ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model c, h, w = args.state_shape net = QRDQNet( @@ -91,20 +95,25 @@ def main(args: argparse.Namespace = get_args()) -> None: num_quantiles=args.num_quantiles, device=args.device, ) + + # define policy and algorithm optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: QRDQN = QRDQN( + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + algorithm: QRDQN = QRDQN( + policy=policy, + optim=optim, discount_factor=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -115,9 +124,10 @@ def main(args: argparse.Namespace = get_args()) -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) + # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -177,7 +187,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -195,24 +207,26 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 9fb0ec7b6..8caa9fd50 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -13,11 +13,13 @@ PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) +from tianshou.env.atari.atari_network import Rainbow from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51, RainbowDQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.trainer import OffPolicyTrainingConfig def get_args() -> argparse.Namespace: @@ -93,11 +95,13 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model - net = RainbowDQN( + net = Rainbow( *args.state_shape, args.action_shape, args.num_atoms, @@ -106,23 +110,29 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: is_dueling=not args.no_dueling, is_noisy=not args.no_noisy, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: C51 = RainbowDQN( + + # define policy and algorithm + policy = C51Policy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + ) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + algorithm: C51 = RainbowDQN( + policy=policy, + optim=optim, + discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer @@ -145,9 +155,10 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: beta=args.beta, weight_norm=not args.no_weight_norm, ) + # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -201,7 +212,6 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") policy.set_eps(args.eps_test) @@ -217,7 +227,9 @@ def watch() -> None: alpha=args.alpha, beta=args.beta, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -235,24 +247,26 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index b50f1c6de..f3407db63 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -13,7 +13,9 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteSAC, ICMOffPolicyWrapper from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.modelfree.sac import AutoAlpha +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -96,12 +98,15 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # define model net = DQNet( *args.state_shape, @@ -117,22 +122,24 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: critic2 = Critic(net, last_size=args.action_shape, device=args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - # define policy + # define policy and algorithm if args.auto_alpha: target_entropy = 0.98 * np.log(np.prod(args.action_shape)) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) - - policy: DiscreteSAC | ICMOffPolicyWrapper - policy = DiscreteSAC( + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) + algorithm: DiscreteSAC | ICMOffPolicyWrapper + policy = DiscreteSACPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm = DiscreteSAC( + policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, critic2=critic2, critic2_optim=critic2_optim, - action_space=env.action_space, tau=args.tau, gamma=args.gamma, alpha=args.alpha, @@ -150,18 +157,20 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.actor_lr) - policy = ICMOffPolicyWrapper( - wrapped_algorithm=policy, + algorithm = ICMOffPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=args.icm_lr_scale, reward_scale=args.icm_reward_scale, forward_loss_weight=args.icm_forward_loss_weight, ).to(args.device) - # load a previous policy + + # load a previous model if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( @@ -172,8 +181,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -209,10 +218,9 @@ def stop_fn(mean_rewards: float) -> bool: def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") - torch.save({"model": policy.state_dict()}, ckpt_path) + torch.save({"model": algorithm.state_dict()}, ckpt_path) return ckpt_path - # watch agent's performance def watch() -> None: print("Setup test envs ...") test_envs.seed(args.seed) @@ -225,7 +233,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -243,24 +253,26 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - resume_from_log=args.resume_id is not None, - save_checkpoint_fn=save_checkpoint_fn, - ).run() + + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, + ) + ) pprint.pprint(result) watch() diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 22934f561..9999140d5 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -28,7 +28,7 @@ class C51Policy(DQNPolicy): def __init__( self, model: torch.nn.Module | Net, - action_space: gym.spaces.Discrete, + action_space: gym.spaces.Space, observation_space: gym.Space | None = None, num_atoms: int = 51, v_min: float = -10.0, @@ -43,6 +43,7 @@ def __init__( :param v_max: the value of the largest atom in the support set. Default to 10.0. """ + assert isinstance(action_space, gym.spaces.Discrete) super().__init__( model=model, action_space=action_space, observation_space=observation_space ) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index a42ba1fea..52efe6da9 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -44,7 +44,7 @@ def __init__( self, *, model: TModel, - action_space: gym.spaces.Discrete, + action_space: gym.spaces.Space, observation_space: gym.Space | None = None, ) -> None: """ @@ -52,6 +52,7 @@ def __init__( :param action_space: the environment's action space :param observation_space: the environment's observation space. """ + assert isinstance(action_space, gym.spaces.Discrete) super().__init__( action_space=action_space, observation_space=observation_space, diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 259fd8e3e..b886c74b9 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -31,7 +31,7 @@ def __init__( *, model: FullQuantileFunction, fraction_model: FractionProposalNetwork, - action_space: gym.spaces.Discrete, + action_space: gym.spaces.Space, observation_space: gym.Space | None = None, ): """ @@ -41,6 +41,7 @@ def __init__( :param action_space: the environment's action space :param observation_space: the environment's observation space. """ + assert isinstance(action_space, gym.spaces.Discrete) super().__init__( model=model, action_space=action_space, diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index b2d668d0e..160a5980b 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -31,12 +31,13 @@ def __init__( self, *, model: torch.nn.Module, - action_space: gym.spaces.Discrete, + action_space: gym.spaces.Space, sample_size: int = 32, online_sample_size: int = 8, target_sample_size: int = 8, observation_space: gym.Space | None = None, ) -> None: + assert isinstance(action_space, gym.spaces.Discrete) assert sample_size > 1, f"sample_size should be greater than 1 but got: {sample_size}" assert ( online_sample_size > 1 From a82555c0bf323e96c60c516e3899a841bb814b01 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 15 Mar 2025 00:32:00 +0100 Subject: [PATCH 061/230] v2: Use OptimizerFactory instances instead of torch.optim.Optimizer instances in all algorithms * Add OptimizerFactory as well as the necessary specializations to implement our use cases * Replace all optimizer arguments in Algorithm classes and in all tests (but not the examples) * Add LRSchedulerFactory and allow OptimizerFactory to simultaneously handle LRScheduler creation for created optimizers, such that the lr_scheduler parameter could be removed from all algorithms * The class MultipleLRSchedulers is now deprecated, as Algorithm now manages an explicit list of LR schedulers (as this is more convenient); module utils.lr_scheduler will thus be subject to deletion in the future * Update tests and previously adapted examples * Remove obsolete helper function clone_optimizer --- CHANGELOG.md | 8 +- examples/atari/atari_c51.py | 3 +- examples/atari/atari_dqn.py | 11 +- examples/atari/atari_fqf.py | 5 +- examples/atari/atari_iqn.py | 3 +- examples/atari/atari_ppo.py | 20 ++-- examples/atari/atari_qrdqn.py | 3 +- examples/atari/atari_rainbow.py | 3 +- examples/atari/atari_sac.py | 9 +- test/continuous/test_ddpg.py | 11 +- test/continuous/test_npg.py | 10 +- test/continuous/test_ppo.py | 10 +- test/continuous/test_redq.py | 5 +- test/continuous/test_sac_with_il.py | 14 +-- test/continuous/test_td3.py | 7 +- test/continuous/test_trpo.py | 3 +- test/discrete/test_a2c_with_il.py | 7 +- test/discrete/test_bdqn.py | 3 +- test/discrete/test_c51.py | 15 ++- test/discrete/test_discrete_sac.py | 7 +- test/discrete/test_dqn.py | 15 ++- test/discrete/test_fqf.py | 17 +-- test/discrete/test_iqn.py | 15 ++- test/discrete/test_pg.py | 10 +- test/discrete/test_qrdqn.py | 12 +- test/discrete/test_rainbow.py | 17 +-- test/modelbased/test_dqn_icm.py | 15 ++- test/modelbased/test_ppo_icm.py | 5 +- test/modelbased/test_psrl.py | 8 +- test/offline/test_bcq.py | 14 ++- test/offline/test_cql.py | 5 +- test/offline/test_discrete_bcq.py | 12 +- test/offline/test_discrete_cql.py | 7 +- test/offline/test_discrete_crr.py | 13 +- test/offline/test_gail.py | 9 +- test/offline/test_td3_bc.py | 18 +-- test/pettingzoo/tic_tac_toe.py | 5 +- tianshou/policy/base.py | 37 +++--- tianshou/policy/imitation/base.py | 8 +- tianshou/policy/imitation/bcq.py | 34 +++--- tianshou/policy/imitation/cql.py | 23 ++-- tianshou/policy/imitation/discrete_bcq.py | 9 +- tianshou/policy/imitation/discrete_cql.py | 8 +- tianshou/policy/imitation/discrete_crr.py | 10 +- tianshou/policy/imitation/gail.py | 15 +-- tianshou/policy/imitation/td3_bc.py | 15 +-- tianshou/policy/modelbased/icm.py | 20 ++-- tianshou/policy/modelbased/psrl.py | 3 - tianshou/policy/modelfree/a2c.py | 26 ++-- tianshou/policy/modelfree/bdqn.py | 8 +- tianshou/policy/modelfree/c51.py | 11 +- tianshou/policy/modelfree/ddpg.py | 18 ++- tianshou/policy/modelfree/discrete_sac.py | 11 +- tianshou/policy/modelfree/dqn.py | 18 ++- tianshou/policy/modelfree/fqf.py | 19 +-- tianshou/policy/modelfree/iqn.py | 7 +- tianshou/policy/modelfree/npg.py | 13 +- tianshou/policy/modelfree/pg.py | 9 +- tianshou/policy/modelfree/ppo.py | 12 +- tianshou/policy/modelfree/qrdqn.py | 7 +- tianshou/policy/modelfree/redq.py | 11 +- tianshou/policy/modelfree/sac.py | 13 +- tianshou/policy/modelfree/td3.py | 33 ++--- tianshou/policy/modelfree/trpo.py | 9 +- tianshou/policy/multiagent/mapolicy.py | 6 - tianshou/policy/optim.py | 139 ++++++++++++++++++++++ tianshou/utils/lr_scheduler.py | 1 + tianshou/utils/optim.py | 42 ------- 68 files changed, 507 insertions(+), 442 deletions(-) create mode 100644 tianshou/policy/optim.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b8363a14b..450bd2957 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,7 +60,13 @@ * The updating interface has been cleaned up: * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. - * New method `run_training`: The `Algorithm` abstraction can now directly initiate the learning process via this method. + * New method `run_training`: The `Algorithm` abstraction can now directly initiate the learning process via this method. + * `Algorithms` no longer require `torch.optim.Optimizer` instances and instead require `OptimizerFactory` + instances, which create the actual optimizers internally. + The new `OptimizerFactory` abstraction simultaneously handles the creation of learning rate schedulers + for the optimizers created (via method `with_lr_scheduler_factory` and accompanying factory abstraction + `LRSchedulerFactory`). + The parameter `lr_scheduler` has thus been removed from all algorithm constructors. * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 9af7e8a7b..4b8eb5364 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -14,6 +14,7 @@ from tianshou.policy import C51 from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig @@ -89,7 +90,7 @@ def main(args: argparse.Namespace = get_args()) -> None: net = C51Net(*args.state_shape, args.action_shape, args.num_atoms, args.device) # define policy and algorithm - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = C51Policy( model=net, action_space=env.action_space, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 3dc77215c..85b48bc2c 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -15,6 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -106,7 +107,7 @@ def main(args: argparse.Namespace = get_args()) -> None: # define model net = DQNet(*args.state_shape, args.action_shape, args.device).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm policy = DQNPolicy( @@ -126,13 +127,13 @@ def main(args: argparse.Namespace = get_args()) -> None: action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, + feature_net=feature_net.net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=[512], device=args.device, ) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + icm_optim = AdamOptimizerFactory(lr=args.lr) algorithm = ICMOffPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index fa61605ab..31014e599 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -14,6 +14,7 @@ from tianshou.policy import FQF from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.fqf import FQFPolicy +from tianshou.policy.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -99,9 +100,9 @@ def main(args: argparse.Namespace = get_args()) -> None: args.num_cosines, device=args.device, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) - fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) + fraction_optim = RMSpropOptimizerFactory(lr=args.fraction_lr) # define policy and algorithm policy = FQFPolicy( diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 9999c1550..c662ab146 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -14,6 +14,7 @@ from tianshou.policy import IQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.iqn import IQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -98,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None: num_cosines=args.num_cosines, device=args.device, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm policy = IQNPolicy( diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 3c407f522..cc82f589c 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -6,7 +6,6 @@ import numpy as np import torch -from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet, layer_init, scale_obs @@ -16,8 +15,8 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.policy.modelfree.pg import DiscreteActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainingConfig -from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -123,14 +122,16 @@ def main(args: argparse.Namespace = get_args()) -> None: net = scale_obs(net) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5) + optim = AdamOptimizerFactory(lr=args.lr, eps=1e-5) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + num_epochs=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + ) + ) # define algorithm policy = DiscreteActorPolicy( @@ -147,7 +148,6 @@ def main(args: argparse.Namespace = get_args()) -> None: vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, - lr_scheduler=lr_scheduler, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, @@ -165,7 +165,7 @@ def main(args: argparse.Namespace = get_args()) -> None: hidden_sizes=[args.hidden_size], device=args.device, ) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + icm_optim = AdamOptimizerFactory(lr=args.lr) algorithm = ICMOnPolicyWrapper( # type: ignore[no-redef] wrapped_algorithm=algorithm, model=icm_net, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 52398fb58..579387b23 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -14,6 +14,7 @@ from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainingConfig @@ -97,7 +98,7 @@ def main(args: argparse.Namespace = get_args()) -> None: ) # define policy and algorithm - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = QRDQNPolicy( model=net, action_space=env.action_space, diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 8caa9fd50..151b03d8f 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -19,6 +19,7 @@ from tianshou.policy import C51, RainbowDQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainingConfig @@ -119,7 +120,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: v_min=args.v_min, v_max=args.v_max, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) algorithm: C51 = RainbowDQN( policy=policy, optim=optim, diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index f3407db63..646308393 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -15,6 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.modelfree.sac import AutoAlpha +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -116,11 +117,11 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: output_dim_added_layer=args.hidden_size, ) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) critic1 = Critic(net, last_size=args.action_shape, device=args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = Critic(net, last_size=args.action_shape, device=args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # define policy and algorithm if args.auto_alpha: @@ -156,7 +157,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: hidden_sizes=[args.hidden_size], device=args.device, ) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.actor_lr) + icm_optim = AdamOptimizerFactory(lr=args.actor_lr) algorithm = ICMOffPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 1f4b11474..797479448 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -12,6 +12,7 @@ from tianshou.policy import DDPG from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -62,16 +63,15 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( @@ -85,13 +85,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic = Critic(net, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, ) - policy_optim = torch.optim.Adam(policy.parameters(), lr=args.actor_lr) + policy_optim = AdamOptimizerFactory(lr=args.actor_lr) algorithm: DDPG = DDPG( policy=policy, policy_optim=policy_optim, @@ -101,6 +101,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: gamma=args.gamma, estimation_step=args.n_step, ) + # collector train_collector = Collector[CollectStats]( algorithm, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 4879afec6..966c30a7d 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -14,6 +14,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -67,16 +68,15 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net( args.state_shape, @@ -94,12 +94,12 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: ), device=args.device, ).to(args.device) + # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(critic.parameters(), lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -116,7 +116,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: algorithm: NPG[NPGTrainingStats] = NPG( policy=policy, critic=critic, - optim=optim, + optim=AdamOptimizerFactory(lr=args.lr), discount_factor=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index cd9e02e73..601fcaaae 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -12,6 +12,7 @@ from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net @@ -72,16 +73,15 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) @@ -95,7 +95,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -168,7 +168,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: else: print("Fail to restore policy and optim.") - # trainer + # train result = algorithm.run_training( OnPolicyTrainingConfig( train_collector=train_collector, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 739bdba42..972a7bcad 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -12,6 +12,7 @@ from tianshou.policy import REDQ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.redq import REDQPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net @@ -87,7 +88,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) def linear(x: int, y: int) -> nn.Module: return EnsembleLinear(args.ensemble_size, x, y) @@ -103,7 +104,7 @@ def linear(x: int, y: int) -> nn.Module: critic = Critic(net_c, device=args.device, linear_layer=linear, flatten_input=False).to( args.device, ) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 3ad6f0ad3..5870899a8 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -12,6 +12,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -76,16 +77,17 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # you can also use tianshou.env.SubprocVectorEnv + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed + args.training_num) + # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -94,7 +96,7 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -103,15 +105,13 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) - policy = SACPolicy( actor=actor, action_space=env.action_space, @@ -180,7 +180,7 @@ def stop_fn(mean_rewards: float) -> bool: max_action=args.max_action, device=args.device, ).to(args.device) - optim = torch.optim.Adam(il_actor.parameters(), lr=args.il_lr) + optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( actor=il_actor, action_space=env.action_space, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index e26f1f1c9..4e62caebe 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -12,6 +12,7 @@ from tianshou.policy import TD3 from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -79,7 +80,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -88,7 +89,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -97,7 +98,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, action_space=env.action_space, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 1f1d1fecb..cbcda10ed 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -13,6 +13,7 @@ from tianshou.policy import TRPO from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -98,7 +99,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index f2070ce8e..7d0f5a1a2 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -13,9 +13,10 @@ from tianshou.policy.base import Algorithm from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic try: @@ -94,7 +95,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical policy = ActorPolicy( actor=actor, @@ -157,7 +158,7 @@ def stop_fn(mean_rewards: float) -> bool: # env.spec.reward_threshold = 190 # lower the goal net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) - optim = torch.optim.Adam(actor.parameters(), lr=args.il_lr) + optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( actor=actor, action_space=env.action_space, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index cbb8239de..dbb5b846d 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -8,6 +8,7 @@ from tianshou.env import ContinuousToDiscrete, DummyVectorEnv from tianshou.policy import BDQN from tianshou.policy.modelfree.bdqn import BDQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils.net.common import BranchingNet @@ -98,7 +99,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: args.action_hidden_sizes, device=args.device, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = BDQNPolicy( model=net, action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 0023863d5..e6890d971 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -18,6 +18,7 @@ from tianshou.policy import C51 from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -74,16 +75,15 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net( state_shape=args.state_shape, @@ -93,7 +93,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: softmax=True, num_atoms=args.num_atoms, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = C51Policy( model=net, action_space=env.action_space, @@ -109,6 +109,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -120,13 +121,15 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector + + # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # log + + # logger log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 8e1f49957..0efdf6f0e 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -14,6 +14,7 @@ DiscreteSACPolicy, DiscreteSACTrainingStats, ) +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -80,13 +81,13 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: action_dim = space_info.action_info.action_dim net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, softmax_output=False, device=args.device).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) critic1 = Critic(net_c1, last_size=action_dim, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net(obs_dim, hidden_sizes=args.hidden_sizes, device=args.device) critic2 = Critic(net_c2, last_size=action_dim, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # better not to use auto alpha in CartPole if args.auto_alpha: diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6e9e81903..ec0583b6a 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -17,6 +17,7 @@ from tianshou.policy import DQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -68,16 +69,15 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( @@ -87,7 +87,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: device=args.device, # dueling=(Q_param, V_param), ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = DQNPolicy( model=net, action_space=env.action_space, observation_space=env.observation_space ) @@ -98,6 +98,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -109,13 +110,15 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector + + # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # log + + # logger log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 34c891bd5..193d7d77c 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -17,6 +17,7 @@ from tianshou.policy import FQF from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.fqf import FQFPolicy +from tianshou.policy.optim import AdamOptimizerFactory, RMSPropOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -73,16 +74,15 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model feature_net = Net( args.state_shape, @@ -98,9 +98,9 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: num_cosines=args.num_cosines, device=args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) - fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) + fraction_optim = RMSPropOptimizerFactory(lr=args.fraction_lr) policy = FQFPolicy( model=net, fraction_model=fraction_net, @@ -116,6 +116,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -127,13 +128,15 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector + + # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # log + + # logger log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index bd6fd21d5..2e5dda9cd 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -17,6 +17,7 @@ from tianshou.policy import IQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.iqn import IQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -73,16 +74,15 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model feature_net = Net( args.state_shape, @@ -97,7 +97,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: num_cosines=args.num_cosines, device=args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = IQNPolicy( model=net, action_space=env.action_space, @@ -112,6 +112,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -123,13 +124,15 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector + + # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # log + + # logger log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 5192c174d..a3edecb2a 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -12,6 +12,7 @@ from tianshou.policy import Reinforce from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -56,16 +57,15 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model net = Net( state_shape=args.state_shape, @@ -74,7 +74,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: device=args.device, softmax=True, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) dist_fn = torch.distributions.Categorical policy = ActorPolicy( actor=net, @@ -93,6 +93,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) + # collector train_collector = Collector[CollectStats]( algorithm, @@ -100,6 +101,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: VectorReplayBuffer(args.buffer_size, len(train_envs)), ) test_collector = Collector[CollectStats](algorithm, test_envs) + # log log_path = os.path.join(args.logdir, args.task, "pg") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index cc15effdc..867c528a8 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -16,6 +16,7 @@ from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -91,7 +92,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: softmax=False, num_atoms=args.num_quantiles, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = QRDQNPolicy( model=net, action_space=env.action_space, observation_space=env.observation_space ) @@ -103,6 +104,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: @@ -114,13 +116,15 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector + + # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # log + + # logger log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) @@ -144,7 +148,7 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # trainer + # train result = algorithm.run_training( OffPolicyTrainingConfig( train_collector=train_collector, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index f1fdf8bcf..915ef85ab 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -17,6 +17,7 @@ from tianshou.policy import RainbowDQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -89,11 +90,10 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) - # model - def noisy_linear(x: int, y: int) -> NoisyLinear: return NoisyLinear(x, y, args.noisy_std) + # model net = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -103,7 +103,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: num_atoms=args.num_atoms, dueling_param=({"linear_layer": noisy_linear}, {"linear_layer": noisy_linear}), ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = C51Policy( model=net, action_space=env.action_space, @@ -118,6 +118,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: @@ -130,13 +131,15 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) - # collector + + # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # log + + # logger log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) @@ -177,7 +180,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: torch.save( { "model": algorithm.state_dict(), - "optim": optim.state_dict(), + "optim": algorithm.optim.state_dict(), }, ckpt_path, ) @@ -205,7 +208,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: else: print("Fail to restore buffer.") - # trainer + # train result = algorithm.run_training( OffPolicyTrainingConfig( train_collector=train_collector, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index b88379b9d..bf6d0e317 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -15,6 +15,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DQN, ICMOffPolicyWrapper from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net @@ -87,16 +88,15 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( @@ -106,7 +106,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: device=args.device, # dueling=(Q_param, V_param), ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = DQNPolicy( model=net, action_space=env.action_space, @@ -136,8 +136,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes[-1:], device=args.device, ).to(args.device) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - icm_algorithm: ICMOffPolicyWrapper = ICMOffPolicyWrapper( + icm_optim = AdamOptimizerFactory(lr=args.lr) + icm_algorithm = ICMOffPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, @@ -145,6 +145,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: reward_scale=args.reward_scale, forward_loss_weight=args.forward_loss_weight, ) + # buffer buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: @@ -162,8 +163,6 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: icm_algorithm, train_envs, buf, exploration_noise=True ) test_collector = Collector[CollectStats](icm_algorithm, test_envs, exploration_noise=True) - - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 1eaff9e1b..cedcc662d 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -13,6 +13,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net @@ -111,7 +112,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) # base algorithm: PPO - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical policy = ActorPolicy( actor=actor, @@ -153,7 +154,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes[-1:], device=args.device, ).to(args.device) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) + icm_optim = AdamOptimizerFactory(lr=args.lr) icm_algorithm = ICMOnPolicyWrapper( wrapped_algorithm=algorithm, model=icm_net, diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index a150df52a..ead82de46 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -50,7 +50,6 @@ def get_args() -> argparse.Namespace: reason="EnvPool is not installed. If on linux, please install it (e.g. as poetry extra)", ) def test_psrl(args: argparse.Namespace = get_args()) -> None: - # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) if args.reward_threshold is None: @@ -59,9 +58,11 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: print("reward threshold:", args.reward_threshold) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + # model n_action = args.action_shape n_state = args.state_shape @@ -80,6 +81,7 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: policy=policy, add_done_loop=args.add_done_loop, ) + # collector train_collector = Collector[CollectStats]( algorithm, @@ -90,6 +92,7 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: train_collector.reset() test_collector = Collector[CollectStats](algorithm, test_envs) test_collector.reset() + # Logger log_path = os.path.join(args.logdir, args.task, "psrl") writer = SummaryWriter(log_path) @@ -107,7 +110,8 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold train_collector.collect(n_step=args.buffer_size, random=True) - # train (test it without logger) + + # train result = algorithm.run_training( OnPolicyTrainingConfig( train_collector=train_collector, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 6c87f64c6..0112ae244 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -13,6 +13,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import BCQ, Algorithm from tianshou.policy.imitation.bcq import BCQPolicy, BCQTrainingStats +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net @@ -90,12 +91,12 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) - # model # perturbation network net_a = MLP( input_dim=args.state_dim + args.action_dim, @@ -106,7 +107,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, @@ -116,7 +117,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic = Critic(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae # output_dim = 0, so the last Module in the encoder is ReLU @@ -141,7 +142,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: max_action=args.max_action, device=args.device, ).to(args.device) - vae_optim = torch.optim.Adam(vae.parameters()) + vae_optim = AdamOptimizerFactory() policy = BCQPolicy( actor_perturbation=actor, @@ -168,7 +169,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: # buffer has been gathered # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs) - # log + + # logger t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' log_path = os.path.join(args.logdir, args.task, "bcq", log_file) @@ -189,7 +191,7 @@ def watch() -> None: collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) - # trainer + # train result = algorithm.run_training( OfflineTrainingConfig( buffer=buffer, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 4df3cae4d..d49134df6 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -14,6 +14,7 @@ from tianshou.policy import CQL, Algorithm from tianshou.policy.imitation.cql import CQLTrainingStats from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -117,7 +118,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c = Net( @@ -128,7 +129,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic = Critic(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = -np.prod(args.action_shape) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 1dd61ecb9..773ecd1fe 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -17,9 +17,10 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import Algorithm, DiscreteBCQ from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor from tianshou.utils.space_info import SpaceInfo @@ -68,10 +69,12 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: env.spec.reward_threshold if env.spec else None, ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) + # model net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) policy_net = Actor( @@ -86,9 +89,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, device=args.device, ).to(args.device) - actor_critic = ActorCritic(policy_net, imitation_net) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - + optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, @@ -104,6 +105,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: eval_eps=args.eps_test, imitation_logits_penalty=args.imitation_logits_penalty, ) + # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): @@ -118,6 +120,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: # collector test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) + # logger log_path = os.path.join(args.logdir, args.task, "discrete_bcq") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) @@ -154,6 +157,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: else: print("Fail to restore policy and optim.") + # train result = algorithm.run_training( OfflineTrainingConfig( buffer=buffer, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 142fffd4f..f071c9d30 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -17,6 +17,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import Algorithm, DiscreteCQL from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -65,10 +66,12 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: env.spec.reward_threshold if env.spec else None, ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) + # model net = Net( state_shape=args.state_shape, @@ -78,7 +81,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: softmax=False, num_atoms=args.num_quantiles, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = QRDQNPolicy( model=net, @@ -93,6 +96,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, min_q_weight=args.min_q_weight, ).to(args.device) + # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): @@ -117,6 +121,7 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold + # train result = algorithm.run_training( OfflineTrainingConfig( buffer=buffer, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index a873df1a0..2fe4aeb72 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -17,9 +17,10 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import Algorithm, DiscreteCRR from tianshou.policy.modelfree.pg import DiscreteActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.space_info import SpaceInfo @@ -63,11 +64,13 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: env.spec.reward_threshold if env.spec else None, ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) - # model + + # model and algorithm net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) actor = Actor( preprocess_net=net, @@ -83,9 +86,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: last_size=action_dim, device=args.device, ) - actor_critic = ActorCritic(actor, critic) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - + optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteActorPolicy( actor=actor, action_space=env.action_space, @@ -97,6 +98,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: discount_factor=args.gamma, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): @@ -121,6 +123,7 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold + # train result = algorithm.run_training( OfflineTrainingConfig( buffer=buffer, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index c83c22aaf..d9cab8723 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -13,6 +13,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import GAIL, Algorithm from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net @@ -112,7 +113,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # discriminator disc_net = Critic( Net( @@ -130,7 +131,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) - disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr) + disc_optim = AdamOptimizerFactory(lr=args.disc_lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -189,7 +190,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: torch.save( { "model": algorithm.state_dict(), - "optim": optim.state_dict(), + "optim": algorithm.optim.state_dict(), }, ckpt_path, ) @@ -202,7 +203,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) algorithm.load_state_dict(checkpoint["model"]) - optim.load_state_dict(checkpoint["optim"]) + algorithm.optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 481809d2c..ce98a4c92 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -15,6 +15,7 @@ from tianshou.policy import TD3BC from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -86,14 +87,13 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: args.state_dim = space_info.action_info.action_dim args.action_dim = space_info.observation_info.obs_dim - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) - # model # actor network net_a = Net( args.state_shape, @@ -106,9 +106,9 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: max_action=args.max_action, device=args.device, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - # critic network + # critic networks net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -124,13 +124,15 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) + # policy and algorithm policy = DDPGPolicy( actor=actor, action_space=env.action_space, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), ) algorithm: TD3BC = TD3BC( policy=policy, @@ -141,7 +143,6 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, @@ -158,7 +159,8 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: # buffer has been gathered # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs) - # log + + # logger t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc' log_path = os.path.join(args.logdir, args.task, "td3_bc", log_file) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index fbc5ca6ab..086af3cdb 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -20,6 +20,7 @@ MultiAgentOffPolicyAlgorithm, ) from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -101,7 +102,7 @@ def get_agents( args: argparse.Namespace = get_args(), agent_learn: Algorithm | None = None, agent_opponent: Algorithm | None = None, - optim: torch.optim.Optimizer | None = None, + optim: OptimizerFactory | None = None, ) -> tuple[MultiAgentOffPolicyAlgorithm, torch.optim.Optimizer | None, list]: env = get_env() observation_space = ( @@ -120,7 +121,7 @@ def get_agents( device=args.device, ).to(args.device) if optim is None: - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) algorithm = DQNPolicy( model=net, action_space=env.action_space, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index b3bbba386..04952c245 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -13,6 +13,7 @@ from numpy.typing import ArrayLike from overrides import override from torch import nn +from torch.optim.lr_scheduler import LRScheduler from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as from tianshou.data.batch import Batch, BatchProtocol, TArr @@ -24,7 +25,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.utils import MultipleLRSchedulers +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.lagged_network import ( LaggedNetworkCollection, ) @@ -46,7 +47,6 @@ logger = logging.getLogger(__name__) -TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers TArrOrActBatch = TypeVar("TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") @@ -479,23 +479,20 @@ def __init__( self, *, policy: TPolicy, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: - """ - :param policy: the policy - :param lr_scheduler: if not None, will be called in `update()`. - """ + """:param policy: the policy""" super().__init__() self.policy: TPolicy = policy - self.lr_scheduler = lr_scheduler + self.lr_schedulers: list[LRScheduler] = [] self.updating = False - # TODO delete this - def __setstate__(self, state: dict[str, Any]) -> None: - # TODO Use setstate function once merged - if "is_within_training_step" not in state: - state["is_within_training_step"] = False - self.__dict__ = state + def _create_optimizer( + self, module: torch.nn.Module, factory: OptimizerFactory + ) -> torch.optim.Optimizer: + optimizer, lr_scheduler = factory.create_instances(module) + if lr_scheduler is not None: + self.lr_schedulers.append(lr_scheduler) + return optimizer def process_fn( self, @@ -578,8 +575,8 @@ def _update( with torch_train_mode(self): training_stat = update_with_batch_fn(batch) self.post_process_fn(batch, buffer, indices) - if self.lr_scheduler is not None: - self.lr_scheduler.step() + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step() self.updating = False training_stat.train_time = time.time() - start_time return training_stat @@ -880,9 +877,8 @@ class OnPolicyWrapperAlgorithm( def __init__( self, wrapped_algorithm: OnPolicyAlgorithm[TPolicy, TWrappedAlgorthmTrainingStats], - lr_scheduler: TLearningRateScheduler | None = None, ): - super().__init__(policy=wrapped_algorithm.policy, lr_scheduler=lr_scheduler) + super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm def process_fn( @@ -920,9 +916,8 @@ class OffPolicyWrapperAlgorithm( def __init__( self, wrapped_algorithm: OffPolicyAlgorithm[TPolicy, TWrappedAlgorthmTrainingStats], - lr_scheduler: TLearningRateScheduler | None = None, ): - super().__init__(policy=wrapped_algorithm.policy, lr_scheduler=lr_scheduler) + super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm def process_fn( diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 51ebd92e1..e385f2698 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -16,9 +16,9 @@ from tianshou.policy.base import ( OffPolicyAlgorithm, Policy, - TLearningRateScheduler, TrainingStats, ) +from tianshou.policy.optim import OptimizerFactory # Dimension Naming Convention # B - Batch Size @@ -94,8 +94,7 @@ def __init__( self, *, policy: ImitationPolicy, - optim: torch.optim.Optimizer, - lr_scheduler: TLearningRateScheduler | None = None, + optim: OptimizerFactory, ) -> None: """ :param policy: the policy @@ -111,9 +110,8 @@ def __init__( """ super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) - self.optim = optim + self.optim = self._create_optimizer(self.policy, optim) def _update_with_batch( self, diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 3f171c0c8..58a841ba8 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -14,11 +14,10 @@ LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, Policy, - TLearningRateScheduler, TrainingStats, ) +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import VAE -from tianshou.utils.optim import clone_optimizer @dataclass(kw_only=True) @@ -107,50 +106,47 @@ def __init__( self, *, policy: BCQPolicy, - actor_perturbation_optim: torch.optim.Optimizer, - critic_optim: torch.optim.Optimizer, - vae_optim: torch.optim.Optimizer, + actor_perturbation_optim: OptimizerFactory, + critic_optim: OptimizerFactory, + vae_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, + critic2_optim: OptimizerFactory | None = None, gamma: float = 0.99, tau: float = 0.005, lmbda: float = 0.75, num_sampled_action: int = 10, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy - :param actor_perturbation_optim: the optimizer for the policy's actor perturbation network. - :param critic_optim: the optimizer for the policy's critic network. + :param actor_perturbation_optim: the optimizer factory for the policy's actor perturbation network. + :param critic_optim: the optimizer factory for the policy's critic network. :param critic2: the second critic network; if None, clone the critic from the policy - :param critic2_optim: the optimizer for the second critic network; if None, clone optimizer of first critic + :param critic2_optim: the optimizer for the second critic network; if None, use optimizer factory of first critic :param vae_optim: the optimizer for the VAE network. :param gamma: discount factor, in [0, 1]. :param tau: param for soft update of the target network. :param lmbda: param for Clipped Double Q-learning. :param num_sampled_action: the number of sampled actions in calculating target Q. The algorithm samples several actions using VAE, and perturbs each action to get the target Q. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ # actor is Perturbation! super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) self.actor_perturbation_target = self._add_lagged_network(self.policy.actor_perturbation) - self.actor_perturbation_optim = actor_perturbation_optim + self.actor_perturbation_optim = self._create_optimizer( + self.policy.actor_perturbation, actor_perturbation_optim + ) self.critic_target = self._add_lagged_network(self.policy.critic) - self.critic_optim = critic_optim + self.critic_optim = self._create_optimizer(self.policy.critic, critic_optim) - critic2 = critic2 or copy.deepcopy(self.policy.critic) - critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2 = critic2 + self.critic2 = critic2 or copy.deepcopy(self.policy.critic) self.critic2_target = self._add_lagged_network(self.critic2) - self.critic2_optim = critic2_optim + self.critic2_optim = self._create_optimizer(self.critic2, critic2_optim or critic_optim) - self.vae_optim = vae_optim + self.vae_optim = self._create_optimizer(self.policy.vae, vae_optim) self.gamma = gamma self.lmbda = lmbda diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 12cf8577e..376eaa032 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -14,11 +14,10 @@ from tianshou.policy.base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, - TLearningRateScheduler, ) from tianshou.policy.modelfree.sac import Alpha, SACPolicy, SACTrainingStats +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.conversion import to_optional_float -from tianshou.utils.optim import clone_optimizer from tianshou.utils.torch_utils import torch_device @@ -41,11 +40,11 @@ def __init__( self, *, policy: SACPolicy, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, + critic2_optim: OptimizerFactory | None = None, cql_alpha_lr: float = 1e-4, cql_weight: float = 1.0, tau: float = 0.005, @@ -61,8 +60,6 @@ def __init__( alpha_max: float = 1e6, clip_grad: float = 1.0, calibrated: bool = True, - estimation_step: int = 1, # TODO remove - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param actor: the actor network following the rules in @@ -94,25 +91,19 @@ def __init__( :param calibrated: calibrate Q-values as in CalQL paper `arXiv:2303.05479`. Useful for offline pre-training followed by online training, and also was observed to achieve better results than vanilla cql. - :param estimation_step: Estimation steps. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in - optimizer in each policy.update(). """ super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) device = torch_device(policy) - self.policy_optim = policy_optim + self.policy_optim = self._create_optimizer(self.policy, policy_optim) self.critic = critic - self.critic_optim = critic_optim + self.critic_optim = self._create_optimizer(self.critic, critic_optim) self.critic2 = critic2 or deepcopy(critic) - self.critic2_optim = critic2_optim or clone_optimizer( - critic_optim, self.critic2.parameters() - ) + self.critic2_optim = self._create_optimizer(self.critic2, critic2_optim or critic_optim) self.critic_old = self._add_lagged_network(self.critic) self.critic2_old = self._add_lagged_network(self.critic2) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index d13eca89b..a5b47e225 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -18,9 +18,9 @@ LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, Policy, - TLearningRateScheduler, ) from tianshou.policy.modelfree.dqn import DQNTrainingStats +from tianshou.policy.optim import OptimizerFactory float_info = torch.finfo(torch.float32) INF = float_info.max @@ -108,7 +108,7 @@ def __init__( self, *, policy: DiscreteBCQPolicy, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 8000, @@ -117,7 +117,6 @@ def __init__( reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -137,14 +136,12 @@ def __init__( :param clip_loss_grad: clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of the MSE loss. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) - self.optim = optim + self.optim = self._create_optimizer(self.policy, optim) assert ( 0.0 <= discount_factor <= 1.0 ), f"discount factor should be in [0, 1] but got: {discount_factor}" diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index c4703dcff..b2d62d7fc 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -8,8 +8,9 @@ from tianshou.data import to_torch from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import QRDQN -from tianshou.policy.base import OfflineAlgorithm, TLearningRateScheduler +from tianshou.policy.base import OfflineAlgorithm from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats +from tianshou.policy.optim import OptimizerFactory @dataclass(kw_only=True) @@ -32,14 +33,13 @@ def __init__( self, *, policy: QRDQNPolicy, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, min_q_weight: float = 10.0, discount_factor: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -53,7 +53,6 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param lr_scheduler: if not None, will be called in `policy.update()`. """ QRDQN.__init__( self, @@ -64,7 +63,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.min_q_weight = min_q_weight diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index e8da15491..e8c642ab6 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -5,19 +5,20 @@ import torch import torch.nn.functional as F from torch.distributions import Categorical +from torch.nn import ModuleList from tianshou.data import ReplayBuffer, to_torch, to_torch_as from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol from tianshou.policy.base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, - TLearningRateScheduler, ) from tianshou.policy.modelfree.pg import ( DiscountedReturnComputation, DiscreteActorPolicy, PGTrainingStats, ) +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.discrete import Critic @@ -42,7 +43,7 @@ def __init__( *, policy: DiscreteActorPolicy, critic: torch.nn.Module | Critic, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, discount_factor: float = 0.99, policy_improvement_mode: Literal["exp", "binary", "all"] = "exp", ratio_upper_bound: float = 20.0, @@ -50,7 +51,6 @@ def __init__( min_q_weight: float = 10.0, target_update_freq: int = 0, reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: r""" :param policy: the policy @@ -70,19 +70,17 @@ def __init__( :param reward_normalization: if True, will normalize the *returns* by subtracting the running mean and dividing by the running standard deviation. Can be detrimental to performance! - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) - self.optim = optim self.discounted_return_computation = DiscountedReturnComputation( discount_factor=discount_factor, reward_normalization=reward_normalization, ) self.critic = critic + self.optim = self._create_optimizer(ModuleList([self.policy, self.critic]), optim) self._target = target_update_freq > 0 self._freq = target_update_freq self._iter = 0 diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 4b3d1829d..2d5a4b445 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -13,9 +13,9 @@ ) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import PPO -from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.modelfree.ppo import PPOTrainingStats +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -38,10 +38,10 @@ def __init__( *, policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, expert_buffer: ReplayBuffer, disc_net: torch.nn.Module, - disc_optim: torch.optim.Optimizer, + disc_optim: OptimizerFactory, disc_update_num: int = 4, eps_clip: float = 0.2, dual_clip: float | None = None, @@ -56,16 +56,15 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: r""" :param policy: the policy. :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic networks. + :param optim: the optimizer factory for the actor and critic networks. :param expert_buffer: the replay buffer containing expert experience. :param disc_net: the discriminator network with input dim equals state dim plus action dim and output dim equals 1. - :param disc_optim: the optimizer for the discriminator network. + :param disc_optim: the optimizer factory for the discriminator network. :param disc_update_num: the number of discriminator grad steps per model grad step. :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original paper. @@ -84,7 +83,6 @@ def __init__( :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. :param reward_normalization: normalize estimated values to have std close to 1. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, @@ -102,10 +100,9 @@ def __init__( max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.disc_net = disc_net - self.disc_optim = disc_optim + self.disc_optim = self._create_optimizer(self.disc_net, disc_optim) self.disc_update_num = disc_update_num self.expert_buffer = expert_buffer # TODO: This violates the type requirement; nn.Module does not necessarily have output_dim! diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 5877df467..dfddfa429 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -6,11 +6,11 @@ from tianshou.data import to_torch_as from tianshou.data.types import RolloutBatchProtocol -from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import TD3 -from tianshou.policy.base import OfflineAlgorithm, TLearningRateScheduler +from tianshou.policy.base import OfflineAlgorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.modelfree.td3 import TD3TrainingStats +from tianshou.policy.optim import OptimizerFactory @dataclass(kw_only=True) @@ -29,21 +29,19 @@ def __init__( self, *, policy: DDPGPolicy, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, + critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, - exploration_noise: BaseNoise | None = GaussianNoise(sigma=0.1), policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5, # TODO: same name as alpha in SAC and REDQ, which also inherit from DDPGPolicy. Rename? alpha: float = 2.5, estimation_step: int = 1, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -64,8 +62,6 @@ def __init__( :param noise_clip: the clipping range used in updating policy network. :param alpha: the value of alpha, which controls the weight for TD3 learning relative to behavior cloning. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() """ TD3.__init__( self, @@ -81,7 +77,6 @@ def __init__( noise_clip=noise_clip, update_actor_freq=update_actor_freq, estimation_step=estimation_step, - lr_scheduler=lr_scheduler, ) self.alpha = alpha diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index a192d972b..2b87414b9 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -10,12 +10,12 @@ OffPolicyWrapperAlgorithm, OnPolicyAlgorithm, OnPolicyWrapperAlgorithm, - TLearningRateScheduler, TPolicy, TrainingStats, TrainingStatsWrapper, TTrainingStats, ) +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -105,29 +105,26 @@ def __init__( *, wrapped_algorithm: OffPolicyAlgorithm[TPolicy, TTrainingStats], model: IntrinsicCuriosityModule, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, lr_scale: float, reward_scale: float, forward_loss_weight: float, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param wrapped_algorithm: the base algorithm to which we want to add the ICM. :param model: the ICM model. - :param optim: the optimizer for parameter `model`. + :param optim: the optimizer factory for the ICM model. :param lr_scale: the scaling factor for ICM learning. :param forward_loss_weight: the weight for forward model loss. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ OffPolicyWrapperAlgorithm.__init__( self, wrapped_algorithm=wrapped_algorithm, - lr_scheduler=lr_scheduler, ) _ICMMixin.__init__( self, model=model, - optim=optim, + optim=self._create_optimizer(model, optim), lr_scale=lr_scale, reward_scale=reward_scale, forward_loss_weight=forward_loss_weight, @@ -169,29 +166,26 @@ def __init__( *, wrapped_algorithm: OnPolicyAlgorithm[TPolicy, TTrainingStats], model: IntrinsicCuriosityModule, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, lr_scale: float, reward_scale: float, forward_loss_weight: float, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param wrapped_algorithm: the base algorithm to which we want to add the ICM. :param model: the ICM model. - :param optim: the optimizer for parameter `model`. + :param optim: the optimizer factory for the ICM model. :param lr_scale: the scaling factor for ICM learning. :param forward_loss_weight: the weight for forward model loss. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ OnPolicyWrapperAlgorithm.__init__( self, wrapped_algorithm=wrapped_algorithm, - lr_scheduler=lr_scheduler, ) _ICMMixin.__init__( self, model=model, - optim=optim, + optim=self._create_optimizer(model, optim), lr_scale=lr_scale, reward_scale=reward_scale, forward_loss_weight=forward_loss_weight, diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index f1eae2102..5d11de79a 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -11,7 +11,6 @@ from tianshou.policy.base import ( OnPolicyAlgorithm, Policy, - TLearningRateScheduler, TrainingStats, ) @@ -222,7 +221,6 @@ def __init__( *, policy: PSRLPolicy, add_done_loop: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -232,7 +230,6 @@ def __init__( """ super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) self._add_done_loop = add_done_loop diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 6f46d862b..60e6d7487 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -11,10 +11,10 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy.base import ( OnPolicyAlgorithm, - TLearningRateScheduler, TrainingStats, ) from tianshou.policy.modelfree.pg import ActorPolicy, TPGTrainingStats +from tianshou.policy.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.continuous import Critic @@ -42,32 +42,35 @@ def __init__( *, policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, + optim_include_actor: bool, gae_lambda: float = 0.95, max_batchsize: int = 256, discount_factor: float = 0.99, reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer factory. + :param optim_include_actor: whether the optimizer shall include the actor network's parameters. + Pass False for algorithms that shall update only the critic via the optimizer. :param gae_lambda: in [0, 1], param for generalized advantage estimation (GAE). :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. :param reward_normalization: normalize estimated values to have std close to 1. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) self.critic = critic assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}" self.gae_lambda = gae_lambda self.max_batchsize = max_batchsize self._actor_critic = ActorCritic(self.policy.actor, self.critic) - self.optim = optim + if optim_include_actor: + self.optim = self._create_optimizer(self._actor_critic, optim) + else: + self.optim = self._create_optimizer(self.critic, optim) self.gamma = discount_factor self.rew_norm = reward_normalization self.ret_rms = RunningMeanStd() @@ -122,7 +125,7 @@ def __init__( *, policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float | None = None, @@ -131,11 +134,11 @@ def __init__( discount_factor: float = 0.99, # TODO: This algorithm does not seem to use the reward_normalization parameter. reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ + :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer factory for the actor and critic networks. :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. :param max_grad_norm: clipping gradients in back propagation. @@ -143,17 +146,16 @@ def __init__( :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. :param reward_normalization: normalize estimated values to have std close to 1. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, critic=critic, optim=optim, + optim_include_actor=True, gae_lambda=gae_lambda, max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.vf_coef = vf_coef self.ent_coef = ent_coef diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index f8818750d..9e8fcbfc3 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -15,12 +15,13 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy.base import TArrOrActBatch, TLearningRateScheduler +from tianshou.policy.base import TArrOrActBatch from tianshou.policy.modelfree.dqn import ( DQNPolicy, DQNTrainingStats, QLearningOffPolicyAlgorithm, ) +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import BranchingNet mark_used(ActBatchProtocol) @@ -97,13 +98,12 @@ def __init__( self, *, policy: BDQNPolicy, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: policy @@ -115,7 +115,6 @@ def __init__( :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? :param is_double: whether to use double DQN. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ assert ( estimation_step == 1 @@ -127,7 +126,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.is_double = is_double diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 9999140d5..59f5b8301 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -7,12 +7,12 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import ( DQNPolicy, DQNTrainingStats, QLearningOffPolicyAlgorithm, ) +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import Net @@ -68,12 +68,11 @@ def __init__( self, *, policy: C51Policy, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: a policy following the rules (s -> action_values_BA) @@ -84,11 +83,6 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, @@ -97,7 +91,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.delta_z = (policy.v_max - policy.v_min) / (policy.num_atoms - 1) diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index cea89ce86..6b5344f7a 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -23,11 +23,11 @@ OffPolicyAlgorithm, Policy, TArrOrActBatch, - TLearningRateScheduler, TPolicy, TrainingStats, TTrainingStats, ) +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import Actor, Critic mark_used(ActBatchProtocol) @@ -183,13 +183,12 @@ def __init__( self, *, policy: Any, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, tau: float = 0.005, gamma: float = 0.99, estimation_step: int = 1, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -209,13 +208,12 @@ def __init__( assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}" super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) LaggedNetworkPolyakUpdateAlgorithmMixin.__init__(self, tau=tau) - self.policy_optim = policy_optim + self.policy_optim = self._create_optimizer(policy, policy_optim) self.critic = critic self.critic_old = self._add_lagged_network(self.critic) - self.critic_optim = critic_optim + self.critic_optim = self._create_optimizer(self.critic, critic_optim) self.gamma = gamma self.estimation_step = estimation_step @@ -307,13 +305,12 @@ def __init__( self, *, policy: DDPGPolicy, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module | Critic, - critic_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, tau: float = 0.005, gamma: float = 0.99, estimation_step: int = 1, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -328,7 +325,6 @@ def __init__( super().__init__( policy=policy, policy_optim=policy_optim, - lr_scheduler=lr_scheduler, critic=critic, critic_optim=critic_optim, tau=tau, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index fbbd3cbe5..6ec6beb26 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -12,9 +12,10 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy.base import Policy, TLearningRateScheduler +from tianshou.policy.base import Policy from tianshou.policy.modelfree.sac import Alpha, SACTrainingStats from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.discrete import Critic @@ -81,16 +82,15 @@ def __init__( self, *, policy: DiscreteSACPolicy, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module | Critic, - critic_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, critic2: torch.nn.Module | Critic | None = None, - critic2_optim: torch.optim.Optimizer | None = None, + critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, estimation_step: int = 1, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -119,7 +119,6 @@ def __init__( tau=tau, gamma=gamma, estimation_step=estimation_step, - lr_scheduler=lr_scheduler, ) self.alpha = Alpha.from_float_or_instance(alpha) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 52efe6da9..ec60109c9 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -21,10 +21,10 @@ OffPolicyAlgorithm, Policy, TArrOrActBatch, - TLearningRateScheduler, TrainingStats, TTrainingStats, ) +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import Net mark_used(ActBatchProtocol) @@ -52,7 +52,6 @@ def __init__( :param action_space: the environment's action space :param observation_space: the environment's observation space. """ - assert isinstance(action_space, gym.spaces.Discrete) super().__init__( action_space=action_space, observation_space=observation_space, @@ -159,12 +158,11 @@ def __init__( self, *, policy: TDQNPolicy, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -175,13 +173,11 @@ def __init__( 0 if a target network shall not be used. :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) - self.optim = optim + self.optim = self._create_policy_optimizer(optim) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) assert ( 0.0 <= discount_factor <= 1.0 @@ -199,6 +195,9 @@ def __init__( self._add_lagged_network(self.policy.model) if self.use_target_network else None ) + def _create_policy_optimizer(self, optim: OptimizerFactory) -> torch.optim.Optimizer: + return self._create_optimizer(self.policy, optim) + @property def use_target_network(self) -> bool: return self.target_update_freq > 0 @@ -255,14 +254,13 @@ def __init__( self, *, policy: TDQNPolicy, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, discount_factor: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -277,12 +275,10 @@ def __init__( :param clip_loss_grad: clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of the MSE loss. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, optim=optim, - lr_scheduler=lr_scheduler, discount_factor=discount_factor, estimation_step=estimation_step, target_update_freq=target_update_freq, diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index b886c74b9..9c42e25ec 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -5,13 +5,14 @@ import numpy as np import torch import torch.nn.functional as F +from overrides import override from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import QRDQN -from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -100,8 +101,8 @@ def __init__( self, *, policy: FQFPolicy, - optim: torch.optim.Optimizer, - fraction_optim: torch.optim.Optimizer, + optim: OptimizerFactory, + fraction_optim: OptimizerFactory, discount_factor: float = 0.99, # TODO: used as num_quantiles in QRDQNPolicy, but num_fractions in FQFPolicy. # Rename? Or at least explain what happens here. @@ -110,7 +111,6 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -125,8 +125,6 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param observation_space: Env's observation space. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, @@ -136,10 +134,15 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.ent_coef = ent_coef - self.fraction_optim = fraction_optim + self.fraction_optim = self._create_optimizer(self.policy.fraction_model, fraction_optim) + + @override + def _create_policy_optimizer(self, optim: OptimizerFactory) -> torch.optim.Optimizer: + # Override to leave out the fraction model (use main model only), as we want + # to use a separate optimizer for the fraction model + return self._create_optimizer(self.policy.model, optim) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 160a5980b..3f7e4f7f1 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -14,8 +14,8 @@ RolloutBatchProtocol, ) from tianshou.policy import QRDQN -from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats +from tianshou.policy.optim import OptimizerFactory @dataclass(kw_only=True) @@ -95,13 +95,12 @@ def __init__( self, *, policy: IQNPolicy, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, discount_factor: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -114,7 +113,6 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, @@ -124,7 +122,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) def _update_with_batch( diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index f2aa4d1a2..506fc08c3 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -9,9 +9,10 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import TrainingStats from tianshou.policy.modelfree.a2c import ActorCriticOnPolicyAlgorithm from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -37,7 +38,7 @@ def __init__( *, policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, optim_critic_iters: int = 5, actor_step_size: float = 0.5, advantage_normalization: bool = True, @@ -46,12 +47,11 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ - :param policy: the policy + :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer factory for the critic network. :param optim_critic_iters: Number of times to optimize critic network per update. :param actor_step_size: step size for actor update in natural gradient direction. :param advantage_normalization: whether to do per mini-batch advantage @@ -60,17 +60,16 @@ def __init__( :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. :param reward_normalization: normalize estimated values to have std close to 1. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, critic=critic, optim=optim, + optim_include_actor=False, gae_lambda=gae_lambda, max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.norm_adv = advantage_normalization self.optim_critic_iters = optim_critic_iters diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 81f7796f5..e7381c8a8 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -25,9 +25,9 @@ from tianshou.policy.base import ( OnPolicyAlgorithm, Policy, - TLearningRateScheduler, TrainingStats, ) +from tianshou.policy.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.continuous import ActorProb from tianshou.utils.net.discrete import Actor, dist_fn_categorical_from_logits @@ -255,8 +255,7 @@ def __init__( policy: TActorPolicy, discount_factor: float = 0.99, reward_normalization: bool = False, - optim: torch.optim.Optimizer, - lr_scheduler: TLearningRateScheduler | None = None, + optim: OptimizerFactory, ) -> None: """ :param policy: the policy @@ -265,17 +264,15 @@ def __init__( :param reward_normalization: if True, will normalize the *returns* by subtracting the running mean and dividing by the running standard deviation. Can be detrimental to performance! - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, - lr_scheduler=lr_scheduler, ) self.discounted_return_computation = DiscountedReturnComputation( discount_factor=discount_factor, reward_normalization=reward_normalization, ) - self.optim = optim + self.optim = self._create_optimizer(self.policy, optim) def process_fn( self, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 3d7ec4f4a..044f459f7 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -9,8 +9,9 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import A2C -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import TrainingStats from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -61,7 +62,7 @@ def __init__( *, policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = False, @@ -75,12 +76,11 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: r""" - :param policy: the policy + :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer factory for the actor and critic networks. :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original paper. :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, @@ -98,7 +98,6 @@ def __init__( :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. :param reward_normalization: normalize estimated values to have std close to 1. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ assert ( dual_clip is None or dual_clip > 1.0 @@ -115,7 +114,6 @@ def __init__( max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.eps_clip = eps_clip self.dual_clip = dual_clip diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 28bad5ac6..47179d033 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -8,12 +8,12 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.dqn import ( DQNPolicy, DQNTrainingStats, QLearningOffPolicyAlgorithm, ) +from tianshou.policy.optim import OptimizerFactory @dataclass(kw_only=True) @@ -42,13 +42,12 @@ def __init__( self, *, policy: TQRDQNPolicy, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, discount_factor: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -61,7 +60,6 @@ def __init__( you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). TODO: rename to return_normalization? - :param lr_scheduler: if not None, will be called in `policy.update()`. """ assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" super().__init__( @@ -71,7 +69,6 @@ def __init__( estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.num_quantiles = num_quantiles tau = torch.linspace(0, 1, self.num_quantiles + 1) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 07d247b4d..11c24f457 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -13,13 +13,13 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, ContinuousPolicyWithExplorationNoise, DDPGTrainingStats, ) from tianshou.policy.modelfree.sac import Alpha +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import ActorProb @@ -109,9 +109,9 @@ def __init__( self, *, policy: REDQPolicy, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, ensemble_size: int = 10, subset_size: int = 2, tau: float = 0.005, @@ -121,7 +121,6 @@ def __init__( actor_delay: int = 20, deterministic_eval: bool = True, target_mode: Literal["mean", "min"] = "min", - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -134,11 +133,8 @@ def __init__( :param gamma: Discount factor, in [0, 1]. :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). - :param exploration_noise: The exploration noise, added to the action. Defaults - to ``GaussianNoise(sigma=0.1)``. :param estimation_step: The number of steps to look ahead. :param actor_delay: Number of critic updates before an actor update. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ if target_mode not in ("min", "mean"): raise ValueError(f"Unsupported target_mode: {target_mode}") @@ -155,7 +151,6 @@ def __init__( tau=tau, gamma=gamma, estimation_step=estimation_step, - lr_scheduler=lr_scheduler, ) self.ensemble_size = ensemble_size self.subset_size = subset_size diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 2e65cd9b1..f25cf6887 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -14,9 +14,10 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.policy.base import TrainingStats from tianshou.policy.modelfree.ddpg import ContinuousPolicyWithExplorationNoise from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ActorProb @@ -210,17 +211,16 @@ def __init__( self, *, policy: SACPolicy, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, + critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, estimation_step: int = 1, deterministic_eval: bool = True, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -236,8 +236,6 @@ def __init__( :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param estimation_step: The number of steps to look ahead. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() """ super().__init__( policy=policy, @@ -249,7 +247,6 @@ def __init__( tau=tau, gamma=gamma, estimation_step=estimation_step, - lr_scheduler=lr_scheduler, ) self.deterministic_eval = deterministic_eval self.alpha = Alpha.from_float_or_instance(alpha) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 9998d89de..c928a9a35 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -11,7 +11,6 @@ RolloutBatchProtocol, ) from tianshou.policy.base import ( - TLearningRateScheduler, TPolicy, TrainingStats, TTrainingStats, @@ -21,7 +20,7 @@ DDPGPolicy, TActBatchProtocol, ) -from tianshou.utils.optim import clone_optimizer +from tianshou.policy.optim import OptimizerFactory @dataclass(kw_only=True) @@ -47,15 +46,14 @@ def __init__( self, *, policy: Any, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, - critic2: torch.nn.Module, - critic2_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, + critic2: torch.nn.Module | None = None, + critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, estimation_step: int = 1, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -68,7 +66,7 @@ def __init__( :param critic2: the second critic network (analogous functionality to the first). If None, use the same network as the first critic (via deepcopy). :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). + If None, use critic_optim. :param tau: param for soft update of the target network. :param gamma: discount factor, in [0, 1]. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate @@ -77,20 +75,15 @@ def __init__( super().__init__( policy=policy, policy_optim=policy_optim, - lr_scheduler=lr_scheduler, critic=critic, critic_optim=critic_optim, tau=tau, gamma=gamma, estimation_step=estimation_step, ) - if critic2 and not critic2_optim: - raise ValueError("critic2_optim must be provided if critic2 is provided") - critic2 = critic2 or deepcopy(critic) - critic2_optim = critic2_optim or clone_optimizer(critic_optim, critic2.parameters()) - self.critic2 = critic2 + self.critic2 = critic2 or deepcopy(critic) self.critic2_old = self._add_lagged_network(self.critic2) - self.critic2_optim = critic2_optim + self.critic2_optim = self._create_optimizer(self.critic2, critic2_optim or critic_optim) def _target_q_compute_value( self, obs_batch: Batch, act_batch: TActBatchProtocol @@ -113,18 +106,17 @@ def __init__( self, *, policy: DDPGPolicy, - policy_optim: torch.optim.Optimizer, + policy_optim: OptimizerFactory, critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, + critic_optim: OptimizerFactory, critic2: torch.nn.Module | None = None, - critic2_optim: torch.optim.Optimizer | None = None, + critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5, estimation_step: int = 1, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param policy: the policy @@ -140,8 +132,6 @@ def __init__( :param policy_noise: the noise used in updating policy network. :param update_actor_freq: the update frequency of actor network. :param noise_clip: the clipping range used in updating policy network. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() """ super().__init__( policy=policy, @@ -153,7 +143,6 @@ def __init__( tau=tau, gamma=gamma, estimation_step=estimation_step, - lr_scheduler=lr_scheduler, ) self.actor_old = self._add_lagged_network(self.policy.actor) self.policy_noise = policy_noise diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index cc68ace02..2dac104af 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -8,9 +8,9 @@ from tianshou.data import Batch, SequenceSummaryStats from tianshou.policy import NPG -from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import Critic from tianshou.utils.net.discrete import Critic as DiscreteCritic @@ -31,7 +31,7 @@ def __init__( *, policy: ActorPolicy, critic: torch.nn.Module | Critic | DiscreteCritic, - optim: torch.optim.Optimizer, + optim: OptimizerFactory, max_kl: float = 0.01, backtrack_coeff: float = 0.8, max_backtracks: int = 10, @@ -43,11 +43,10 @@ def __init__( discount_factor: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer factory for the critic network. :param max_kl: max kl-divergence used to constrain each actor network update. :param backtrack_coeff: Coefficient to be multiplied by step size when constraints are not met. @@ -60,7 +59,6 @@ def __init__( :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. :param reward_normalization: normalize estimated values to have std close to 1. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, @@ -73,7 +71,6 @@ def __init__( max_batchsize=max_batchsize, discount_factor=discount_factor, reward_normalization=reward_normalization, - lr_scheduler=lr_scheduler, ) self.max_backtracks = max_backtracks self.max_kl = max_kl diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index f610d433d..9333b5e25 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -14,7 +14,6 @@ OffPolicyAlgorithm, OnPolicyAlgorithm, Policy, - TLearningRateScheduler, TrainingStats, ) @@ -277,7 +276,6 @@ def __init__( *, algorithms: list[OffPolicyAlgorithm], env: PettingZooEnv, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param algorithms: a list of off-policy algorithms. @@ -287,7 +285,6 @@ def __init__( self._dispatcher: MARLDispatcher[OffPolicyAlgorithm] = MARLDispatcher(algorithms, env) super().__init__( policy=self._dispatcher.create_policy(), - lr_scheduler=lr_scheduler, ) self._submodules = ModuleList(algorithms) @@ -323,17 +320,14 @@ def __init__( *, algorithms: list[OnPolicyAlgorithm], env: PettingZooEnv, - lr_scheduler: TLearningRateScheduler | None = None, ) -> None: """ :param algorithms: a list of off-policy algorithms. :param env: the multi-agent RL environment - :param lr_scheduler: if not None, will be called in `policy.update()`. """ self._dispatcher: MARLDispatcher[OnPolicyAlgorithm] = MARLDispatcher(algorithms, env) super().__init__( policy=self._dispatcher.create_policy(), - lr_scheduler=lr_scheduler, ) self._submodules = ModuleList(algorithms) diff --git a/tianshou/policy/optim.py b/tianshou/policy/optim.py new file mode 100644 index 000000000..ab1ae67ac --- /dev/null +++ b/tianshou/policy/optim.py @@ -0,0 +1,139 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable +from typing import Any, Self, TypeAlias + +import numpy as np +import torch +from sensai.util.string import ToStringMixin +from torch.optim import Adam, RMSprop +from torch.optim.lr_scheduler import LambdaLR, LRScheduler + +ParamsType: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] + + +class LRSchedulerFactory(ToStringMixin, ABC): + """Factory for the creation of a learning rate scheduler.""" + + @abstractmethod + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + pass + + +class LRSchedulerFactoryLinear(LRSchedulerFactory): + """ + Factory for a learning rate scheduler where the learning rate linearly decays towards + zero for the given trainer parameters. + """ + + def __init__(self, num_epochs: int, step_per_epoch: int, step_per_collect: int): + self.num_epochs = num_epochs + self.step_per_epoch = step_per_epoch + self.step_per_collect = step_per_collect + + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + return LambdaLR(optim, lr_lambda=self._LRLambda(self).compute) + + class _LRLambda: + def __init__(self, parent: "LRSchedulerFactoryLinear"): + self.max_update_num = ( + np.ceil(parent.step_per_epoch / parent.step_per_collect) * parent.num_epochs + ) + + def compute(self, epoch: int) -> float: + return 1.0 - epoch / self.max_update_num + + +class OptimizerFactory(ABC, ToStringMixin): + def __init__(self) -> None: + self.lr_scheduler_factory: LRSchedulerFactory | None = None + + def with_lr_scheduler_factory(self, lr_scheduler_factory: LRSchedulerFactory) -> Self: + self.lr_scheduler_factory = lr_scheduler_factory + return self + + def create_instances( + self, + module: torch.nn.Module, + ) -> tuple[torch.optim.Optimizer, LRScheduler | None]: + optimizer = self._create_optimizer_for_params(module.parameters()) + lr_scheduler = None + if self.lr_scheduler_factory is not None: + lr_scheduler = self.lr_scheduler_factory.create_scheduler(optimizer) + return optimizer, lr_scheduler + + @abstractmethod + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + pass + + +class TorchOptimizerFactory(OptimizerFactory): + """General factory for arbitrary torch optimizers.""" + + def __init__(self, optim_class: Callable[..., torch.optim.Optimizer], **kwargs: Any): + """ + + :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), + which will be passed the module parameters, the learning rate as `lr` and the + kwargs provided. + :param kwargs: keyword arguments to provide at optimizer construction + """ + super().__init__() + self.optim_class = optim_class + self.kwargs = kwargs + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return self.optim_class(params, **self.kwargs) + + +class AdamOptimizerFactory(OptimizerFactory): + def __init__( + self, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, + ): + super().__init__() + self.lr = lr + self.weight_decay = weight_decay + self.eps = eps + self.betas = betas + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return Adam( + params, + lr=self.lr, + betas=self.betas, + eps=self.eps, + weight_decay=self.weight_decay, + ) + + +class RMSPropOptimizerFactory(OptimizerFactory): + def __init__( + self, + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-08, + weight_decay: float = 0, + momentum: float = 0, + centered: bool = False, + ): + super().__init__() + self.lr = lr + self.alpha = alpha + self.momentum = momentum + self.centered = centered + self.weight_decay = weight_decay + self.eps = eps + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return RMSprop( + params, + lr=self.lr, + alpha=self.alpha, + eps=self.eps, + weight_decay=self.weight_decay, + momentum=self.momentum, + centered=self.centered, + ) diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py index 59890c75c..66b313c7c 100644 --- a/tianshou/utils/lr_scheduler.py +++ b/tianshou/utils/lr_scheduler.py @@ -1,6 +1,7 @@ import torch +# TODO: We no longer need this class as Algorithm now uses an explicit list class MultipleLRSchedulers: """A wrapper for multiple learning rate schedulers. diff --git a/tianshou/utils/optim.py b/tianshou/utils/optim.py index c69ef71db..ce59edce1 100644 --- a/tianshou/utils/optim.py +++ b/tianshou/utils/optim.py @@ -1,6 +1,3 @@ -from collections.abc import Iterator -from typing import TypeVar - import torch from torch import nn @@ -28,42 +25,3 @@ def optim_step( ) nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm) optim.step() - - -_STANDARD_TORCH_OPTIMIZERS = [ - torch.optim.Adam, - torch.optim.SGD, - torch.optim.RMSprop, - torch.optim.Adadelta, - torch.optim.AdamW, - torch.optim.Adamax, - torch.optim.NAdam, - torch.optim.SparseAdam, - torch.optim.LBFGS, -] - -TOptim = TypeVar("TOptim", bound=torch.optim.Optimizer) - - -def clone_optimizer( - optim: TOptim, - new_params: nn.Parameter | Iterator[nn.Parameter], -) -> TOptim: - """Clone an optimizer to get a new optim instance with new parameters. - - **WARNING**: This is a temporary measure, and should not be used in downstream code! - Once tianshou interfaces have moved to optimizer factories instead of optimizers, - this will be removed. - - :param optim: the optimizer to clone - :param new_params: the new parameters to use - :return: a new optimizer with the same configuration as the old one - """ - optim_class = type(optim) - # custom optimizers may not behave as expected - if optim_class not in _STANDARD_TORCH_OPTIMIZERS: - raise ValueError( - f"Cannot clone optimizer {optim} of type {optim_class}" - f"Currently, only standard torch optimizers are supported.", - ) - return optim_class(new_params, **optim.defaults) From f20dbce1fefcf517d94d54428a5b0ddbbf4dfaf5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 15 Mar 2025 12:43:18 +0100 Subject: [PATCH 062/230] v2: Adapt high-level API (new optimizer handling, remaining adaptations) * Optimizers can now be specified in the algorithm-specific Params objects (via an OptimizerFactoryFactory), and the ExperimentBuilder allows to define the default via method with_optim_default. Previously, only a single optimizer factory was used for all optimizers. * Learning rate schedulers are now specified via parameters ending in "lr_scheduler" rather than "lr_scheduler_factory" and now are given as factory factories (LRSchedulerFactoryFactory) Positive side-effects: * The parameter transformation is greatly simplified, as the old learning rate scheduler handling was quite complex * Some abstractions (e.g. ModuleOpt) are no longer required and have been removed --- CHANGELOG.md | 8 +- examples/atari/atari_ppo_hl.py | 6 +- examples/mujoco/mujoco_a2c_hl.py | 10 +- examples/mujoco/mujoco_npg_hl.py | 6 +- examples/mujoco/mujoco_ppo_hl.py | 6 +- examples/mujoco/mujoco_ppo_hl_multi.py | 4 +- examples/mujoco/mujoco_reinforce_hl.py | 6 +- examples/mujoco/mujoco_trpo_hl.py | 6 +- test/discrete/test_fqf.py | 4 +- tianshou/highlevel/algorithm.py | 198 +++++++------- tianshou/highlevel/experiment.py | 50 +--- tianshou/highlevel/module/actor.py | 21 -- tianshou/highlevel/module/critic.py | 43 --- tianshou/highlevel/module/module_opt.py | 29 -- tianshou/highlevel/optim.py | 41 +-- tianshou/highlevel/params/alpha.py | 21 +- tianshou/highlevel/params/lr_scheduler.py | 31 +-- tianshou/highlevel/params/policy_params.py | 276 +++++++++----------- tianshou/highlevel/params/policy_wrapper.py | 11 +- tianshou/highlevel/trainer.py | 18 +- tianshou/policy/optim.py | 2 +- 21 files changed, 320 insertions(+), 477 deletions(-) delete mode 100644 tianshou/highlevel/module/module_opt.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 450bd2957..b9d7a526c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,7 +113,13 @@ * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` - +* High-Level API changes: + * Detailed optimizer configuration (analogous to the procedural API) is now possible: + * All optimizers can be configured in the respective algorithm-specific `Params` object by using + `OptimizerFactoryFactory` instances as parameter values (e.g. for `optim`, `actor_optim`, `critic_optim`, etc.). + * Learning rate schedulers remain separate parameters and now use `LRSchedulerFactoryFactory` + instances. The respective parameter names now use the suffix `lr_scheduler` instead of `lr_scheduler_factory` + (as the precise nature need not be reflected in the name; brevity is preferable). * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. ## Unreleased diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 06d59d555..6f8f8a6ce 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -16,7 +16,7 @@ ExperimentConfig, PPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, @@ -95,9 +95,7 @@ def main( dual_clip=dual_clip, recompute_advantage=recompute_adv, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, ), ) .with_actor_factory(ActorFactoryAtariDQN(scale_obs=scale_obs, features_only=True)) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index c804d6c26..e3a30f8a4 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -14,8 +14,8 @@ A2CExperimentBuilder, ExperimentConfig, ) -from tianshou.highlevel.optim import OptimizerFactoryRMSprop -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.optim import OptimizerFactoryFactoryRMSprop +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.policy_params import A2CParams @@ -72,13 +72,11 @@ def main( ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, + optim=OptimizerFactoryFactoryRMSprop(eps=1e-5, alpha=0.99), lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, ), ) - .with_optim_factory(OptimizerFactoryRMSprop(eps=1e-5, alpha=0.99)) .with_actor_factory_default(hidden_sizes, nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, nn.Tanh) .build() diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 387f87c6e..fc1f5afb9 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -14,7 +14,7 @@ ExperimentConfig, NPGExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.policy_params import NPGParams @@ -72,9 +72,7 @@ def main( optim_critic_iters=optim_critic_iters, actor_step_size=actor_step_size, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index b10d4cf26..7334e6bfa 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -14,7 +14,7 @@ ExperimentConfig, PPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams @@ -82,9 +82,7 @@ def main( dual_clip=dual_clip, recompute_advantage=recompute_adv, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 333870809..9e63fc59a 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -28,7 +28,7 @@ PPOExperimentBuilder, ) from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.policy_params import PPOParams log = logging.getLogger(__name__) @@ -111,7 +111,7 @@ def main( dual_clip=None, recompute_advantage=True, lr=3e-4, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config), + lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 59a600568..cac066c25 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -14,7 +14,7 @@ ExperimentConfig, PGExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.policy_params import PGParams @@ -64,9 +64,7 @@ def main( action_bound_method=action_bound_method, reward_normalization=rew_norm, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 1ec26bad2..2aa2a24ca 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -14,7 +14,7 @@ ExperimentConfig, TRPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.policy_params import TRPOParams @@ -76,9 +76,7 @@ def main( backtrack_coeff=backtrack_coeff, max_backtracks=max_backtracks, lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 193d7d77c..4cb884e58 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -17,7 +17,7 @@ from tianshou.policy import FQF from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.fqf import FQFPolicy -from tianshou.policy.optim import AdamOptimizerFactory, RMSPropOptimizerFactory +from tianshou.policy.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.trainer.base import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -100,7 +100,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: ) optim = AdamOptimizerFactory(lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) - fraction_optim = RMSPropOptimizerFactory(lr=args.fraction_lr) + fraction_optim = RMSpropOptimizerFactory(lr=args.fraction_lr) policy = FQFPolicy( model=net, fraction_model=fraction_net, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 1ebfb8653..958409cd3 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -19,10 +19,7 @@ TDevice, ) from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory -from tianshou.highlevel.module.module_opt import ( - ActorCriticOpt, -) -from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.optim import OptimizerFactoryFactory from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, @@ -32,7 +29,7 @@ NPGParams, Params, ParamsMixinActorAndDualCritics, - ParamsMixinLearningRateWithScheduler, + ParamsMixinSingleModel, ParamTransformerData, PGParams, PPOParams, @@ -66,11 +63,14 @@ Policy, ) from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.redq import REDQPolicy +from tianshou.policy.modelfree.sac import SACPolicy from tianshou.trainer import OffPolicyTrainer, OnPolicyTrainer, Trainer from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig -from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor CHECKPOINT_DICT_KEY_MODEL = "model" @@ -78,7 +78,7 @@ TParams = TypeVar("TParams", bound=Params) TActorCriticParams = TypeVar( "TActorCriticParams", - bound=Params | ParamsMixinLearningRateWithScheduler, + bound=Params | ParamsMixinSingleModel, ) TActorDualCriticsParams = TypeVar( "TActorDualCriticsParams", @@ -86,7 +86,7 @@ ) TDiscreteCriticOnlyParams = TypeVar( "TDiscreteCriticOnlyParams", - bound=Params | ParamsMixinLearningRateWithScheduler, + bound=Params | ParamsMixinSingleModel, ) TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) TPolicy = TypeVar("TPolicy", bound=Policy) @@ -96,7 +96,7 @@ class AlgorithmFactory(ABC, ToStringMixin): """Factory for the creation of an :class:`Algorithm` instance, its policy, trainer as well as collectors.""" - def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactory): + def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactoryFactory): self.sampling_config = sampling_config self.optim_factory = optim_factory self.algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None @@ -279,7 +279,7 @@ def __init__( params: PGParams, sampling_config: SamplingConfig, actor_factory: ActorFactory, - optim_factory: OptimizerFactory, + optim_factory: OptimizerFactoryFactory, ): super().__init__(sampling_config, optim_factory) self.params = params @@ -287,18 +287,12 @@ def __init__( self.optim_factory = optim_factory def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.lr, - ) + actor = self.actor_factory.create_module(envs, device) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, - optim=actor.optim, - optim_factory=self.optim_factory, + optim_factory_default=self.optim_factory, ), ) dist_fn = self.actor_factory.create_dist_fn(envs) @@ -307,14 +301,13 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: ActorPolicy, kwargs, ["action_scaling", "action_bound_method", "deterministic_eval"], - actor=actor.module, + actor=actor, dist_fn=dist_fn, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) return Reinforce( policy=policy, - optim=actor.optim, **kwargs, ) @@ -330,7 +323,7 @@ def __init__( sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, - optimizer_factory: OptimizerFactory, + optimizer_factory: OptimizerFactoryFactory, ): super().__init__(sampling_config, optim_factory=optimizer_factory) self.params = params @@ -343,32 +336,19 @@ def __init__( def _get_algorithm_class(self) -> type[TAlgorithm]: pass - def create_actor_critic_module_opt( - self, - envs: Environments, - device: TDevice, - lr: float, - ) -> ActorCriticOpt: - actor = self.actor_factory.create_module(envs, device) - critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) - actor_critic = ActorCritic(actor, critic) - optim = self.optim_factory.create_optimizer(actor_critic, lr) - return ActorCriticOpt(actor_critic, optim) - @typing.no_type_check def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: - actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) + actor = self.actor_factory.create_module(envs, device) + critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, - optim_factory=self.optim_factory, - optim=actor_critic.optim, + optim_factory_default=self.optim_factory, ), ) - kwargs["actor"] = actor_critic.actor - kwargs["critic"] = actor_critic.critic - kwargs["optim"] = actor_critic.optim + kwargs["actor"] = actor + kwargs["critic"] = critic kwargs["action_space"] = envs.get_action_space() kwargs["observation_space"] = envs.get_observation_space() kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs) @@ -422,7 +402,7 @@ def __init__( params: TDiscreteCriticOnlyParams, sampling_config: SamplingConfig, model_factory: ModuleFactory, - optim_factory: OptimizerFactory, + optim_factory: OptimizerFactoryFactory, ): super().__init__(sampling_config, optim_factory) self.params = params @@ -446,13 +426,11 @@ def _create_policy( @typing.no_type_check def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: model = self.model_factory.create_module(envs, device) - optim = self.optim_factory.create_optimizer(model, self.params.lr) params_dict = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, - optim=optim, - optim_factory=self.optim_factory, + optim_factory_default=self.optim_factory, ), ) envs.get_type().assert_discrete(self) @@ -461,12 +439,11 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: algorithm_class = self._get_algorithm_class() return algorithm_class( policy=policy, - optim=optim, **params_dict, ) -class DeepQLearningAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DQN]): +class DQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DQN]): def _create_policy( self, model: torch.nn.Module, @@ -488,6 +465,23 @@ def _get_algorithm_class(self) -> type[DQN]: class IQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[IQNParams, IQN]): + def _create_policy( + self, + model: torch.nn.Module, + params: dict, + action_space: gymnasium.spaces.Discrete, + observation_space: gymnasium.spaces.Space, + ) -> TPolicy: + pass + return self._create_policy_from_args( + IQNPolicy, + params, + ["sample_size", "online_sample_size", "target_sample_size"], + model=model, + action_space=action_space, + observation_space=observation_space, + ) + def _get_algorithm_class(self) -> type[IQN]: return IQN @@ -499,7 +493,7 @@ def __init__( sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, - optim_factory: OptimizerFactory, + optim_factory: OptimizerFactoryFactory, ): super().__init__(sampling_config, optim_factory) self.critic_factory = critic_factory @@ -508,41 +502,30 @@ def __init__( self.optim_factory = optim_factory def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.actor_lr, - ) - critic = self.critic_factory.create_module_opt( + actor = self.actor_factory.create_module(envs, device) + critic = self.critic_factory.create_module( envs, device, True, - self.optim_factory, - self.params.critic_lr, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, - optim_factory=self.optim_factory, - actor=actor, - critic1=critic, + optim_factory_default=self.optim_factory, ), ) policy = self._create_policy_from_args( DDPGPolicy, kwargs, - ["action_scaling", "action_bound_method"], - actor=actor.module, + ["exploration_noise", "action_scaling", "action_bound_method"], + actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), ) return DDPG( policy=policy, - policy_optim=actor.optim, - critic=critic.module, - critic_optim=critic.optim, + critic=critic, **kwargs, ) @@ -554,7 +537,7 @@ def __init__( sampling_config: SamplingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, - optim_factory: OptimizerFactory, + optim_factory: OptimizerFactoryFactory, ): super().__init__(sampling_config, optim_factory) self.critic_ensemble_factory = critic_ensemble_factory @@ -564,37 +547,35 @@ def __init__( def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: envs.get_type().assert_continuous(self) - actor = self.actor_factory.create_module_opt( + actor = self.actor_factory.create_module( envs, device, - self.optim_factory, - self.params.actor_lr, ) - critic_ensemble = self.critic_ensemble_factory.create_module_opt( + critic_ensemble = self.critic_ensemble_factory.create_module( envs, device, self.params.ensemble_size, True, - self.optim_factory, - self.params.critic_lr, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, - optim_factory=self.optim_factory, - actor=actor, - critic1=critic_ensemble, + optim_factory_default=self.optim_factory, ), ) action_space = cast(gymnasium.spaces.Box, envs.get_action_space()) - return REDQ( - policy=actor.module, - policy_optim=actor.optim, - critic=critic_ensemble.module, - critic_optim=critic_ensemble.optim, + policy = self._create_policy_from_args( + REDQPolicy, + kwargs, + ["exploration_noise", "deterministic_eval", "action_scaling", "action_bound_method"], + actor=actor, action_space=action_space, observation_space=envs.get_observation_space(), + ) + return REDQ( + policy=policy, + critic=critic_ensemble, **kwargs, ) @@ -611,7 +592,7 @@ def __init__( actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, - optim_factory: OptimizerFactory, + optim_factory: OptimizerFactoryFactory, ): super().__init__(sampling_config, optim_factory) self.params = params @@ -639,54 +620,51 @@ def _create_policy( @typing.no_type_check def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: - actor = self.actor_factory.create_module_opt( - envs, - device, - self.optim_factory, - self.params.actor_lr, - ) + actor = self.actor_factory.create_module(envs, device) use_action_shape = self._get_discrete_last_size_use_action_shape() critic_use_action = self._get_critic_use_action(envs) - critic1 = self.critic1_factory.create_module_opt( + critic1 = self.critic1_factory.create_module( envs, device, - critic_use_action, - self.optim_factory, - self.params.critic1_lr, + use_action=critic_use_action, discrete_last_size_use_action_shape=use_action_shape, ) - critic2 = self.critic2_factory.create_module_opt( + critic2 = self.critic2_factory.create_module( envs, device, - critic_use_action, - self.optim_factory, - self.params.critic2_lr, + use_action=critic_use_action, discrete_last_size_use_action_shape=use_action_shape, ) kwargs = self.params.create_kwargs( ParamTransformerData( envs=envs, device=device, - optim_factory=self.optim_factory, - actor=actor, - critic1=critic1, - critic2=critic2, + optim_factory_default=self.optim_factory, ), ) - policy = self._create_policy(actor.module, envs, kwargs) + policy = self._create_policy(actor, envs, kwargs) algorithm_class = self._get_algorithm_class() return algorithm_class( policy=policy, - policy_optim=actor.optim, - critic=critic1.module, - critic_optim=critic1.optim, - critic2=critic2.module, - critic2_optim=critic2.optim, + critic=critic1, + critic2=critic2, **kwargs, ) class SACAlgorithmFactory(ActorDualCriticsAlgorithmFactory[SACParams, SAC, TPolicy]): + def _create_policy( + self, actor: torch.nn.Module | Actor, envs: Environments, params: dict + ) -> SACPolicy: + return self._create_policy_from_args( + SACPolicy, + params, + ["exploration_noise", "deterministic_eval", "action_scaling", "action_bound_method"], + actor=actor, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + def _get_algorithm_class(self) -> type[SAC]: return SAC @@ -694,6 +672,18 @@ def _get_algorithm_class(self) -> type[SAC]: class DiscreteSACAlgorithmFactory( ActorDualCriticsAlgorithmFactory[DiscreteSACParams, DiscreteSAC, TPolicy] ): + def _create_policy( + self, actor: torch.nn.Module | Actor, envs: Environments, params: dict + ) -> DiscreteSACPolicy: + return self._create_policy_from_args( + DiscreteSACPolicy, + params, + ["deterministic_eval"], + actor=actor, + action_space=envs.get_action_space(), + observation_space=envs.get_observation_space(), + ) + def _get_algorithm_class(self) -> type[DiscreteSAC]: return DiscreteSAC @@ -705,7 +695,7 @@ def _create_policy( return self._create_policy_from_args( DDPGPolicy, params, - ["action_scaling", "action_bound_method"], + ["exploration_noise", "action_scaling", "action_bound_method"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 337c16583..e252e2142 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -42,12 +42,11 @@ A2CAlgorithmFactory, AlgorithmFactory, DDPGAlgorithmFactory, - DeepQLearningAlgorithmFactory, DiscreteSACAlgorithmFactory, + DQNAlgorithmFactory, IQNAlgorithmFactory, NPGAlgorithmFactory, PPOAlgorithmFactory, - RandomActionAlgorithmFactory, REDQAlgorithmFactory, ReinforceAlgorithmFactory, SACAlgorithmFactory, @@ -79,8 +78,8 @@ from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory from tianshou.highlevel.optim import ( - OptimizerFactory, - OptimizerFactoryAdam, + OptimizerFactoryFactory, + OptimizerFactoryFactoryAdam, ) from tianshou.highlevel.params.policy_params import ( A2CParams, @@ -520,7 +519,7 @@ def __init__( self._env_factory = env_factory self._sampling_config = sampling_config self._logger_factory: LoggerFactory | None = None - self._optim_factory: OptimizerFactory | None = None + self._optim_factory: OptimizerFactoryFactory | None = None self._algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() @@ -566,10 +565,13 @@ def with_algorithm_wrapper_factory( self._algorithm_wrapper_factory = algorithm_wrapper_factory return self - def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: - """Allows to customize the gradient-based optimizer to use. + def with_optim_default(self, optim_factory: OptimizerFactoryFactory) -> Self: + """Allows to customize the default optimizer to use. - By default, :class:`OptimizerFactoryAdam` will be used with default parameters. + The default optimizer applies when optimizer factory factories are set to None + in algorithm parameter objects. + + By default, :class:`OptimizerFactoryFactoryAdam` will be used with default parameters. :param optim_factory: the optimizer factory :return: the builder @@ -577,23 +579,6 @@ def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: self._optim_factory = optim_factory return self - def with_optim_factory_default( - self, - # Keep values in sync with default values in OptimizerFactoryAdam - betas: tuple[float, float] = (0.9, 0.999), - eps: float = 1e-08, - weight_decay: float = 0, - ) -> Self: - """Configures the use of the default optimizer, Adam, with the given parameters. - - :param betas: coefficients used for computing running averages of gradient and its square - :param eps: term added to the denominator to improve numerical stability - :param weight_decay: weight decay (L2 penalty) - :return: the builder - """ - self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) - return self - def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self: """Allows to define a callback function which is called at the beginning of every epoch during training. @@ -640,10 +625,9 @@ def with_name( def _create_algorithm_factory(self) -> AlgorithmFactory: pass - def _get_optim_factory(self) -> OptimizerFactory: + def _get_optim_factory(self) -> OptimizerFactoryFactory: if self._optim_factory is None: - # same mechanism as in `with_optim_factory_default` - return OptimizerFactoryAdam() + return OptimizerFactoryFactoryAdam() else: return self._optim_factory @@ -690,14 +674,6 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: return ExperimentCollection(seeded_experiments) -class RandomActionExperimentBuilder(ExperimentBuilder): - def _create_algorithm_factory(self) -> RandomActionAlgorithmFactory: - return RandomActionAlgorithmFactory( - sampling_config=self.sampling_config, - optim_factory=self._get_optim_factory(), - ) - - class _BuilderMixinActorFactory(ActorFutureProviderProtocol): def __init__(self, continuous_actor_type: ContinuousActorType): self._continuous_actor_type = continuous_actor_type @@ -1222,7 +1198,7 @@ def with_model_factory_default( return self def _create_algorithm_factory(self) -> AlgorithmFactory: - return DeepQLearningAlgorithmFactory( + return DQNAlgorithmFactory( self._params, self._sampling_config, self._model_factory, diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index ceb1262f7..ca73dc2a4 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -18,8 +18,6 @@ IntermediateModule, IntermediateModuleFactory, ) -from tianshou.highlevel.module.module_opt import ModuleOpt -from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.dist_fn import ( DistributionFunctionFactoryCategorical, DistributionFunctionFactoryIndependentGaussians, @@ -60,25 +58,6 @@ def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: if the actor does not output distribution parameters """ - def create_module_opt( - self, - envs: Environments, - device: TDevice, - optim_factory: OptimizerFactory, - lr: float, - ) -> ModuleOpt: - """Creates the actor module along with its optimizer for the given learning rate. - - :param envs: the environments - :param device: the torch device - :param optim_factory: the optimizer factory - :param lr: the learning rate - :return: a container with the actor module and its optimizer - """ - module = self.create_module(envs, device) - optim = optim_factory.create_optimizer(module, lr) - return ModuleOpt(module, optim) - @staticmethod def _init_linear(actor: torch.nn.Module) -> None: """Initializes linear layers of an actor module using default mechanisms. diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 0352fd132..35f5f9483 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -8,8 +8,6 @@ from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.actor import ActorFuture from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal -from tianshou.highlevel.module.module_opt import ModuleOpt -from tianshou.highlevel.optim import OptimizerFactory from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net @@ -34,34 +32,6 @@ def create_module( :return: the module """ - def create_module_opt( - self, - envs: Environments, - device: TDevice, - use_action: bool, - optim_factory: OptimizerFactory, - lr: float, - discrete_last_size_use_action_shape: bool = False, - ) -> ModuleOpt: - """Creates the critic module along with its optimizer for the given learning rate. - - :param envs: the environments - :param device: the torch device - :param use_action: whether to expect the action as an additional input (in addition to the observations) - :param optim_factory: the optimizer factory - :param lr: the learning rate - :param discrete_last_size_use_action_shape: whether, for the discrete case, the output dimension shall use the action shape - :return: - """ - module = self.create_module( - envs, - device, - use_action, - discrete_last_size_use_action_shape=discrete_last_size_use_action_shape, - ) - opt = optim_factory.create_optimizer(module, lr) - return ModuleOpt(module, opt) - class CriticFactoryDefault(CriticFactory): """A critic factory which, depending on the type of environment, creates a suitable MLP-based critic.""" @@ -223,19 +193,6 @@ def create_module( ) -> nn.Module: pass - def create_module_opt( - self, - envs: Environments, - device: TDevice, - ensemble_size: int, - use_action: bool, - optim_factory: OptimizerFactory, - lr: float, - ) -> ModuleOpt: - module = self.create_module(envs, device, ensemble_size, use_action) - opt = optim_factory.create_optimizer(module, lr) - return ModuleOpt(module, opt) - class CriticEnsembleFactoryDefault(CriticEnsembleFactory): """A critic ensemble factory which, depending on the type of environment, creates a suitable MLP-based critic.""" diff --git a/tianshou/highlevel/module/module_opt.py b/tianshou/highlevel/module/module_opt.py deleted file mode 100644 index 558686aa9..000000000 --- a/tianshou/highlevel/module/module_opt.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass - -import torch - -from tianshou.utils.net.common import ActorCritic - - -@dataclass -class ModuleOpt: - """Container for a torch module along with its optimizer.""" - - module: torch.nn.Module - optim: torch.optim.Optimizer - - -@dataclass -class ActorCriticOpt: - """Container for an :class:`ActorCritic` instance along with its optimizer.""" - - actor_critic_module: ActorCritic - optim: torch.optim.Optimizer - - @property - def actor(self) -> torch.nn.Module: - return self.actor_critic_module.actor - - @property - def critic(self) -> torch.nn.Module: - return self.actor_critic_module.critic diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index d480978fb..3a63cd5f0 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -4,7 +4,13 @@ import torch from sensai.util.string import ToStringMixin -from torch.optim import Adam, RMSprop + +from tianshou.policy.optim import ( + AdamOptimizerFactory, + OptimizerFactory, + RMSpropOptimizerFactory, + TorchOptimizerFactory, +) TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] @@ -14,20 +20,17 @@ def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Opt pass -class OptimizerFactory(ABC, ToStringMixin): - def create_optimizer( - self, - module: torch.nn.Module, - lr: float, - ) -> torch.optim.Optimizer: - return self.create_optimizer_for_params(module.parameters(), lr) +class OptimizerFactoryFactory(ABC, ToStringMixin): + @staticmethod + def default() -> "OptimizerFactoryFactory": + return OptimizerFactoryFactoryAdam() @abstractmethod - def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + def create_optimizer_factory(self, lr: float) -> OptimizerFactory: pass -class OptimizerFactoryTorch(OptimizerFactory): +class OptimizerFactoryFactoryTorch(OptimizerFactoryFactory): def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any): """Factory for torch optimizers. @@ -39,11 +42,11 @@ def __init__(self, optim_class: OptimizerWithLearningRateProtocol, **kwargs: Any self.optim_class = optim_class self.kwargs = kwargs - def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: - return self.optim_class(params, lr=lr, **self.kwargs) + def create_optimizer_factory(self, lr: float) -> OptimizerFactory: + return TorchOptimizerFactory(optim_class=self.optim_class, lr=lr) -class OptimizerFactoryAdam(OptimizerFactory): +class OptimizerFactoryFactoryAdam(OptimizerFactoryFactory): # Note: currently used as default optimizer # values should be kept in sync with `ExperimentBuilder.with_optim_factory_default` def __init__( @@ -56,9 +59,8 @@ def __init__( self.eps = eps self.betas = betas - def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: - return Adam( - params, + def create_optimizer_factory(self, lr: float) -> AdamOptimizerFactory: + return AdamOptimizerFactory( lr=lr, betas=self.betas, eps=self.eps, @@ -66,7 +68,7 @@ def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim ) -class OptimizerFactoryRMSprop(OptimizerFactory): +class OptimizerFactoryFactoryRMSprop(OptimizerFactoryFactory): def __init__( self, alpha: float = 0.99, @@ -81,9 +83,8 @@ def __init__( self.weight_decay = weight_decay self.eps = eps - def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: - return RMSprop( - params, + def create_optimizer_factory(self, lr: float) -> RMSpropOptimizerFactory: + return RMSpropOptimizerFactory( lr=lr, alpha=self.alpha, eps=self.eps, diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 5f42d10e1..fc23baeb4 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -3,10 +3,11 @@ import numpy as np import torch from sensai.util.string import ToStringMixin +from torch.nn import ParameterList from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice -from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.optim import OptimizerFactoryFactory from tianshou.policy.modelfree.sac import AutoAlpha @@ -15,14 +16,18 @@ class AutoAlphaFactory(ToStringMixin, ABC): def create_auto_alpha( self, envs: Environments, - optim_factory: OptimizerFactory, device: TDevice, ) -> AutoAlpha: pass class AutoAlphaFactoryDefault(AutoAlphaFactory): - def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0): + def __init__( + self, + lr: float = 3e-4, + target_entropy_coefficient: float = -1.0, + optimizer: OptimizerFactoryFactory | None = None, + ): """ :param lr: the learning rate for the optimizer of the alpha parameter :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; @@ -31,14 +36,15 @@ def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0): spaces respectively, which gives a reasonable trade-off between exploration and exploitation. For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); 1.0 would give full entropy exploration. + :param optimizer: the optimizer factory to use; if None, use default """ self.lr = lr self.target_entropy_coefficient = target_entropy_coefficient + self.optimizer_factory_factory = optimizer or OptimizerFactoryFactory.default() def create_auto_alpha( self, envs: Environments, - optim_factory: OptimizerFactory, device: TDevice, ) -> AutoAlpha: action_dim = np.prod(envs.get_action_shape()) @@ -47,5 +53,8 @@ def create_auto_alpha( else: target_entropy = self.target_entropy_coefficient * np.log(action_dim) log_alpha = torch.zeros(1, requires_grad=True, device=device) - alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) - return AutoAlpha(target_entropy, log_alpha, alpha_optim) + optim_factory = self.optimizer_factory_factory.create_optimizer_factory(lr=self.lr) + optim, lr_scheduler = optim_factory.create_instances(ParameterList([log_alpha])) + if lr_scheduler is not None: + raise ValueError("Learning rate schedulers are not supported for AutoAlpha") + return AutoAlpha(target_entropy, log_alpha, optim) diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 09c4c4261..bfd9cd76b 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -1,35 +1,26 @@ from abc import ABC, abstractmethod -import numpy as np -import torch from sensai.util.string import ToStringMixin -from torch.optim.lr_scheduler import LambdaLR, LRScheduler from tianshou.highlevel.config import SamplingConfig +from tianshou.policy.optim import LRSchedulerFactory, LRSchedulerFactoryLinear -class LRSchedulerFactory(ToStringMixin, ABC): - """Factory for the creation of a learning rate scheduler.""" +class LRSchedulerFactoryFactory(ToStringMixin, ABC): + """Factory for the creation of a learning rate scheduler factory.""" @abstractmethod - def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + def create_lr_scheduler_factory(self) -> LRSchedulerFactory: pass -class LRSchedulerFactoryLinear(LRSchedulerFactory): +class LRSchedulerFactoryFactoryLinear(LRSchedulerFactoryFactory): def __init__(self, sampling_config: SamplingConfig): self.sampling_config = sampling_config - def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - return LambdaLR(optim, lr_lambda=self._LRLambda(self.sampling_config).compute) - - class _LRLambda: - def __init__(self, sampling_config: SamplingConfig): - assert sampling_config.step_per_collect is not None - self.max_update_num = ( - np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect) - * sampling_config.num_epochs - ) - - def compute(self, epoch: int) -> float: - return 1.0 - epoch / self.max_update_num + def create_lr_scheduler_factory(self) -> LRSchedulerFactory: + return LRSchedulerFactoryLinear( + num_epochs=self.sampling_config.num_epochs, + step_per_epoch=self.sampling_config.step_per_epoch, + step_per_collect=self.sampling_config.step_per_collect, + ) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index d20bbe44b..0769537c7 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -1,23 +1,19 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import asdict, dataclass from typing import Any, Literal, Protocol -import torch from sensai.util.pickle import setstate from sensai.util.string import ToStringMixin -from torch.optim.lr_scheduler import LRScheduler from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice -from tianshou.highlevel.module.module_opt import ModuleOpt -from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.optim import OptimizerFactoryFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactory from tianshou.highlevel.params.noise import NoiseFactory -from tianshou.utils import MultipleLRSchedulers @dataclass(kw_only=True) @@ -30,12 +26,7 @@ class ParamTransformerData: envs: Environments device: TDevice - optim_factory: OptimizerFactory - optim: torch.optim.Optimizer | None = None - """the single optimizer for the case where there is just one""" - actor: ModuleOpt | None = None - critic1: ModuleOpt | None = None - critic2: ModuleOpt | None = None + optim_factory_default: OptimizerFactoryFactory class ParamTransformer(ABC): @@ -52,8 +43,18 @@ def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: pass @staticmethod - def get(d: dict[str, Any], key: str, drop: bool = False) -> Any: - value = d[key] + def get( + d: dict[str, Any], + key: str, + drop: bool = False, + default_factory: Callable[[], Any] | None = None, + ) -> Any: + try: + value = d[key] + except KeyError as e: + raise Exception(f"Key not found: '{key}'; available keys: {list(d.keys())}") from e + if value is None and default_factory is not None: + value = default_factory() if drop: del d[key] return value @@ -68,6 +69,17 @@ def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: del kwargs[k] +class ParamTransformerRename(ParamTransformer): + def __init__(self, renamed_params: dict[str, str]): + self.renamed_params = renamed_params + + def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: + for old_name, new_name in self.renamed_params.items(): + v = kwargs[old_name] + del kwargs[old_name] + kwargs[new_name] = v + + class ParamTransformerChangeValue(ParamTransformer): def __init__(self, key: str): self.key = key @@ -80,105 +92,42 @@ def change_value(self, value: Any, data: ParamTransformerData) -> Any: pass -class ParamTransformerLRScheduler(ParamTransformer): +class ParamTransformerOptimFactory(ParamTransformer): """Transformer for learning rate scheduler params. Transforms a key containing a learning rate scheduler factory (removed) into a key containing a learning rate scheduler (added) for the data member `optim`. """ - def __init__(self, key_scheduler_factory: str, key_scheduler: str): - self.key_scheduler_factory = key_scheduler_factory - self.key_scheduler = key_scheduler - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - assert data.optim is not None - factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True) - params[self.key_scheduler] = ( - factory.create_scheduler(data.optim) if factory is not None else None - ) - - -class ParamTransformerMultiLRScheduler(ParamTransformer): - def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str): - """Transforms several scheduler factories into a single scheduler. - - The result may be a `MultipleLRSchedulers` instance if more than one factory is indeed given. - - :param optim_key_list: a list of tuples (optimizer, key of learning rate factory) - :param key_scheduler: the key under which to store the resulting learning rate scheduler - """ - self.optim_key_list = optim_key_list - self.key_scheduler = key_scheduler - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - lr_schedulers = [] - for optim, lr_scheduler_factory_key in self.optim_key_list: - lr_scheduler_factory: LRSchedulerFactory | None = self.get( - params, - lr_scheduler_factory_key, - drop=True, - ) - if lr_scheduler_factory is not None: - lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim)) - lr_scheduler: LRScheduler | MultipleLRSchedulers | None - match len(lr_schedulers): - case 0: - lr_scheduler = None - case 1: - lr_scheduler = lr_schedulers[0] - case _: - lr_scheduler = MultipleLRSchedulers(*lr_schedulers) - params[self.key_scheduler] = lr_scheduler - - -class ParamTransformerActorAndCriticLRScheduler(ParamTransformer): def __init__( self, - key_scheduler_factory_actor: str, - key_scheduler_factory_critic: str, - key_scheduler: str, + key_optim_factory_factory, + key_lr: str, + key_lr_scheduler_factory_factory: str, + key_optim_output: str, ): - self.key_factory_actor = key_scheduler_factory_actor - self.key_factory_critic = key_scheduler_factory_critic - self.key_scheduler = key_scheduler + self.key_optim_factory_factory = key_optim_factory_factory + self.key_lr = key_lr + self.key_scheduler_factory = key_lr_scheduler_factory_factory + self.key_optim_output = key_optim_output def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - assert data.actor is not None and data.critic1 is not None - transformer = ParamTransformerMultiLRScheduler( - [ - (data.actor.optim, self.key_factory_actor), - (data.critic1.optim, self.key_factory_critic), - ], - self.key_scheduler, + optim_factory_factory: OptimizerFactoryFactory = self.get( + params, + self.key_optim_factory_factory, + drop=True, + default_factory=lambda: data.optim_factory_default, ) - transformer.transform(params, data) - - -class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer): - def __init__( - self, - key_scheduler_factory_actor: str, - key_scheduler_factory_critic1: str, - key_scheduler_factory_critic2: str, - key_scheduler: str, - ): - self.key_factory_actor = key_scheduler_factory_actor - self.key_factory_critic1 = key_scheduler_factory_critic1 - self.key_factory_critic2 = key_scheduler_factory_critic2 - self.key_scheduler = key_scheduler - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - assert data.actor is not None and data.critic1 is not None and data.critic2 is not None - transformer = ParamTransformerMultiLRScheduler( - [ - (data.actor.optim, self.key_factory_actor), - (data.critic1.optim, self.key_factory_critic1), - (data.critic2.optim, self.key_factory_critic2), - ], - self.key_scheduler, + lr_scheduler_factory_factory: LRSchedulerFactoryFactory | None = self.get( + params, self.key_scheduler_factory, drop=True ) - transformer.transform(params, data) + lr: float = self.get(params, self.key_lr, drop=True) + optim_factory = optim_factory_factory.create_optimizer_factory(lr) + if lr_scheduler_factory_factory is not None: + optim_factory.with_lr_scheduler_factory( + lr_scheduler_factory_factory.create_lr_scheduler_factory() + ) + params[self.key_optim_output] = optim_factory class ParamTransformerAutoAlpha(ParamTransformer): @@ -188,7 +137,7 @@ def __init__(self, key: str): def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: alpha = self.get(kwargs, self.key) if isinstance(alpha, AutoAlphaFactory): - kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device) + kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.device) class ParamTransformerNoiseFactory(ParamTransformerChangeValue): @@ -218,7 +167,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: pass -@dataclass +@dataclass(kw_only=True) class Params(GetParamTransformersProtocol, ToStringMixin): def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: params = asdict(self) @@ -230,43 +179,48 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return [] -@dataclass -class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol): +@dataclass(kw_only=True) +class ParamsMixinSingleModel(GetParamTransformersProtocol): + optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the model's optimizer; if None, use default""" lr: float = 1e-3 """the learning rate to use in the gradient-based optimizer""" - lr_scheduler_factory: LRSchedulerFactory | None = None + lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler""" def _get_param_transformers(self) -> list[ParamTransformer]: return [ - ParamTransformerDrop("lr"), - ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"), + ParamTransformerOptimFactory("optim", "lr", "lr_scheduler", "optim"), ] -@dataclass +@dataclass(kw_only=True) class ParamsMixinActorAndCritic(GetParamTransformersProtocol): + actor_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the actor's optimizer; if None, use default""" + critic_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the critic's optimizer; if None, use default""" actor_lr: float = 1e-3 """the learning rate to use for the actor network""" critic_lr: float = 1e-3 """the learning rate to use for the critic network""" - actor_lr_scheduler_factory: LRSchedulerFactory | None = None + actor_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" - critic_lr_scheduler_factory: LRSchedulerFactory | None = None + critic_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the critic network (if any)""" def _get_param_transformers(self) -> list[ParamTransformer]: return [ - ParamTransformerDrop("actor_lr", "critic_lr"), - ParamTransformerActorAndCriticLRScheduler( - "actor_lr_scheduler_factory", - "critic_lr_scheduler_factory", - "lr_scheduler", + ParamTransformerOptimFactory( + "actor_optim", "actor_lr", "actor_lr_scheduler", "policy_optim" + ), + ParamTransformerOptimFactory( + "critic_optim", "critic_lr", "critic_lr_scheduler", "critic_optim" ), ] -@dataclass +@dataclass(kw_only=True) class ParamsMixinActionScaling(GetParamTransformersProtocol): action_scaling: bool | Literal["default"] = "default" """whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces""" @@ -279,7 +233,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return [ParamTransformerActionScaling("action_scaling")] -@dataclass +@dataclass(kw_only=True) class ParamsMixinExplorationNoise(GetParamTransformersProtocol): exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None """ @@ -293,8 +247,8 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return [ParamTransformerNoiseFactory("exploration_noise")] -@dataclass -class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithScheduler): +@dataclass(kw_only=True) +class PGParams(Params, ParamsMixinActionScaling, ParamsMixinSingleModel): discount_factor: float = 0.99 """ discount factor (gamma) for future rewards; must be in [0, 1] @@ -316,11 +270,11 @@ def __setstate__(self, state: dict[str, Any]) -> None: def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) - transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) + transformers.extend(ParamsMixinSingleModel._get_param_transformers(self)) return transformers -@dataclass +@dataclass(kw_only=True) class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): gae_lambda: float = 0.95 """ @@ -335,7 +289,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return [] -@dataclass +@dataclass(kw_only=True) class A2CParams(PGParams, ParamsMixinGeneralAdvantageEstimation): vf_coef: float = 0.5 """weight (coefficient) of the value loss in the loss function""" @@ -350,7 +304,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return transformers -@dataclass +@dataclass(kw_only=True) class PPOParams(A2CParams): eps_clip: float = 0.2 """ @@ -397,7 +351,7 @@ class PPOParams(A2CParams): """ -@dataclass +@dataclass(kw_only=True) class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation): optim_critic_iters: int = 5 """number of times to optimize critic network per update.""" @@ -412,7 +366,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return transformers -@dataclass +@dataclass(kw_only=True) class TRPOParams(NPGParams): max_kl: float = 0.01 """ @@ -426,34 +380,42 @@ class TRPOParams(NPGParams): """maximum number of times to backtrack in line search when the constraints are not met.""" -@dataclass +@dataclass(kw_only=True) class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol): + actor_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the actor's optimizer; if None, use default""" + critic1_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the first critic's optimizer; if None, use default""" + critic2_optim: OptimizerFactoryFactory | None = None + """the factory for the creation of the second critic's optimizer; if None, use default""" actor_lr: float = 1e-3 """the learning rate to use for the actor network""" critic1_lr: float = 1e-3 """the learning rate to use for the first critic network""" critic2_lr: float = 1e-3 """the learning rate to use for the second critic network""" - actor_lr_scheduler_factory: LRSchedulerFactory | None = None + actor_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the actor network (if any)""" - critic1_lr_scheduler_factory: LRSchedulerFactory | None = None + critic1_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the first critic network (if any)""" - critic2_lr_scheduler_factory: LRSchedulerFactory | None = None + critic2_lr_scheduler: LRSchedulerFactoryFactory | None = None """factory for the creation of a learning rate scheduler to use for the second critic network (if any)""" def _get_param_transformers(self) -> list[ParamTransformer]: return [ - ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"), - ParamTransformerActorDualCriticsLRScheduler( - "actor_lr_scheduler_factory", - "critic1_lr_scheduler_factory", - "critic2_lr_scheduler_factory", - "lr_scheduler", + ParamTransformerOptimFactory( + "actor_optim", "actor_lr", "actor_lr_scheduler", "policy_optim" + ), + ParamTransformerOptimFactory( + "critic1_optim", "critic1_lr", "critic1_lr_scheduler", "critic_optim" + ), + ParamTransformerOptimFactory( + "critic2_optim", "critic2_lr", "critic2_lr_scheduler", "critic2_optim" ), ] -@dataclass +@dataclass(kw_only=True) class _SACParams(Params, ParamsMixinActorAndDualCritics): tau: float = 0.005 """controls the contribution of the entropy term in the overall optimization objective, @@ -481,12 +443,12 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return transformers -@dataclass +@dataclass(kw_only=True) class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling): deterministic_eval: bool = True """ - whether to use deterministic action (mean of Gaussian policy) in evaluation mode instead of stochastic - action sampled by the policy. Does not affect training.""" + whether to use deterministic action (mode of Gaussian policy) in evaluation mode instead of stochastic + action sampled from the distribution. Does not affect training.""" def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() @@ -495,13 +457,16 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return transformers -@dataclass +@dataclass(kw_only=True) class DiscreteSACParams(_SACParams): - pass + deterministic_eval: bool = True + """ + whether to use deterministic action (most probably action) in evaluation mode instead of stochastic + action sampled from the distribution. Does not affect training.""" -@dataclass -class DQNParams(Params, ParamsMixinLearningRateWithScheduler): +@dataclass(kw_only=True) +class QLearningOffPolicyParams(Params, ParamsMixinSingleModel): discount_factor: float = 0.99 """ discount factor (gamma) for future rewards; must be in [0, 1] @@ -512,6 +477,15 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler): """the target network update frequency (0 if no target network is to be used)""" reward_normalization: bool = False """whether to normalize the returns to Normal(0, 1)""" + + def _get_param_transformers(self) -> list[ParamTransformer]: + transformers = super()._get_param_transformers() + transformers.extend(ParamsMixinSingleModel._get_param_transformers(self)) + return transformers + + +@dataclass(kw_only=True) +class DQNParams(QLearningOffPolicyParams): is_double: bool = True """whether to use double Q learning""" clip_loss_grad: bool = False @@ -519,13 +493,11 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler): loss instead of the MSE loss.""" def _get_param_transformers(self) -> list[ParamTransformer]: - transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self)) - return transformers + return super()._get_param_transformers() -@dataclass -class IQNParams(DQNParams): +@dataclass(kw_only=True) +class IQNParams(QLearningOffPolicyParams): sample_size: int = 32 """the number of samples for policy evaluation""" online_sample_size: int = 8 @@ -545,7 +517,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return transformers -@dataclass +@dataclass(kw_only=True) class DDPGParams( Params, ParamsMixinActorAndCritic, @@ -571,7 +543,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return transformers -@dataclass +@dataclass(kw_only=True) class REDQParams(DDPGParams): ensemble_size: int = 10 """the number of sub-networks in the critic ensemble""" @@ -602,7 +574,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return transformers -@dataclass +@dataclass(kw_only=True) class TD3Params( Params, ParamsMixinActorAndDualCritics, diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 51a4438fc..2d808eccd 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -7,7 +7,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory -from tianshou.highlevel.optim import OptimizerFactory +from tianshou.highlevel.optim import OptimizerFactoryFactory from tianshou.policy import Algorithm, ICMOffPolicyWrapper from tianshou.policy.base import OffPolicyAlgorithm, OnPolicyAlgorithm from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper @@ -22,7 +22,7 @@ def create_wrapped_algorithm( self, policy: Algorithm, envs: Environments, - optim_factory: OptimizerFactory, + optim_factory: OptimizerFactoryFactory, device: TDevice, ) -> TAlgorithmOut: pass @@ -40,6 +40,7 @@ def __init__( lr_scale: float, reward_scale: float, forward_loss_weight: float, + optim: OptimizerFactoryFactory | None = None, ): self.feature_net_factory = feature_net_factory self.hidden_sizes = hidden_sizes @@ -47,12 +48,13 @@ def __init__( self.lr_scale = lr_scale self.reward_scale = reward_scale self.forward_loss_weight = forward_loss_weight + self.optim_factory = optim def create_wrapped_algorithm( self, algorithm: Algorithm, envs: Environments, - optim_factory: OptimizerFactory, + optim_factory_default: OptimizerFactoryFactory, device: TDevice, ) -> ICMOffPolicyWrapper: feature_net = self.feature_net_factory.create_intermediate_module(envs, device) @@ -67,7 +69,8 @@ def create_wrapped_algorithm( hidden_sizes=self.hidden_sizes, device=device, ) - icm_optim = optim_factory.create_optimizer(icm_net, lr=self.lr) + optim_factory = self.optim_factory or optim_factory_default + icm_optim = optim_factory.create_optimizer_factory(lr=self.lr) cls: type[ICMOffPolicyWrapper] | type[ICMOnPolicyWrapper] if isinstance(algorithm, OffPolicyAlgorithm): cls = ICMOffPolicyWrapper diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 4e1397e25..93dbe5c0b 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -10,13 +10,13 @@ from tianshou.highlevel.logger import TLogger from tianshou.policy import DQN, Algorithm -TPolicy = TypeVar("TPolicy", bound=Algorithm) +TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) log = logging.getLogger(__name__) class TrainingContext: - def __init__(self, policy: TPolicy, envs: Environments, logger: TLogger): - self.policy = policy + def __init__(self, algorithm: TAlgorithm, envs: Environments, logger: TLogger): + self.algorithm = algorithm self.envs = envs self.logger = logger @@ -90,8 +90,8 @@ def __init__(self, eps_test: float): self.eps_test = eps_test def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy = cast(DQN, context.policy) - policy.set_eps(self.eps_test) + algorithm = cast(DQN, context.algorithm) + algorithm.policy.set_eps(self.eps_test) class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback): @@ -105,7 +105,7 @@ def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = self.decay_steps = decay_steps def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: - policy = cast(DQN, context.policy) + algorithm = cast(DQN, context.algorithm) logger = context.logger if env_step <= self.decay_steps: eps = self.eps_train - env_step / self.decay_steps * ( @@ -113,7 +113,7 @@ def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: ) else: eps = self.eps_train_final - policy.set_eps(eps) + algorithm.policy.set_eps(eps) logger.write("train/env_step", env_step, {"train/eps": eps}) @@ -126,8 +126,8 @@ def __init__(self, eps_test: float): self.eps_test = eps_test def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: - policy = cast(DQN, context.policy) - policy.set_eps(self.eps_test) + algorithm = cast(DQN, context.algorithm) + algorithm.policy.set_eps(self.eps_test) class EpochStopCallbackRewardThreshold(EpochStopCallback): diff --git a/tianshou/policy/optim.py b/tianshou/policy/optim.py index ab1ae67ac..a03c95871 100644 --- a/tianshou/policy/optim.py +++ b/tianshou/policy/optim.py @@ -109,7 +109,7 @@ def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimi ) -class RMSPropOptimizerFactory(OptimizerFactory): +class RMSpropOptimizerFactory(OptimizerFactory): def __init__( self, lr: float = 1e-2, From cd777248ab72cbdc5fdf5e2ac047477a67cf2b3c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 12:44:00 +0100 Subject: [PATCH 063/230] v2: Adapt MuJoCo examples --- examples/mujoco/fetch_her_ddpg.py | 59 ++++++++++++----------- examples/mujoco/mujoco_a2c.py | 75 +++++++++++++++-------------- examples/mujoco/mujoco_ddpg.py | 58 ++++++++++++---------- examples/mujoco/mujoco_npg.py | 74 +++++++++++++++------------- examples/mujoco/mujoco_ppo.py | 74 +++++++++++++++------------- examples/mujoco/mujoco_redq.py | 58 ++++++++++++---------- examples/mujoco/mujoco_reinforce.py | 72 ++++++++++++++------------- examples/mujoco/mujoco_sac.py | 58 ++++++++++++---------- examples/mujoco/mujoco_td3.py | 60 ++++++++++++----------- examples/mujoco/mujoco_trpo.py | 68 ++++++++++++++------------ tianshou/policy/modelfree/redq.py | 2 +- 11 files changed, 358 insertions(+), 300 deletions(-) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 3f2c24e32..887adccec 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -10,7 +10,6 @@ import numpy as np import torch - from tianshou.data import ( Collector, CollectStats, @@ -19,15 +18,17 @@ ReplayBuffer, VectorReplayBuffer, ) -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated +from tianshou.env.venvs import BaseVectorEnv from tianshou.exploration import GaussianNoise +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPG from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic -from tianshou.env.venvs import BaseVectorEnv from tianshou.utils.space_info import ActionSpaceInfo @@ -159,7 +160,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: max_action=args.max_action, device=args.device, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = dict_state_dec(Net)( flat_state_shape, action_shape=args.action_shape, @@ -168,22 +169,25 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy: DDPG = DDPG( + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) + policy = DDPGPolicy( actor=actor, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + algorithm = DDPG( + policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), estimation_step=args.n_step, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector @@ -212,8 +216,8 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: horizon=args.her_horizon, future_k=args.her_future_k, ) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -221,21 +225,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index e5a18e1ac..96f8d173f 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -10,13 +10,14 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import A2C from tianshou.policy.base import Algorithm -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -70,7 +71,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_a2c(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -123,45 +124,48 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.RMSprop( - actor_critic.parameters(), + optim = RMSpropOptimizerFactory( lr=args.lr, eps=1e-5, alpha=0.99, ) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + num_epochs=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: A2C = A2C( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=True, + action_bound_method=args.bound_action_method, + action_space=env.action_space, + ) + algorithm: A2C = A2C( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, ) # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) @@ -172,8 +176,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -201,21 +205,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! @@ -226,4 +231,4 @@ def save_best_fn(policy: Algorithm) -> None: if __name__ == "__main__": - test_a2c() + main() diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 869267d91..b76607e2b 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -14,7 +14,9 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPG from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -64,7 +66,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_ddpg(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -87,7 +89,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -96,22 +98,25 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic = Critic(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy: DDPG = DDPG( + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) + policy = DDPGPolicy( actor=actor, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + algorithm: DDPG = DDPG( + policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), estimation_step=args.n_step, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector @@ -120,8 +125,8 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -150,21 +155,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! @@ -175,4 +181,4 @@ def save_best_fn(policy: Algorithm) -> None: if __name__ == "__main__": - test_ddpg() + main() diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index bcb55ba59..ce556e40a 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -10,13 +10,14 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import NPG from tianshou.policy.base import Algorithm -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -75,7 +76,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_npg(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -126,30 +127,34 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(critic.parameters(), lr=args.lr) - lr_scheduler = None + optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + num_epochs=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: NPG = NPG( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=True, + action_bound_method=args.bound_action_method, + action_space=env.action_space, + ) + algorithm: NPG = NPG( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, advantage_normalization=args.norm_adv, optim_critic_iters=args.optim_critic_iters, actor_step_size=args.actor_step_size, @@ -158,7 +163,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) @@ -169,8 +174,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -198,21 +203,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! @@ -223,4 +229,4 @@ def save_best_fn(policy: Algorithm) -> None: if __name__ == "__main__": - test_npg() + main() diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 08a802470..13a333d9d 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -10,13 +10,14 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPO from tianshou.policy.base import Algorithm -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -75,7 +76,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_ppo(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -128,34 +129,38 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + num_epochs=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: PPO = PPO( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=True, + action_bound_method=args.bound_action_method, + action_space=env.action_space, + ) + algorithm: PPO = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, @@ -166,7 +171,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) @@ -177,8 +182,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -206,21 +211,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! @@ -231,4 +237,4 @@ def save_best_fn(policy: Algorithm) -> None: if __name__ == "__main__": - test_ppo() + main() diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 2391a81a5..e8ba562e6 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -13,7 +13,9 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import REDQ from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.redq import REDQPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -68,7 +70,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_redq(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -94,7 +96,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) def linear(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(args.ensemble_size, x, y) @@ -113,7 +115,7 @@ def linear(x: int, y: int) -> EnsembleLinear: linear_layer=linear, flatten_input=False, ).to(args.device) - critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr) + critics_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) @@ -121,8 +123,12 @@ def linear(x: int, y: int) -> EnsembleLinear: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: REDQ = REDQ( - policy=actor, + policy = REDQPolicy( + actor=actor, + action_space=env.action_space, + ) + algorithm: REDQ = REDQ( + policy=policy, policy_optim=actor_optim, critic=critics, critic_optim=critics_optim, @@ -134,12 +140,11 @@ def linear(x: int, y: int) -> EnsembleLinear: estimation_step=args.n_step, actor_delay=args.update_per_step, target_mode=args.target_mode, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector @@ -148,8 +153,8 @@ def linear(x: int, y: int) -> EnsembleLinear: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -178,21 +183,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! @@ -203,4 +209,4 @@ def save_best_fn(policy: Algorithm) -> None: if __name__ == "__main__": - test_redq() + main() diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index c1b6f8e47..ee6aa6cfe 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -10,13 +10,14 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import Reinforce from tianshou.policy.base import Algorithm -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb @@ -67,7 +68,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_reinforce(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -111,34 +112,38 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(actor.parameters(), lr=args.lr) - lr_scheduler = None + optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + num_epochs=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: Reinforce = Reinforce( + policy = ActorPolicy( actor=actor, - optim=optim, dist_fn=dist, action_space=env.action_space, - discount_factor=args.gamma, - reward_normalization=args.rew_norm, action_scaling=True, action_bound_method=args.action_bound_method, - lr_scheduler=lr_scheduler, + ) + algorithm: Reinforce = Reinforce( + policy=policy, + optim=optim, + discount_factor=args.gamma, + reward_normalization=args.rew_norm, ) # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) @@ -149,8 +154,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -178,21 +183,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(state, os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! @@ -203,4 +209,4 @@ def save_best_fn(policy: Algorithm) -> None: if __name__ == "__main__": - test_reinforce() + main() diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index baae832e9..7728a39a1 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -13,7 +13,9 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import SAC from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -65,7 +67,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_sac(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -91,7 +93,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -107,9 +109,9 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) @@ -117,8 +119,12 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: SAC = SAC( + policy = SACPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: SAC = SAC( + policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, @@ -128,12 +134,11 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: gamma=args.gamma, alpha=args.alpha, estimation_step=args.n_step, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector @@ -142,8 +147,8 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -172,21 +177,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! @@ -197,4 +203,4 @@ def save_best_fn(policy: Algorithm) -> None: if __name__ == "__main__": - test_sac() + main() diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 8bc814f93..b3c6e9679 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -14,7 +14,9 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TD3 from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -67,7 +69,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_td3(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -92,7 +94,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -108,12 +110,17 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy: TD3 = TD3( + policy = DDPGPolicy( actor=actor, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + algorithm: TD3 = TD3( + policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, @@ -121,17 +128,15 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, estimation_step=args.n_step, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector @@ -140,8 +145,8 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) @@ -170,21 +175,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! @@ -195,4 +201,4 @@ def save_best_fn(policy: Algorithm) -> None: if __name__ == "__main__": - test_td3() + main() diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 141924118..5bf3a1891 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -10,13 +10,14 @@ from mujoco_env import make_mujoco_env from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TRPO from tianshou.policy.base import Algorithm -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -129,30 +130,34 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(critic.parameters(), lr=args.lr) - lr_scheduler = None + optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + num_epochs=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: TRPO = TRPO( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=True, + action_bound_method=args.bound_action_method, + action_space=env.action_space, + ) + algorithm: TRPO = TRPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, advantage_normalization=args.norm_adv, optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, @@ -163,7 +168,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # load a previous policy if args.resume_path: ckpt = torch.load(args.resume_path, map_location=args.device) - policy.load_state_dict(ckpt["model"]) + algorithm.load_state_dict(ckpt["model"]) train_envs.set_obs_rms(ckpt["obs_rms"]) test_envs.set_obs_rms(ckpt["obs_rms"]) print("Loaded agent from: ", args.resume_path) @@ -174,8 +179,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -204,20 +209,21 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 11c24f457..f62b7c818 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -40,7 +40,7 @@ def __init__( *, actor: torch.nn.Module | ActorProb, exploration_noise: BaseNoise | Literal["default"] | None = None, - action_space: gym.spaces.Box, + action_space: gym.spaces.Space, deterministic_eval: bool = True, action_scaling: bool = True, action_bound_method: Literal["clip"] | None = "clip", From 652c7f34d6c02034aca55749375787ecbbeedded Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 14:28:05 +0100 Subject: [PATCH 064/230] v2: Refactor vanilla imitation learning, separating off-policy from offline case * Introduce mixin ImitiationLearningAlgorithmMixin to factor out common functionality * ImitationLearning -> OffPolicyImitationLearning, OfflineImitationLearning --- CHANGELOG.md | 1 + test/continuous/test_sac_with_il.py | 4 +- test/discrete/test_a2c_with_il.py | 4 +- tianshou/policy/__init__.py | 2 +- tianshou/policy/imitation/base.py | 90 ++++++++++++++++++++--------- 5 files changed, 68 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b9d7a526c..06df556bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ Migration information: The instantiation of a policy is replaced by the instantiation of an `Algorithm`, which is passed a `Policy`. In most cases, the former policy class name `Policy` is replaced by algorithm class ``; exceptions are noted below. + * `ImitationPolicy` -> `OffPolicyImitationLearning`, `OfflineImitationLearning` * `PGPolicy` -> `Reinforce` * `MultiAgentPolicyManager` -> `MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm` * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 5870899a8..92efc19b4 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -8,7 +8,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import SAC, ImitationLearning +from tianshou.policy import SAC, OffPolicyImitationLearning from tianshou.policy.base import Algorithm from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy @@ -187,7 +187,7 @@ def stop_fn(mean_rewards: float) -> bool: action_scaling=True, action_bound_method="clip", ) - il_algorithm: ImitationLearning = ImitationLearning( + il_algorithm: OffPolicyImitationLearning = OffPolicyImitationLearning( policy=il_policy, optim=optim, ) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 7d0f5a1a2..f2e75ab66 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import A2C, ImitationLearning +from tianshou.policy import A2C, OffPolicyImitationLearning from tianshou.policy.base import Algorithm from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.pg import ActorPolicy @@ -163,7 +163,7 @@ def stop_fn(mean_rewards: float) -> bool: actor=actor, action_space=env.action_space, ) - il_algorithm: ImitationLearning = ImitationLearning( + il_algorithm: OffPolicyImitationLearning = OffPolicyImitationLearning( policy=il_policy, optim=optim, ) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index a12b3fc50..cf1815273 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -21,7 +21,7 @@ from tianshou.policy.modelfree.sac import SAC from tianshou.policy.modelfree.redq import REDQ from tianshou.policy.modelfree.discrete_sac import DiscreteSAC -from tianshou.policy.imitation.base import ImitationLearning +from tianshou.policy.imitation.base import OffPolicyImitationLearning from tianshou.policy.imitation.bcq import BCQ from tianshou.policy.imitation.cql import CQL from tianshou.policy.imitation.td3_bc import TD3BC diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index e385f2698..ddf9231ed 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -14,6 +14,7 @@ RolloutBatchProtocol, ) from tianshou.policy.base import ( + OfflineAlgorithm, OffPolicyAlgorithm, Policy, TrainingStats, @@ -46,9 +47,7 @@ def __init__( action_bound_method: Literal["clip", "tanh"] | None = "clip", ): """ - :param actor: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param optim: for optimizing the model. + :param actor: a model following the rules (s -> a) :param action_space: Env's action_space. :param observation_space: Env's observation space. :param action_scaling: if True, scale the action from [-1, 1] to the range @@ -87,8 +86,36 @@ def forward( return cast(ModelOutputBatchProtocol, result) -class ImitationLearning(OffPolicyAlgorithm, Generic[TImitationTrainingStats]): - """Implementation of vanilla imitation learning.""" +class ImitationLearningAlgorithmMixin: + def _imitation_update( + self, + batch: RolloutBatchProtocol, + policy: ImitationPolicy, + optim: torch.optim.Optimizer, + ) -> ImitationTrainingStats: + optim.zero_grad() + if policy.action_type == "continuous": # regression + act = policy(batch).act + act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) + loss = F.mse_loss(act, act_target) + elif policy.action_type == "discrete": # classification + act = F.log_softmax(policy(batch).logits, dim=-1) + act_target = to_torch(batch.act, dtype=torch.long, device=act.device) + loss = F.nll_loss(act, act_target) + else: + raise ValueError(policy.action_type) + loss.backward() + optim.step() + + return ImitationTrainingStats(loss=loss.item()) + + +class OffPolicyImitationLearning( + OffPolicyAlgorithm[ImitationPolicy, TImitationTrainingStats], + ImitationLearningAlgorithmMixin, + Generic[TImitationTrainingStats], +): + """Implementation of off-policy vanilla imitation learning.""" def __init__( self, @@ -98,15 +125,7 @@ def __init__( ) -> None: """ :param policy: the policy - :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param optim: for optimizing the model. - :param action_space: Env's action_space. - :param observation_space: Env's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. - :param lr_scheduler: if not None, will be called in `policy.update()`. + :param optim: the optimizer factory """ super().__init__( policy=policy, @@ -117,18 +136,33 @@ def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> TImitationTrainingStats: - self.optim.zero_grad() - if self.policy.action_type == "continuous": # regression - act = self.policy(batch).act - act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) - loss = F.mse_loss(act, act_target) - elif self.policy.action_type == "discrete": # classification - act = F.log_softmax(self.policy(batch).logits, dim=-1) - act_target = to_torch(batch.act, dtype=torch.long, device=act.device) - loss = F.nll_loss(act, act_target) - else: - raise ValueError(self.policy.action_type) - loss.backward() - self.optim.step() + return self._imitation_update(batch, self.policy, self.optim) + + +class OfflineImitationLearning( + OfflineAlgorithm[ImitationPolicy, TImitationTrainingStats], + ImitationLearningAlgorithmMixin, + Generic[TImitationTrainingStats], +): + """Implementation of offline vanilla imitation learning.""" + + def __init__( + self, + *, + policy: ImitationPolicy, + optim: OptimizerFactory, + ) -> None: + """ + :param policy: the policy + :param optim: the optimizer factory + """ + super().__init__( + policy=policy, + ) + self.optim = self._create_optimizer(self.policy, optim) - return ImitationTrainingStats(loss=loss.item()) # type: ignore + def _update_with_batch( + self, + batch: RolloutBatchProtocol, + ) -> TImitationTrainingStats: + return self._imitation_update(batch, self.policy, self.optim) From 45bf412e71b6daa3ac73853816241abb130e3ce5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 14:08:16 +0100 Subject: [PATCH 065/230] v2: Adapt offline examples --- examples/offline/atari_bcq.py | 56 ++++++++++++++++--------------- examples/offline/atari_cql.py | 50 ++++++++++++++++------------ examples/offline/atari_crr.py | 54 ++++++++++++++++-------------- examples/offline/atari_il.py | 46 +++++++++++++------------ examples/offline/d4rl_bcq.py | 59 ++++++++++++++++++--------------- examples/offline/d4rl_cql.py | 53 +++++++++++++++-------------- examples/offline/d4rl_il.py | 47 ++++++++++++++------------ examples/offline/d4rl_td3_bc.py | 52 ++++++++++++++++------------- 8 files changed, 228 insertions(+), 189 deletions(-) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 091b2f7c9..ebe582def 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -11,15 +11,16 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import DQNet -from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteBCQ from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer -from tianshou.utils.net.common import ActorCritic +from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils.net.discrete import Actor @@ -73,7 +74,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, @@ -118,24 +119,26 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) - actor_critic = ActorCritic(policy_net, imitation_net) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - # define policy - policy: DiscreteBCQ = DiscreteBCQ( + optim = AdamOptimizerFactory(lr=args.lr) + # define policy and algorithm + policy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, - optim=optim, action_space=env.action_space, + unlikely_action_threshold=args.unlikely_action_threshold, + ) + algorithm: DiscreteBCQ = DiscreteBCQ( + policy=policy, + optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, eval_eps=args.eps_test, - unlikely_action_threshold=args.unlikely_action_threshold, imitation_logits_penalty=args.imitation_logits_penalty, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: @@ -155,7 +158,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -198,22 +201,23 @@ def watch() -> None: watch() sys.exit(0) - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.update_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_discrete_bcq(get_args()) + main(get_args()) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 28e394b18..217fd997f 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -12,14 +12,16 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import QRDQNet -from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import QRDQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCQL from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils.space_info import SpaceInfo @@ -73,7 +75,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, @@ -105,12 +107,15 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: num_quantiles=args.num_quantiles, device=args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # define policy - policy: DiscreteCQL = DiscreteCQL( + policy = QRDQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + algorithm: DiscreteCQL = DiscreteCQL( + policy=policy, + optim=optim, discount_factor=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, @@ -119,7 +124,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: @@ -139,7 +144,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -182,22 +187,23 @@ def watch() -> None: watch() sys.exit(0) - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.update_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_discrete_cql(get_args()) + main(get_args()) diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 237ac2f44..c5d1558e3 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -11,15 +11,16 @@ import torch from gymnasium.spaces import Discrete -from examples.atari.atari_network import DQNet -from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCRR from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer -from tianshou.utils.net.common import ActorCritic +from tianshou.policy.modelfree.pg import DiscreteActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.space_info import SpaceInfo @@ -74,7 +75,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: +def main(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, @@ -119,14 +120,16 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: last_size=int(np.prod(args.action_shape)), device=args.device, ).to(args.device) - actor_critic = ActorCritic(actor, critic) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - # define policy - policy: DiscreteCRR = DiscreteCRR( + optim = AdamOptimizerFactory(lr=args.lr) + # define policy and algorithm + policy = DiscreteActorPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: DiscreteCRR = DiscreteCRR( + policy=policy, critic=critic, optim=optim, - action_space=env.action_space, discount_factor=args.gamma, policy_improvement_mode=args.policy_improvement_mode, ratio_upper_bound=args.ratio_upper_bound, @@ -136,7 +139,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: @@ -156,7 +159,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -198,22 +201,23 @@ def watch() -> None: watch() sys.exit(0) - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.update_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) watch() if __name__ == "__main__": - test_discrete_crr(get_args()) + main(get_args()) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index d69a55c89..a34704b0f 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -10,14 +10,15 @@ import numpy as np import torch -from examples.atari.atari_network import DQNet -from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import ImitationLearning from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer +from tianshou.policy.imitation.base import ImitationPolicy, OfflineImitationLearning +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils.space_info import SpaceInfo @@ -88,14 +89,16 @@ def test_il(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net = DQNet(c, h, w, args.action_shape, device=args.device).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # define policy - policy: ImitationLearning = ImitationLearning( - actor=net, optim=optim, action_space=env.action_space + policy = ImitationPolicy(actor=net, action_space=env.action_space) + algorithm = OfflineImitationLearning( + policy=policy, + optim=optim, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer if args.buffer_from_rl_unplugged: @@ -115,7 +118,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -157,18 +160,19 @@ def watch() -> None: watch() sys.exit(0) - result = OfflineTrainer( - policy=policy, - buffer=buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.update_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) watch() diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 2da381f2e..67bb12eda 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -15,7 +15,9 @@ from tianshou.env import SubprocVectorEnv from tianshou.policy import BCQ from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer +from tianshou.policy.imitation.bcq import BCQPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation @@ -106,7 +108,7 @@ def test_bcq() -> None: actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to( args.device, ) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, @@ -123,9 +125,9 @@ def test_bcq() -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae # output_dim = 0, so the last Module in the encoder is ReLU @@ -150,31 +152,33 @@ def test_bcq() -> None: max_action=args.max_action, device=args.device, ).to(args.device) - vae_optim = torch.optim.Adam(vae.parameters()) + vae_optim = AdamOptimizerFactory() - policy: BCQ = BCQ( + policy = BCQPolicy( actor_perturbation=actor, - actor_perturbation_optim=actor_optim, + action_space=env.action_space, critic=critic1, + vae=vae, + ) + algorithm: BCQ = BCQ( + policy=policy, + actor_perturbation_optim=actor_optim, critic_optim=critic1_optim, - action_space=env.action_space, critic2=critic2, critic2_optim=critic2_optim, - vae=vae, vae_optim=vae_optim, - device=args.device, gamma=args.gamma, tau=args.tau, lmbda=args.lmbda, - ) + ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -205,24 +209,25 @@ def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") - policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector[CollectStats](policy, env) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=replay_buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=replay_buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) else: watch() diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index a52962cba..3c1451f9a 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -15,7 +15,9 @@ from tianshou.env import SubprocVectorEnv from tianshou.policy import CQL from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer +from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -256,7 +258,7 @@ def test_cql() -> None: unbounded=True, conditioned_sigma=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c1 = Net( @@ -274,9 +276,9 @@ def test_cql() -> None: device=args.device, ) critic = Critic(net_c1, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: target_entropy = -args.action_dim @@ -284,12 +286,15 @@ def test_cql() -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: CQL = CQL( + policy = SACPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm = CQL( + policy=policy, policy_optim=actor_optim, critic=critic, critic_optim=critic_optim, - action_space=env.action_space, critic2=critic2, critic2_optim=critic2_optim, calibrated=args.calibrated, @@ -303,16 +308,15 @@ def test_cql() -> None: lagrange_threshold=args.lagrange_threshold, min_action=args.min_action, max_action=args.max_action, - device=args.device, - ) + ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -343,24 +347,25 @@ def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") - policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector[CollectStats](policy, env) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=replay_buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=replay_buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) else: watch() diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 59b5d996a..af9908fd9 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -13,9 +13,10 @@ from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import ImitationLearning from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer +from tianshou.policy.imitation.base import ImitationPolicy, OfflineImitationLearning +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor @@ -94,23 +95,26 @@ def test_il() -> None: max_action=args.max_action, device=args.device, ).to(args.device) - optim = torch.optim.Adam(actor.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) - policy: ImitationLearning = ImitationLearning( + policy = ImitationPolicy( actor=actor, - optim=optim, action_space=env.action_space, action_scaling=True, action_bound_method="clip", ) + algorithm = OfflineImitationLearning( + policy=policy, + optim=optim, + ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -141,24 +145,25 @@ def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") - policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector[CollectStats](policy, env) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=replay_buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=replay_buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) else: watch() diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index ed58fb179..07ed5b75d 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -16,7 +16,9 @@ from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BC from tianshou.policy.base import Algorithm -from tianshou.trainer import OfflineTrainer +from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OfflineTrainingConfig from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -113,7 +115,7 @@ def test_td3_bc() -> None: max_action=args.max_action, device=args.device, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network net_c1 = Net( @@ -131,12 +133,17 @@ def test_td3_bc() -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy: TD3BC = TD3BC( + policy = DDPGPolicy( actor=actor, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + action_space=env.action_space, + ) + algorithm: TD3BC = TD3BC( + policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, @@ -144,22 +151,20 @@ def test_td3_bc() -> None: critic2_optim=critic2_optim, tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, estimation_step=args.n_step, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -190,8 +195,8 @@ def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") - policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector[CollectStats](policy, env) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) + collector = Collector[CollectStats](algorithm, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: @@ -199,18 +204,19 @@ def watch() -> None: if args.norm_obs: replay_buffer, obs_rms = normalize_all_obs_in_replay_buffer(replay_buffer) test_envs.set_obs_rms(obs_rms) - # trainer - result = OfflineTrainer( - policy=policy, - buffer=replay_buffer, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, - batch_size=args.batch_size, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OfflineTrainingConfig( + buffer=replay_buffer, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + episode_per_test=args.test_num, + batch_size=args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + ) + ) pprint.pprint(result) else: watch() From 877023b9441a062172fe32f0f717af5314cec1ff Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 16:12:20 +0100 Subject: [PATCH 066/230] v2: Adapt test_drqn --- test/discrete/test_drqn.py | 53 +++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 6e2395303..84f1d3001 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -10,7 +10,9 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent from tianshou.utils.space_info import SpaceInfo @@ -73,13 +75,16 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: net = Recurrent(args.layer_num, args.state_shape, args.action_shape, args.device).to( args.device, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQN = DQN( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DQNPolicy( model=net, + action_space=env.action_space, + ) + algorithm: DQN = DQN( + policy=policy, optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, - action_space=env.action_space, target_update_freq=args.target_update_freq, ) # collector @@ -89,10 +94,9 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.stack_num, ignore_obs_next=True, ) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) # the stack_num is for RNN training: sample framestack obs - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -112,21 +116,22 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) From b17e03d76b86a068bf6e3b1cf13fdfff2efde012 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 16:12:38 +0100 Subject: [PATCH 067/230] v2: Adapt remaining examples --- examples/box2d/acrobot_dualdqn.py | 56 +++++++++-------- examples/box2d/bipedal_bdq.py | 50 ++++++++------- examples/box2d/bipedal_hardcore_sac.py | 58 ++++++++++-------- examples/box2d/lunarlander_dqn.py | 52 +++++++++------- examples/box2d/mcc_sac.py | 56 +++++++++-------- examples/inverse/irl_gail.py | 74 ++++++++++++----------- examples/vizdoom/vizdoom_c51.py | 64 +++++++++++--------- examples/vizdoom/vizdoom_ppo.py | 84 ++++++++++++++------------ 8 files changed, 273 insertions(+), 221 deletions(-) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index b5820d2d4..1ef780b58 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -11,7 +11,9 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -73,24 +75,27 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, device=args.device, dueling_param=(Q_param, V_param), - ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQN = DQN( + ) + optim = AdamOptimizerFactory(lr=args.lr) + policy = DQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + algorithm: DQN = DQN( + policy=policy, + optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, - ) + ).to(args.device) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -122,23 +127,24 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) if __name__ == "__main__": diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 97c4a4683..df454722d 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -12,7 +12,9 @@ from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.policy import BDQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.bdqn import BDQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet @@ -100,22 +102,25 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: args.action_hidden_sizes, device=args.device, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BDQN = BDQN( + optim = AdamOptimizerFactory(lr=args.lr) + policy = BDQNPolicy( model=net, + action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? + ) + algorithm: BDQN = BDQN( + policy=policy, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -141,22 +146,23 @@ def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - train_fn=train_fn, - test_fn=test_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + train_fn=train_fn, + test_fn=test_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) if __name__ == "__main__": diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index d1a3d0475..9e88adba4 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -13,7 +13,9 @@ from tianshou.env import SubprocVectorEnv from tianshou.policy import SAC from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -115,7 +117,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: device=args.device, unbounded=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, @@ -125,7 +127,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, @@ -135,17 +137,21 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(critic2.parameters(), lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) - policy: SAC = SAC( + policy = SACPolicy( actor=actor, + action_space=env.action_space, + ) + algorithm: SAC = SAC( + policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, @@ -155,21 +161,20 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: gamma=args.gamma, alpha=args.alpha, estimation_step=args.n_step, - action_space=env.action_space, ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path)) + algorithm.load_state_dict(torch.load(args.resume_path)) print("Loaded agent from: ", args.resume_path) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") @@ -187,22 +192,23 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= env.spec.reward_threshold return False - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - test_in_train=False, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + test_in_train=False, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) if __name__ == "__main__": pprint.pprint(result) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 332df7747..63e9b0a3a 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -11,7 +11,9 @@ from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import DQN from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -76,23 +78,26 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: device=args.device, dueling_param=(Q_param, V_param), ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: DQN = DQN( + optim = AdamOptimizerFactory(lr=args.lr) + policy = DQNPolicy( model=net, - optim=optim, action_space=env.action_space, + ) + algorithm: DQN = DQN( + policy=policy, + optim=optim, discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -119,23 +124,24 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - train_fn=train_fn, - test_fn=test_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + train_fn=train_fn, + test_fn=test_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) if __name__ == "__main__": diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 21336ed3e..5e11e283c 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -12,7 +12,9 @@ from tianshou.exploration import OUNoise from tianshou.policy import SAC from tianshou.policy.base import Algorithm -from tianshou.trainer import OffPolicyTrainer +from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -68,7 +70,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -77,7 +79,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -86,17 +88,22 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) - policy: SAC = SAC( + policy = SACPolicy( actor=actor, + exploration_noise=OUNoise(0.0, args.noise_std), + action_space=env.action_space, + ) + algorithm: SAC = SAC( + policy=policy, policy_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, @@ -105,17 +112,15 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - exploration_noise=OUNoise(0.0, args.noise_std), - action_space=env.action_space, ) # collector train_collector = Collector[CollectStats]( - policy, + algorithm, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector[CollectStats](policy, test_envs) + test_collector = Collector[CollectStats](algorithm, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") @@ -133,21 +138,22 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= env.spec.reward_threshold return False - # trainer - result = OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + ) + ) assert stop_fn(result.best_reward) if __name__ == "__main__": diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 366deee89..dd2dd04cd 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -12,7 +12,6 @@ import torch from torch import nn from torch.distributions import Distribution, Independent, Normal -from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter from tianshou.data import ( @@ -26,9 +25,11 @@ from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.policy import GAIL from tianshou.policy.base import Algorithm -from tianshou.trainer import OnPolicyTrainer +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -154,7 +155,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) # discriminator net_d = Net( args.state_shape, @@ -170,14 +171,16 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: # orthogonal initialization torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) - disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr) + disc_optim = AdamOptimizerFactory(lr=args.disc_lr) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + num_epochs=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + ) + ) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale @@ -205,11 +208,17 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: ) print("dataset loaded") - policy: GAIL = GAIL( + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=True, + action_bound_method=args.bound_action_method, + action_space=env.action_space, + ) + algorithm: GAIL = GAIL( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, expert_buffer=expert_buffer, disc_net=disc_net, disc_optim=disc_optim, @@ -220,10 +229,6 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, - action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, - action_space=env.action_space, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, @@ -233,7 +238,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # collector @@ -242,8 +247,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail' @@ -256,21 +261,22 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index c81cb4fdf..2d492392b 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -13,7 +13,9 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51 from tianshou.policy.base import Algorithm -from tianshou.trainer import OffpolicyTrainer +from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.trainer import OffPolicyTrainingConfig def get_args() -> argparse.Namespace: @@ -93,22 +95,25 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model net = C51Net(*args.state_shape, args.action_shape, args.num_atoms, args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - # define policy - policy: C51 = C51( + optim = AdamOptimizerFactory(lr=args.lr) + # define policy and algorithm + policy = C51Policy( model=net, - optim=optim, - discount_factor=args.gamma, action_space=env.action_space, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + ) + algorithm: C51 = C51( + policy=policy, + optim=optim, + discount_factor=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -120,8 +125,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -179,7 +184,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -197,24 +204,25 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - update_per_step=args.update_per_step, - test_in_train=False, - ).run() + # train + result = algorithm.run_training( + OffPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + ) pprint.pprint(result) watch() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 3f76614d8..0154a14f5 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -8,7 +8,6 @@ import torch from env import make_vizdoom_env from torch.distributions import Categorical -from torch.optim.lr_scheduler import LambdaLR from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet @@ -16,8 +15,9 @@ from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.trainer import OnPolicyTrainer -from tianshou.utils.net.common import ActorCritic +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.trainer import OnPolicyTrainingConfig from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -128,33 +128,37 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: ) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) - optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) - lr_scheduler = None if args.lr_decay: - # decay learning rate to 0 linearly - max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch - - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim.with_lr_scheduler_factory( + LRSchedulerFactoryLinear( + num_epochs=args.epoch, + step_per_epoch=args.step_per_epoch, + step_per_collect=args.step_per_collect, + ) + ) - # define policy def dist(logits: torch.Tensor) -> Categorical: return Categorical(logits=logits) - policy: PPO = PPO( + # define policy and algorithm + policy = ActorPolicy( actor=actor, + dist_fn=dist, + action_scaling=False, + action_space=env.action_space, + ) + algorithm = PPO( + policy=policy, critic=critic, optim=optim, - dist_fn=dist, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, reward_normalization=args.rew_norm, - action_scaling=False, - lr_scheduler=lr_scheduler, - action_space=env.action_space, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, @@ -177,9 +181,9 @@ def dist(logits: torch.Tensor) -> Categorical: action_dim, device=args.device, ) - icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMOnPolicyWrapper( # type: ignore[no-redef] - wrapped_algorithm=policy, + icm_optim = AdamOptimizerFactory(lr=args.lr) + algorithm = ICMOnPolicyWrapper( + wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, lr_scale=args.icm_lr_scale, @@ -188,7 +192,7 @@ def dist(logits: torch.Tensor) -> Categorical: ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + algorithm.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM @@ -200,8 +204,8 @@ def dist(logits: torch.Tensor) -> Categorical: stack_num=args.frames_stack, ) # collector - train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -245,7 +249,9 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats]( + algorithm, test_envs, buffer, exploration_noise=True + ) result = collector.collect(n_step=args.buffer_size, reset_before_collect=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -263,22 +269,24 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) - # trainer - result = OnPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, - batch_size=args.batch_size, - step_per_collect=args.step_per_collect, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ).run() + + # train + result = algorithm.run_training( + OnPolicyTrainingConfig( + train_collector=train_collector, + test_collector=test_collector, + max_epoch=args.epoch, + step_per_epoch=args.step_per_epoch, + repeat_per_collect=args.repeat_per_collect, + episode_per_test=args.test_num, + batch_size=args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger, + test_in_train=False, + ) + ) pprint.pprint(result) watch() From 5aa8b3a980fdfdee6aeded0b8f54e20642bbf449 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 16:12:51 +0100 Subject: [PATCH 068/230] v2: Rename *TrainingConfig classes to *TrainerParams in order to free the name for the high-level API --- CHANGELOG.md | 2 +- examples/atari/atari_c51.py | 4 +- examples/atari/atari_dqn.py | 4 +- examples/atari/atari_fqf.py | 4 +- examples/atari/atari_iqn.py | 4 +- examples/atari/atari_ppo.py | 4 +- examples/atari/atari_qrdqn.py | 4 +- examples/atari/atari_rainbow.py | 4 +- examples/atari/atari_sac.py | 4 +- examples/box2d/acrobot_dualdqn.py | 4 +- examples/box2d/bipedal_bdq.py | 4 +- examples/box2d/bipedal_hardcore_sac.py | 4 +- examples/box2d/lunarlander_dqn.py | 4 +- examples/box2d/mcc_sac.py | 4 +- examples/discrete/discrete_dqn.py | 4 +- examples/inverse/irl_gail.py | 4 +- examples/mujoco/fetch_her_ddpg.py | 4 +- examples/mujoco/mujoco_a2c.py | 4 +- examples/mujoco/mujoco_ddpg.py | 4 +- examples/mujoco/mujoco_npg.py | 4 +- examples/mujoco/mujoco_ppo.py | 4 +- examples/mujoco/mujoco_redq.py | 4 +- examples/mujoco/mujoco_reinforce.py | 4 +- examples/mujoco/mujoco_sac.py | 4 +- examples/mujoco/mujoco_td3.py | 4 +- examples/mujoco/mujoco_trpo.py | 4 +- examples/offline/atari_bcq.py | 4 +- examples/offline/atari_cql.py | 4 +- examples/offline/atari_crr.py | 4 +- examples/offline/atari_il.py | 4 +- examples/offline/d4rl_bcq.py | 4 +- examples/offline/d4rl_cql.py | 4 +- examples/offline/d4rl_il.py | 4 +- examples/offline/d4rl_td3_bc.py | 4 +- examples/vizdoom/vizdoom_c51.py | 4 +- examples/vizdoom/vizdoom_ppo.py | 4 +- test/continuous/test_ddpg.py | 4 +- test/continuous/test_npg.py | 4 +- test/continuous/test_ppo.py | 4 +- test/continuous/test_redq.py | 4 +- test/continuous/test_sac_with_il.py | 6 +- test/continuous/test_td3.py | 4 +- test/continuous/test_trpo.py | 4 +- test/discrete/test_a2c_with_il.py | 6 +- test/discrete/test_bdqn.py | 4 +- test/discrete/test_c51.py | 4 +- test/discrete/test_discrete_sac.py | 4 +- test/discrete/test_dqn.py | 4 +- test/discrete/test_drqn.py | 4 +- test/discrete/test_fqf.py | 4 +- test/discrete/test_iqn.py | 4 +- test/discrete/test_pg.py | 4 +- test/discrete/test_ppo.py | 4 +- test/discrete/test_qrdqn.py | 4 +- test/discrete/test_rainbow.py | 4 +- test/modelbased/test_dqn_icm.py | 4 +- test/modelbased/test_ppo_icm.py | 4 +- test/modelbased/test_psrl.py | 4 +- test/offline/gather_cartpole_data.py | 4 +- test/offline/gather_pendulum_data.py | 4 +- test/offline/test_bcq.py | 4 +- test/offline/test_cql.py | 4 +- test/offline/test_discrete_bcq.py | 4 +- test/offline/test_discrete_cql.py | 4 +- test/offline/test_discrete_crr.py | 4 +- test/offline/test_gail.py | 4 +- test/offline/test_td3_bc.py | 4 +- test/pettingzoo/pistonball.py | 4 +- test/pettingzoo/pistonball_continuous.py | 4 +- test/pettingzoo/tic_tac_toe.py | 4 +- tianshou/highlevel/algorithm.py | 6 +- tianshou/policy/base.py | 41 +++--- tianshou/trainer/__init__.py | 6 +- tianshou/trainer/base.py | 160 +++++++++++------------ tianshou/utils/logger/wandb.py | 2 - 75 files changed, 247 insertions(+), 250 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 06df556bb..6d5515a1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed accordingly. * Migration information at a glance: * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: - `OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`, `OfflineTrainingConfig`. + `OnPolicyTrainerParams`, `OffPolicyTrainerParams`, `OfflineTrainerParams`. * Trainer classes have been renamed: * `OnpolicyTrainer` -> `OnPolicyTrainer` * `OffpolicyTrainer` -> `OffPolicyTrainer` diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 4b8eb5364..cf4725b47 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -15,7 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams def get_args() -> argparse.Namespace: @@ -204,7 +204,7 @@ def watch() -> None: train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 85b48bc2c..901bcf64f 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -16,7 +16,7 @@ from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -248,7 +248,7 @@ def watch() -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 31014e599..b55e9804b 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -15,7 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.fqf import FQFPolicy from tianshou.policy.optim import AdamOptimizerFactory, RMSpropOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -220,7 +220,7 @@ def watch() -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index c662ab146..3bcbd133a 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -15,7 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -216,7 +216,7 @@ def watch() -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index cc82f589c..65e0d9d49 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -16,7 +16,7 @@ from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.policy.modelfree.pg import DiscreteActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -263,7 +263,7 @@ def watch() -> None: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 579387b23..cd17a6ec8 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -15,7 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: @@ -211,7 +211,7 @@ def watch() -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 151b03d8f..a6ea5fe88 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -20,7 +20,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: @@ -251,7 +251,7 @@ def watch() -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 646308393..363c95c73 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -16,7 +16,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.modelfree.sac import AutoAlpha from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -257,7 +257,7 @@ def watch() -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 1ef780b58..570346122 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -13,7 +13,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -129,7 +129,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index df454722d..dea269d4a 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -14,7 +14,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.bdqn import BDQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet @@ -147,7 +147,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # trainer result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 9e88adba4..9ccc48690 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -15,7 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -194,7 +194,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 63e9b0a3a..fb558cf70 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -13,7 +13,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -126,7 +126,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 5e11e283c..068bef1e0 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -14,7 +14,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -140,7 +140,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index a2f6596a8..ff679c542 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -5,7 +5,7 @@ import tianshou as ts from tianshou.data import CollectStats from tianshou.policy.modelfree.dqn import DQNPolicy -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils.space_info import SpaceInfo @@ -66,7 +66,7 @@ def stop_fn(mean_rewards: float) -> bool: return False result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=epoch, diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index dd2dd04cd..4a4b652f1 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -27,7 +27,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -263,7 +263,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 887adccec..e1a3bc823 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -26,7 +26,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic from tianshou.utils.space_info import ActionSpaceInfo @@ -227,7 +227,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 96f8d173f..3efbb9f8b 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -17,7 +17,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -207,7 +207,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index b76607e2b..888b206a1 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -16,7 +16,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -157,7 +157,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index ce556e40a..596bbc433 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -17,7 +17,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -205,7 +205,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 13a333d9d..a425815ed 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -17,7 +17,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -213,7 +213,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index e8ba562e6..31082ab6b 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -15,7 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.redq import REDQPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -185,7 +185,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index ee6aa6cfe..1f946b9df 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -17,7 +17,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb @@ -185,7 +185,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 7728a39a1..85c34884c 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -15,7 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -179,7 +179,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index b3c6e9679..8e7b6ed5f 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -16,7 +16,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -177,7 +177,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 5bf3a1891..4620fa00b 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -17,7 +17,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -210,7 +210,7 @@ def save_best_fn(policy: Algorithm) -> None: if not args.watch: # trainer result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index ebe582def..ffa671b6e 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -20,7 +20,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils.net.discrete import Actor @@ -202,7 +202,7 @@ def watch() -> None: sys.exit(0) result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 217fd997f..2f7849c1b 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -21,7 +21,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo @@ -188,7 +188,7 @@ def watch() -> None: sys.exit(0) result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index c5d1558e3..6972f5bb4 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -20,7 +20,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import DiscreteActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.space_info import SpaceInfo @@ -202,7 +202,7 @@ def watch() -> None: sys.exit(0) result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index a34704b0f..9a028f232 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -18,7 +18,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.imitation.base import ImitationPolicy, OfflineImitationLearning from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo @@ -161,7 +161,7 @@ def watch() -> None: sys.exit(0) result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 67bb12eda..05322bf3e 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -17,7 +17,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation @@ -217,7 +217,7 @@ def watch() -> None: replay_buffer = load_buffer_d4rl(args.expert_data_task) # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 3c1451f9a..d0121d131 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -17,7 +17,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -355,7 +355,7 @@ def watch() -> None: replay_buffer = load_buffer_d4rl(args.expert_data_task) # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index af9908fd9..22b52f212 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -16,7 +16,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.imitation.base import ImitationPolicy, OfflineImitationLearning from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor @@ -153,7 +153,7 @@ def watch() -> None: replay_buffer = load_buffer_d4rl(args.expert_data_task) # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 07ed5b75d..71817dc33 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -18,7 +18,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -206,7 +206,7 @@ def watch() -> None: test_envs.set_obs_rms(obs_rms) # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 2d492392b..a8cfca26f 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -15,7 +15,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: @@ -206,7 +206,7 @@ def watch() -> None: train_collector.collect(n_step=args.batch_size * args.training_num) # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 0154a14f5..551f8009e 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -17,7 +17,7 @@ from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -272,7 +272,7 @@ def watch() -> None: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 797479448..3c3e94312 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -13,7 +13,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -123,7 +123,7 @@ def stop_fn(mean_rewards: float) -> bool: # trainer result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 966c30a7d..11bba11e9 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -15,7 +15,7 @@ from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OnPolicyTrainingConfig +from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -144,7 +144,7 @@ def stop_fn(mean_rewards: float) -> bool: # trainer result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 601fcaaae..f475a045a 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -13,7 +13,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OnPolicyTrainingConfig +from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -170,7 +170,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 972a7bcad..d34034384 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -13,7 +13,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.redq import REDQPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -154,7 +154,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 92efc19b4..2d7c36f04 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -13,7 +13,7 @@ from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic @@ -150,7 +150,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, @@ -200,7 +200,7 @@ def stop_fn(mean_rewards: float) -> bool: ) train_collector.reset() result = il_algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=il_test_collector, max_epoch=args.epoch, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 4e62caebe..48ac97df6 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -13,7 +13,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -140,7 +140,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index cbcda10ed..e74c1e7cd 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -14,7 +14,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OnPolicyTrainingConfig +from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -145,7 +145,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index f2e75ab66..e07939698 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -14,7 +14,7 @@ from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic @@ -136,7 +136,7 @@ def stop_fn(mean_rewards: float) -> bool: # trainer result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, @@ -186,7 +186,7 @@ def stop_fn(mean_rewards: float) -> bool: ) train_collector.reset() result = il_algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=il_test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index dbb5b846d..7f322c8c1 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -9,7 +9,7 @@ from tianshou.policy import BDQN from tianshou.policy.modelfree.bdqn import BDQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils.net.common import BranchingNet @@ -134,7 +134,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index e6890d971..435f1020a 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -19,7 +19,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -191,7 +191,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 0efdf6f0e..503264fa6 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -15,7 +15,7 @@ DiscreteSACTrainingStats, ) from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic @@ -133,7 +133,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index ec0583b6a..0345ba9c6 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -18,7 +18,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -144,7 +144,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 84f1d3001..fd636c5bc 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -12,7 +12,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent from tianshou.utils.space_info import SpaceInfo @@ -118,7 +118,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 4cb884e58..abe413e26 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -18,7 +18,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.fqf import FQFPolicy from tianshou.policy.optim import AdamOptimizerFactory, RMSpropOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -162,7 +162,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 2e5dda9cd..a14ae05ac 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -18,7 +18,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import ImplicitQuantileNetwork @@ -158,7 +158,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index a3edecb2a..bfac05f1b 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -13,7 +13,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OnPolicyTrainingConfig +from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -114,7 +114,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # train - training_config = OnPolicyTrainingConfig( + training_config = OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index c21045a5c..d4c0bbf8c 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -12,7 +12,7 @@ from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import DiscreteActorPolicy -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net from tianshou.utils.net.discrete import Actor, Critic @@ -137,7 +137,7 @@ def stop_fn(mean_rewards: float) -> bool: # trainer result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 867c528a8..b14867b4d 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -17,7 +17,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -150,7 +150,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 915ef85ab..7c6c551ed 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -18,7 +18,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear @@ -210,7 +210,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index bf6d0e317..0a8b2c9aa 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -16,7 +16,7 @@ from tianshou.policy import DQN, ICMOffPolicyWrapper from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -192,7 +192,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = icm_algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index cedcc662d..6e2bbfdbb 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -14,7 +14,7 @@ from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule @@ -185,7 +185,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = icm_algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index ead82de46..086d12f76 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -9,7 +9,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.policy import PSRL from tianshou.policy.modelbased.psrl import PSRLPolicy -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger try: @@ -113,7 +113,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 8ec3c1dec..e84111f65 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -17,7 +17,7 @@ from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -153,7 +153,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # train result = algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 4aafb88be..03c63387d 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -12,7 +12,7 @@ from tianshou.policy import SAC from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats -from tianshou.trainer.base import OffPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -148,7 +148,7 @@ def stop_fn(mean_rewards: float) -> bool: # trainer algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 0112ae244..98349d318 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -14,7 +14,7 @@ from tianshou.policy import BCQ, Algorithm from tianshou.policy.imitation.bcq import BCQPolicy, BCQTrainingStats from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer.base import OfflineTrainingConfig +from tianshou.trainer.base import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation @@ -193,7 +193,7 @@ def watch() -> None: # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index d49134df6..18786eab7 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -15,7 +15,7 @@ from tianshou.policy.imitation.cql import CQLTrainingStats from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -186,7 +186,7 @@ def stop_fn(mean_rewards: float) -> bool: # trainer result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 773ecd1fe..17d05bb88 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -18,7 +18,7 @@ from tianshou.policy import Algorithm, DiscreteBCQ from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor @@ -159,7 +159,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index f071c9d30..a992773e8 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -18,7 +18,7 @@ from tianshou.policy import Algorithm, DiscreteCQL from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -123,7 +123,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 2fe4aeb72..c07af56a8 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -18,7 +18,7 @@ from tianshou.policy import Algorithm, DiscreteCRR from tianshou.policy.modelfree.pg import DiscreteActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic @@ -125,7 +125,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index d9cab8723..6d5d7c3c6 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -14,7 +14,7 @@ from tianshou.policy import GAIL, Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -210,7 +210,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # trainer result = algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index ce98a4c92..641203122 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -16,7 +16,7 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.trainer import OfflineTrainingConfig +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -176,7 +176,7 @@ def stop_fn(mean_rewards: float) -> bool: # train result = algorithm.run_training( - OfflineTrainingConfig( + OfflineTrainerParams( buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index ce6f743bf..fce2ddaa3 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -13,7 +13,7 @@ from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import DQN, Algorithm, MultiAgentOffPolicyAlgorithm from tianshou.policy.modelfree.dqn import DQNPolicy -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -164,7 +164,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: # trainer result = marl_algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 471aa1671..f57d03255 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -18,7 +18,7 @@ from tianshou.policy import PPO, Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm -from tianshou.trainer import OnPolicyTrainingConfig +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.continuous import ActorProb, Critic @@ -264,7 +264,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: # train result = marl_algorithm.run_training( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 086af3cdb..670b6a722 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -21,7 +21,7 @@ ) from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory, OptimizerFactory -from tianshou.trainer import OffPolicyTrainingConfig +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -213,7 +213,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: # trainer result = marl_algorithm.run_training( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epoch=args.epoch, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 958409cd3..2f5824db0 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -70,7 +70,7 @@ from tianshou.policy.modelfree.redq import REDQPolicy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.trainer import OffPolicyTrainer, OnPolicyTrainer, Trainer -from tianshou.trainer.base import OffPolicyTrainingConfig, OnPolicyTrainingConfig +from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils.net.discrete import Actor CHECKPOINT_DICT_KEY_MODEL = "model" @@ -206,7 +206,7 @@ def create_trainer( ) algorithm = cast(OnPolicyAlgorithm, world.policy) return algorithm.create_trainer( - OnPolicyTrainingConfig( + OnPolicyTrainerParams( train_collector=world.train_collector, test_collector=world.test_collector, max_epoch=sampling_config.num_epochs, @@ -253,7 +253,7 @@ def create_trainer( ) algorithm = cast(OffPolicyAlgorithm, world.policy) return algorithm.create_trainer( - OffPolicyTrainingConfig( + OffPolicyTrainerParams( train_collector=world.train_collector, test_collector=world.test_collector, max_epoch=sampling_config.num_epochs, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 04952c245..381138367 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -37,12 +37,13 @@ from tianshou.trainer.base import ( InfoStats, OfflineTrainer, - OfflineTrainingConfig, + OfflineTrainerParams, OffPolicyTrainer, - OffPolicyTrainingConfig, + OffPolicyTrainerParams, OnPolicyTrainer, - OnPolicyTrainingConfig, + OnPolicyTrainerParams, Trainer, + TrainerParams, ) logger = logging.getLogger(__name__) @@ -421,12 +422,10 @@ def _update_lagged_network_weights(self) -> None: TPolicy = TypeVar("TPolicy", bound=Policy) -TTrainingConfig = TypeVar( - "TTrainingConfig", -) # TODO Can't use bound=TrainingConfig because of circular import +TTrainerParams = TypeVar("TTrainerParams", bound="TrainerParams") -class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainingConfig, TTrainingStats], ABC): +class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainerParams, TTrainingStats], ABC): """ TODO fix docstring The base class for any RL policy. @@ -750,23 +749,23 @@ def compute_nstep_return( return cast(BatchWithReturnsProtocol, batch) @abstractmethod - def create_trainer(self, config: TTrainingConfig) -> "Trainer": + def create_trainer(self, params: TTrainerParams) -> "Trainer": pass - def run_training(self, config: TTrainingConfig) -> "InfoStats": - trainer = self.create_trainer(config) + def run_training(self, params: TTrainerParams) -> "InfoStats": + trainer = self.create_trainer(params) return trainer.run() class OnPolicyAlgorithm( - Algorithm[TPolicy, "OnPolicyTrainingConfig", TTrainingStats], + Algorithm[TPolicy, "OnPolicyTrainerParams", TTrainingStats], Generic[TPolicy, TTrainingStats], ABC, ): - def create_trainer(self, config: "OnPolicyTrainingConfig") -> "OnPolicyTrainer": + def create_trainer(self, params: "OnPolicyTrainerParams") -> "OnPolicyTrainer": from tianshou.trainer.base import OnPolicyTrainer - return OnPolicyTrainer(self, config) + return OnPolicyTrainer(self, params) @abstractmethod def _update_with_batch( @@ -789,14 +788,14 @@ def update( class OffPolicyAlgorithm( - Algorithm[TPolicy, "OffPolicyTrainingConfig", TTrainingStats], + Algorithm[TPolicy, "OffPolicyTrainerParams", TTrainingStats], Generic[TPolicy, TTrainingStats], ABC, ): - def create_trainer(self, config: "OffPolicyTrainingConfig") -> "OffPolicyTrainer": + def create_trainer(self, params: "OffPolicyTrainerParams") -> "OffPolicyTrainer": from tianshou.trainer.base import OffPolicyTrainer - return OffPolicyTrainer(self, config) + return OffPolicyTrainer(self, params) @abstractmethod def _update_with_batch( @@ -823,7 +822,7 @@ def update( class OfflineAlgorithm( - Algorithm[TPolicy, "OfflineTrainingConfig", TTrainingStats], + Algorithm[TPolicy, "OfflineTrainerParams", TTrainingStats], Generic[TPolicy, TTrainingStats], ABC, ): @@ -831,16 +830,16 @@ def process_buffer(self, buffer: TBuffer) -> TBuffer: """Pre-process the replay buffer to prepare for offline learning, e.g. to add new keys.""" return buffer - def run_training(self, config: "OfflineTrainingConfig") -> "InfoStats": + def run_training(self, params: "OfflineTrainerParams") -> "InfoStats": # NOTE: This override is required for correct typing when converting # an algorithm to an offline algorithm using diamond inheritance # (e.g. DiscreteCQL) in order to make it match first in the MRO - return super().run_training(config) + return super().run_training(params) - def create_trainer(self, config: "OfflineTrainingConfig") -> "OfflineTrainer": + def create_trainer(self, params: "OfflineTrainerParams") -> "OfflineTrainer": from tianshou.trainer.base import OfflineTrainer - return OfflineTrainer(self, config) + return OfflineTrainer(self, params) @abstractmethod def _update_with_batch( diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 2f7d8fabc..f36ee3035 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -2,10 +2,10 @@ from .base import ( OfflineTrainer, - OfflineTrainingConfig, + OfflineTrainerParams, OffPolicyTrainer, - OffPolicyTrainingConfig, + OffPolicyTrainerParams, OnPolicyTrainer, - OnPolicyTrainingConfig, + OnPolicyTrainerParams, Trainer, ) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index cebaee91c..4b015b04c 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -65,7 +65,7 @@ @dataclass(kw_only=True) -class TrainingConfig(ToStringMixin): +class TrainerParams(ToStringMixin): max_epoch: int = 100 """ the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each @@ -217,7 +217,7 @@ def __post_init__(self): @dataclass(kw_only=True) -class OnlineTrainingConfig(TrainingConfig): +class OnlineTrainerParams(TrainerParams): train_collector: BaseCollector """ the collector with which to gather new data for training in each training step @@ -269,7 +269,7 @@ def __post_init__(self): @dataclass(kw_only=True) -class OnPolicyTrainingConfig(OnlineTrainingConfig): +class OnPolicyTrainerParams(OnlineTrainerParams): batch_size: int | None = 64 """ Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, @@ -288,7 +288,7 @@ class OnPolicyTrainingConfig(OnlineTrainingConfig): @dataclass(kw_only=True) -class OffPolicyTrainingConfig(OnlineTrainingConfig): +class OffPolicyTrainerParams(OnlineTrainerParams): batch_size: int = 64 """ the the number of environment steps/transitions to sample from the buffer for a gradient update. @@ -304,7 +304,7 @@ class OffPolicyTrainingConfig(OnlineTrainingConfig): @dataclass(kw_only=True) -class OfflineTrainingConfig(TrainingConfig): +class OfflineTrainerParams(TrainerParams): buffer: ReplayBuffer """ the replay buffer with environment steps to use as training data for offline learning. @@ -318,12 +318,12 @@ class OfflineTrainingConfig(TrainingConfig): """ -TTrainingConfig = TypeVar("TTrainingConfig", bound=TrainingConfig) -TOnlineTrainingConfig = TypeVar("TOnlineTrainingConfig", bound=OnlineTrainingConfig) +TTrainerParams = TypeVar("TTrainerParams", bound=TrainerParams) +TOnlineTrainerParams = TypeVar("TOnlineTrainerParams", bound=OnlineTrainerParams) TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) -class Trainer(Generic[TAlgorithm, TTrainingConfig], ABC): +class Trainer(Generic[TAlgorithm, TTrainerParams], ABC): """ Base class for trainers in Tianshou, which orchestrate the training process and call upon an RL algorithm's specific network updating logic to perform the actual gradient updates. @@ -334,13 +334,13 @@ class Trainer(Generic[TAlgorithm, TTrainingConfig], ABC): def __init__( self, - policy: TAlgorithm, - config: TTrainingConfig, + algorithm: TAlgorithm, + params: TTrainerParams, ): - self.algorithm = policy - self.config = config + self.algorithm = algorithm + self.params = params - self._logger = config.logger or LazyLogger() + self._logger = params.logger or LazyLogger() self._start_time = time.time() self._stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) @@ -372,7 +372,7 @@ def __init__( self._policy_update_time = 0.0 self._compute_score_fn: Callable[[CollectStats], float] = ( - config.compute_score_fn or self._compute_score_fn_default + params.compute_score_fn or self._compute_score_fn_default ) self._stop_fn_flag = False @@ -395,12 +395,12 @@ def _pbar(self) -> Callable[..., tqdm.tqdm]: tqdm.tqdm, dynamic_ncols=True, ascii=True, - disable=not self.config.show_progress, + disable=not self.params.show_progress, ) def _reset_collectors(self, reset_buffer: bool = False) -> None: - if self.config.test_collector is not None: - self.config.test_collector.reset(reset_buffer=reset_buffer) + if self.params.test_collector is not None: + self.params.test_collector.reset(reset_buffer=reset_buffer) def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = False) -> None: """Initializes the training process. @@ -416,7 +416,7 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F self._env_step = 0 self._current_update_step = 0 - if self.config.resume_from_log: + if self.params.resume_from_log: ( self._start_epoch, self._env_step, @@ -431,9 +431,9 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F self._reset_collectors(reset_buffer=reset_collector_buffers) # make an initial test step to determine the initial best model - if self.config.test_collector is not None: - assert self.config.episode_per_test is not None - assert not isinstance(self.config.test_collector, AsyncCollector) # Issue 700 + if self.params.test_collector is not None: + assert self.params.episode_per_test is not None + assert not isinstance(self.params.test_collector, AsyncCollector) # Issue 700 self._test_step(force_update_best=True, log_msg_prefix="Initial test step") self._stop_fn_flag = False @@ -474,9 +474,9 @@ def _create_epoch_pbar_data_dict( def _create_info_stats( self, ) -> InfoStats: - test_collector = self.config.test_collector - if isinstance(self.config, OnlineTrainingConfig): - train_collector = self.config.train_collector + test_collector = self.params.test_collector + if isinstance(self.params, OnlineTrainerParams): + train_collector = self.params.train_collector else: train_collector = None @@ -519,9 +519,9 @@ def execute_epoch(self) -> EpochStats: steps_done_in_this_epoch = 0 train_collect_stats, training_stats = None, None with self._pbar( - total=self.config.step_per_epoch, desc=f"Epoch #{self._epoch}", position=1 + total=self.params.step_per_epoch, desc=f"Epoch #{self._epoch}", position=1 ) as t: - while steps_done_in_this_epoch < self.config.step_per_epoch and not self._stop_fn_flag: + while steps_done_in_this_epoch < self.params.step_per_epoch and not self._stop_fn_flag: # perform a training step and update progress self._current_update_step += 1 training_step_result = self._training_step() @@ -546,11 +546,11 @@ def execute_epoch(self) -> EpochStats: self._epoch, self._env_step, self._current_update_step, - self.config.save_checkpoint_fn, + self.params.save_checkpoint_fn, ) # test step - if self.config.test_collector is not None: + if self.params.test_collector is not None: test_collect_stats, self._stop_fn_flag = self._test_step() info_stats = self._create_info_stats() @@ -573,7 +573,7 @@ def _should_stop_training_early( based on the score achieved or the collection stats (from which the score could be computed). """ # If no stop criterion is defined, we can never stop training early - if self.config.stop_fn is None: + if self.params.stop_fn is None: return False if score is None: @@ -586,18 +586,18 @@ def _should_stop_training_early( score = self._compute_score_fn(collect_stats) - return self.config.stop_fn(score) + return self.params.stop_fn(score) def _collect_test_episodes( self, ) -> CollectStats: - collector = self.config.test_collector + collector = self.params.test_collector collector.reset(reset_stats=False) - if self.config.test_fn: - self.config.test_fn(self._epoch, self._env_step) - result = collector.collect(n_episode=self.config.episode_per_test) - if self.config.reward_metric: # TODO: move into collector - rew = self.config.reward_metric(result.returns) + if self.params.test_fn: + self.params.test_fn(self._epoch, self._env_step) + result = collector.collect(n_episode=self.params.episode_per_test) + if self.params.reward_metric: # TODO: move into collector + rew = self.params.reward_metric(result.returns) result.returns = rew result.returns_stat = SequenceSummaryStats.from_sequence(rew) if self._logger and self._env_step is not None: @@ -615,8 +615,8 @@ def _test_step( :param force_update_best: whether to force updating of the best model stats (best score, reward, etc.) and call the `save_best_fn` callback """ - assert self.config.episode_per_test is not None - assert self.config.test_collector is not None + assert self.params.episode_per_test is not None + assert self.params.test_collector is not None # collect test episodes test_stat = self._collect_test_episodes() @@ -631,8 +631,8 @@ def _test_step( self._best_epoch = self._epoch self._best_reward = float(rew) self._best_reward_std = rew_std - if self.config.save_best_fn: - self.config.save_best_fn(self.algorithm) + if self.params.save_best_fn: + self.params.save_best_fn(self.algorithm) # log results cur_info, best_info = "", "" @@ -646,7 +646,7 @@ def _test_step( f"{self._best_reward_std:.6f}{best_info} in #{self._best_epoch}" ) log.info(log_msg) - if self.config.verbose: + if self.params.verbose: print(log_msg, flush=True) # determine whether training shall be stopped early @@ -702,24 +702,24 @@ def run( reset_collectors=reset_collectors, reset_collector_buffers=reset_collector_buffers ) - while self._epoch < self.config.max_epoch and not self._stop_fn_flag: + while self._epoch < self.params.max_epoch and not self._stop_fn_flag: self.execute_epoch() return self._create_info_stats() -class OfflineTrainer(Trainer[OfflineAlgorithm, OfflineTrainingConfig]): +class OfflineTrainer(Trainer[OfflineAlgorithm, OfflineTrainerParams]): """An offline trainer, which samples mini-batches from a given buffer and passes them to the algorithm's update function. """ def __init__( self, - policy: "Algorithm", - config: OfflineTrainingConfig, + algorithm: "Algorithm", + params: OfflineTrainerParams, ): - super().__init__(policy, config) - self._buffer = policy.process_buffer(self.config.buffer) + super().__init__(algorithm, params) + self._buffer = algorithm.process_buffer(self.params.buffer) class _TrainingStepResult(Trainer._TrainingStepResult): def __init__(self, training_stats: TrainingStats, env_step_advancement: int): @@ -747,12 +747,12 @@ def _training_step(self) -> _TrainingStepResult: # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. training_stats = self.algorithm.update( - sample_size=self.config.batch_size, buffer=self._buffer + sample_size=self.params.batch_size, buffer=self._buffer ) self._update_moving_avg_stats_and_log_update_data(training_stats) self._policy_update_time += training_stats.train_time return self._TrainingStepResult( - training_stats=training_stats, env_step_advancement=self.config.batch_size + training_stats=training_stats, env_step_advancement=self.params.batch_size ) def _create_epoch_pbar_data_dict( @@ -762,7 +762,7 @@ def _create_epoch_pbar_data_dict( class OnlineTrainer( - Trainer[TAlgorithm, TOnlineTrainingConfig], Generic[TAlgorithm, TOnlineTrainingConfig], ABC + Trainer[TAlgorithm, TOnlineTrainerParams], Generic[TAlgorithm, TOnlineTrainerParams], ABC ): """ An online trainer, which collects data from the environment in each training step and @@ -772,10 +772,10 @@ class OnlineTrainer( def __init__( self, - policy: "Algorithm", - config: OnlineTrainingConfig, + algorithm: "Algorithm", + params: OnlineTrainerParams, ): - super().__init__(policy, config) + super().__init__(algorithm, params) self._env_episode = 0 """ the total number of episodes collected in the environment @@ -783,7 +783,7 @@ def __init__( def _reset_collectors(self, reset_buffer: bool = False) -> None: super()._reset_collectors(reset_buffer=reset_buffer) - self.config.train_collector.reset(reset_buffer=reset_buffer) + self.params.train_collector.reset(reset_buffer=reset_buffer) def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = False) -> None: super().reset( @@ -791,8 +791,8 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F ) if ( - self.config.test_in_train - and self.config.train_collector.policy is not self.algorithm.policy + self.params.test_in_train + and self.params.train_collector.policy is not self.algorithm.policy ): log.warning( "The training data collector's policy is not the same as the one being trained, " @@ -843,7 +843,7 @@ def _training_step(self) -> _TrainingStepResult: # determine whether we should stop training based on the data collected should_stop_training = False - if self.config.test_in_train: + if self.params.test_in_train: should_stop_training = self._test_in_train(collect_stats) # perform gradient update step (if not already done) @@ -862,18 +862,18 @@ def _collect_training_data(self) -> CollectStats: :return: the data collection stats """ - assert self.config.episode_per_test is not None - assert self.config.train_collector is not None + assert self.params.episode_per_test is not None + assert self.params.train_collector is not None - if self.config.train_fn: - self.config.train_fn(self._epoch, self._env_step) + if self.params.train_fn: + self.params.train_fn(self._epoch, self._env_step) - collect_stats = self.config.train_collector.collect( - n_step=self.config.step_per_collect, - n_episode=self.config.episode_per_collect, + collect_stats = self.params.train_collector.collect( + n_step=self.params.step_per_collect, + n_episode=self.params.episode_per_collect, ) - if self.config.train_collector.buffer.hasnull(): + if self.params.train_collector.buffer.hasnull(): from tianshou.data.collector import EpisodeRolloutHook from tianshou.env import DummyVectorEnv @@ -888,8 +888,8 @@ def _collect_training_data(self) -> CollectStats: if collect_stats.n_collected_episodes > 0: assert collect_stats.returns_stat is not None # for mypy assert collect_stats.lens_stat is not None # for mypy - if self.config.reward_metric: # TODO: move inside collector - rew = self.config.reward_metric(collect_stats.returns) + if self.params.reward_metric: # TODO: move inside collector + rew = self.params.reward_metric(collect_stats.returns) collect_stats.returns = rew collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) @@ -960,7 +960,7 @@ def _create_epoch_pbar_data_dict( return result -class OffPolicyTrainer(OnlineTrainer[OffPolicyAlgorithm, OffPolicyTrainingConfig]): +class OffPolicyTrainer(OnlineTrainer[OffPolicyAlgorithm, OffPolicyTrainerParams]): """An off-policy trainer, which samples mini-batches from the buffer of collected data and passes them to algorithm's `update` function. @@ -978,13 +978,13 @@ def _update_step( :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values in it will be replaced by their moving averages. """ - assert self.config.train_collector is not None + assert self.params.train_collector is not None n_collected_steps = collect_stats.n_collected_steps - n_gradient_steps = round(self.config.update_per_step * n_collected_steps) + n_gradient_steps = round(self.params.update_per_step * n_collected_steps) if n_gradient_steps == 0: raise ValueError( f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " - f"update_per_step={self.config.update_per_step}", + f"update_per_step={self.params.update_per_step}", ) update_stat = None @@ -994,7 +994,7 @@ def _update_step( position=0, leave=False, ): - update_stat = self._sample_and_update(self.config.train_collector.buffer) + update_stat = self._sample_and_update(self.params.train_collector.buffer) self._policy_update_time += update_stat.train_time # TODO: only the last update_stat is returned, should be improved @@ -1005,12 +1005,12 @@ def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: # Note: since sample_size=batch_size, this will perform # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. - update_stat = self.algorithm.update(sample_size=self.config.batch_size, buffer=buffer) + update_stat = self.algorithm.update(sample_size=self.params.batch_size, buffer=buffer) self._update_moving_avg_stats_and_log_update_data(update_stat) return update_stat -class OnPolicyTrainer(OnlineTrainer[OnPolicyAlgorithm, OnPolicyTrainingConfig]): +class OnPolicyTrainer(OnlineTrainer[OnPolicyAlgorithm, OnPolicyTrainerParams]): """An on-policy trainer, which passes the entire buffer to the algorithm's `update` methods and resets the buffer thereafter. @@ -1023,15 +1023,15 @@ def _update_step( result: CollectStatsBase | None = None, ) -> TrainingStats: """Perform one on-policy update by passing the entire buffer to the algorithm's update method.""" - assert self.config.train_collector is not None + assert self.params.train_collector is not None # TODO: add logging like in off-policy. Iteration over minibatches currently happens in the algorithms themselves. log.info( - f"Performing on-policy update on buffer of length {len(self.config.train_collector.buffer)}", + f"Performing on-policy update on buffer of length {len(self.params.train_collector.buffer)}", ) training_stat = self.algorithm.update( - buffer=self.config.train_collector.buffer, - batch_size=self.config.batch_size, - repeat=self.config.repeat_per_collect, + buffer=self.params.train_collector.buffer, + batch_size=self.params.batch_size, + repeat=self.params.repeat_per_collect, ) # just for logging, no functional role @@ -1046,7 +1046,7 @@ def _update_step( # _ep_rew and _ep_len. This means that such quantities can no longer be computed # from samples still contained in the buffer, which is also not clean # TODO: improve this situation - self.config.train_collector.reset_buffer(keep_statistics=True) + self.params.train_collector.reset_buffer(keep_statistics=True) # The step is the number of mini-batches used for the update, so essentially self._update_moving_avg_stats_and_log_update_data(training_stat) diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 129e20668..fcdceb0de 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -26,8 +26,6 @@ class WandbLogger(BaseLogger): logger = WandbLogger() logger.load(SummaryWriter(log_path)) - result = OnpolicyTrainer(policy, train_collector, test_collector, - logger=logger).run() :param train_interval: the log interval in log_train_data(). Default to 1000. :param test_interval: the log interval in log_test_data(). Default to 1. From 9bfdb8971e0546184b2733cf6f77a628f6da525e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 16:40:29 +0100 Subject: [PATCH 069/230] v2: High-Level API: Replace SamplingConfig with *TrainingConfig * Properly differentiate off-policy/on-policy cases, no longer allowing inapplicable parameters to be set * Expose test_in_train (previously not available in HL API) --- CHANGELOG.md | 5 + README.md | 12 +- examples/atari/atari_dqn_hl.py | 12 +- examples/atari/atari_iqn_hl.py | 12 +- examples/atari/atari_ppo_hl.py | 12 +- examples/atari/atari_sac_hl.py | 11 +- examples/discrete/discrete_dqn_hl.py | 4 +- examples/mujoco/mujoco_a2c_hl.py | 12 +- examples/mujoco/mujoco_ddpg_hl.py | 11 +- examples/mujoco/mujoco_npg_hl.py | 12 +- examples/mujoco/mujoco_ppo_hl.py | 12 +- examples/mujoco/mujoco_ppo_hl_multi.py | 12 +- examples/mujoco/mujoco_redq_hl.py | 11 +- examples/mujoco/mujoco_reinforce_hl.py | 12 +- examples/mujoco/mujoco_sac_hl.py | 10 +- examples/mujoco/mujoco_td3_hl.py | 10 +- examples/mujoco/mujoco_trpo_hl.py | 12 +- test/highlevel/test_experiment_builder.py | 38 ++++- tianshou/highlevel/algorithm.py | 115 ++++++------- tianshou/highlevel/config.py | 137 ++++++++++----- tianshou/highlevel/experiment.py | 194 +++++++++++++--------- tianshou/highlevel/params/lr_scheduler.py | 17 +- 22 files changed, 404 insertions(+), 279 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d5515a1c..ad0be9bfa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,6 +121,11 @@ * Learning rate schedulers remain separate parameters and now use `LRSchedulerFactoryFactory` instances. The respective parameter names now use the suffix `lr_scheduler` instead of `lr_scheduler_factory` (as the precise nature need not be reflected in the name; brevity is preferable). + * `SamplingConfig` is replaced by `TrainingConfig` and subclasses differentiating off-policy and on-policy cases + appropriately (`OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`). + * The `test_in_train` parameter is now exposed (default False). + * Inapplicable arguments can no longer be set in the respective subclass (e.g. `OffPolicyTrainingConfig` does not + contain parameter `repeat_per_collect`). * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. ## Unreleased diff --git a/README.md b/README.md index eb9067c36..61439a46d 100644 --- a/README.md +++ b/README.md @@ -210,17 +210,17 @@ We shall apply the deep Q network (DQN) learning algorithm using both APIs. To get started, we need some imports. ```python -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import TrainingConfig from tianshou.highlevel.env import ( - EnvFactoryRegistered, - VectorEnvType, + EnvFactoryRegistered, + VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.trainer import ( - EpochTestCallbackDQNSetEps, - EpochTrainCallbackDQNSetEps, - EpochStopCallbackRewardThreshold + EpochTestCallbackDQNSetEps, + EpochTrainCallbackDQNSetEps, + EpochStopCallbackRewardThreshold ) ``` diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 1f3469d39..63f7b5afe 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -10,7 +10,7 @@ IntermediateModuleFactoryAtariDQNFeatures, ) from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DQNExperimentBuilder, ExperimentConfig, @@ -45,14 +45,13 @@ def main( training_num: int = 10, test_num: int = 10, frames_stack: int = 4, - save_buffer_name: str | None = None, # TODO support? icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, ) -> None: log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OffPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -61,7 +60,6 @@ def main( buffer_size=buffer_size, step_per_collect=step_per_collect, update_per_step=update_per_step, - repeat_per_collect=None, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -69,14 +67,14 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, + training_config.train_seed, + training_config.test_seed, frames_stack, scale=scale_obs, ) builder = ( - DQNExperimentBuilder(env_factory, experiment_config, sampling_config) + DQNExperimentBuilder(env_factory, experiment_config, training_config) .with_dqn_params( DQNParams( discount_factor=gamma, diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 0cdfa4d81..4425c2649 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -10,7 +10,7 @@ IntermediateModuleFactoryAtariDQN, ) from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, IQNExperimentBuilder, @@ -47,11 +47,10 @@ def main( training_num: int = 10, test_num: int = 10, frames_stack: int = 4, - save_buffer_name: str | None = None, # TODO support? ) -> None: log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OffPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -60,7 +59,6 @@ def main( buffer_size=buffer_size, step_per_collect=step_per_collect, update_per_step=update_per_step, - repeat_per_collect=None, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -68,14 +66,14 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, + training_config.train_seed, + training_config.test_seed, frames_stack, scale=scale_obs, ) experiment = ( - IQNExperimentBuilder(env_factory, experiment_config, sampling_config) + IQNExperimentBuilder(env_factory, experiment_config, training_config) .with_iqn_params( IQNParams( discount_factor=gamma, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 6f8f8a6ce..32673f937 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -11,7 +11,7 @@ IntermediateModuleFactoryAtariDQNFeatures, ) from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, @@ -57,7 +57,7 @@ def main( ) -> None: log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OnPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -73,14 +73,14 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, + training_config.train_seed, + training_config.test_seed, frames_stack, scale=scale_obs, ) builder = ( - PPOExperimentBuilder(env_factory, experiment_config, sampling_config) + PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( discount_factor=gamma, @@ -95,7 +95,7 @@ def main( dual_clip=dual_clip, recompute_advantage=recompute_adv, lr=lr, - lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory(ActorFactoryAtariDQN(scale_obs=scale_obs, features_only=True)) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 3ce60aa4d..2e5fe7c90 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -11,7 +11,7 @@ IntermediateModuleFactoryAtariDQNFeatures, ) from tianshou.env.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DiscreteSACExperimentBuilder, ExperimentConfig, @@ -51,7 +51,7 @@ def main( ) -> None: log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OffPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, update_per_step=update_per_step, @@ -60,7 +60,6 @@ def main( num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - repeat_per_collect=None, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -68,14 +67,14 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, + training_config.train_seed, + training_config.test_seed, frames_stack, scale=scale_obs, ) builder = ( - DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) + DiscreteSACExperimentBuilder(env_factory, experiment_config, training_config) .with_sac_params( DiscreteSACParams( actor_lr=actor_lr, diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index 464d06e71..c44db7c08 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -1,6 +1,6 @@ from sensai.util.logging import run_main -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.env import ( EnvFactoryRegistered, VectorEnvType, @@ -29,7 +29,7 @@ def main() -> None: watch_render=1 / 35, watch_num_episodes=100, ), - SamplingConfig( + OffPolicyTrainingConfig( num_epochs=10, step_per_epoch=10000, batch_size=64, diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index e3a30f8a4..6f2fb6fa4 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -9,7 +9,7 @@ from torch import nn from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( A2CExperimentBuilder, ExperimentConfig, @@ -43,7 +43,7 @@ def main( ) -> None: log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OnPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -56,13 +56,13 @@ def main( env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=True, ) experiment = ( - A2CExperimentBuilder(env_factory, experiment_config, sampling_config) + A2CExperimentBuilder(env_factory, experiment_config, training_config) .with_a2c_params( A2CParams( discount_factor=gamma, @@ -74,7 +74,7 @@ def main( max_grad_norm=max_grad_norm, optim=OptimizerFactoryFactoryRMSprop(eps=1e-5, alpha=0.99), lr=lr, - lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 27dbfc8d9..2acfa2f1f 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -7,7 +7,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( DDPGExperimentBuilder, ExperimentConfig, @@ -38,7 +38,7 @@ def main( ) -> None: log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OffPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -47,20 +47,19 @@ def main( buffer_size=buffer_size, step_per_collect=step_per_collect, update_per_step=update_per_step, - repeat_per_collect=None, start_timesteps=start_timesteps, start_timesteps_random=True, ) env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=False, ) experiment = ( - DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) + DDPGExperimentBuilder(env_factory, experiment_config, training_config) .with_ddpg_params( DDPGParams( actor_lr=actor_lr, diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index fc1f5afb9..63dfc7b89 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -9,7 +9,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, NPGExperimentBuilder, @@ -42,7 +42,7 @@ def main( ) -> None: log_name = os.path.join(task, "npg", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OnPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -55,13 +55,13 @@ def main( env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=True, ) experiment = ( - NPGExperimentBuilder(env_factory, experiment_config, sampling_config) + NPGExperimentBuilder(env_factory, experiment_config, training_config) .with_npg_params( NPGParams( discount_factor=gamma, @@ -72,7 +72,7 @@ def main( optim_critic_iters=optim_critic_iters, actor_step_size=actor_step_size, lr=lr, - lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 7334e6bfa..997173084 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -9,7 +9,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, @@ -47,7 +47,7 @@ def main( ) -> None: log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OnPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -60,13 +60,13 @@ def main( env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=True, ) experiment = ( - PPOExperimentBuilder(env_factory, experiment_config, sampling_config) + PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( discount_factor=gamma, @@ -82,7 +82,7 @@ def main( dual_clip=dual_clip, recompute_advantage=recompute_adv, lr=lr, - lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 9e63fc59a..a8d3ae828 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -22,7 +22,7 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.evaluation.launcher import RegisteredExpLauncher from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, @@ -58,7 +58,7 @@ def main( experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False) - sampling_config = SamplingConfig( + training_config = OnPolicyTrainingConfig( num_epochs=1, step_per_epoch=5000, batch_size=64, @@ -72,8 +72,8 @@ def main( env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=True, ) @@ -95,7 +95,7 @@ def main( raise ValueError(f"Unknown logger type: {logger_type}") experiment_collection = ( - PPOExperimentBuilder(env_factory, experiment_config, sampling_config) + PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( discount_factor=0.99, @@ -111,7 +111,7 @@ def main( dual_clip=None, recompute_advantage=True, lr=3e-4, - lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config), + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index c0c63279a..24ab3a073 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -8,7 +8,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, REDQExperimentBuilder, @@ -44,7 +44,7 @@ def main( ) -> None: log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OffPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -53,20 +53,19 @@ def main( buffer_size=buffer_size, step_per_collect=step_per_collect, update_per_step=update_per_step, - repeat_per_collect=None, start_timesteps=start_timesteps, start_timesteps_random=True, ) env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=False, ) experiment = ( - REDQExperimentBuilder(env_factory, experiment_config, sampling_config) + REDQExperimentBuilder(env_factory, experiment_config, training_config) .with_redq_params( REDQParams( actor_lr=actor_lr, diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index cac066c25..accf7cbed 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -9,7 +9,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, PGExperimentBuilder, @@ -38,7 +38,7 @@ def main( ) -> None: log_name = os.path.join(task, "reinforce", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OnPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -51,20 +51,20 @@ def main( env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=True, ) experiment = ( - PGExperimentBuilder(env_factory, experiment_config, sampling_config) + PGExperimentBuilder(env_factory, experiment_config, training_config) .with_pg_params( PGParams( discount_factor=gamma, action_bound_method=action_bound_method, reward_normalization=rew_norm, lr=lr, - lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index a150f5571..6e4af1c91 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -7,7 +7,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, SACExperimentBuilder, @@ -40,7 +40,7 @@ def main( ) -> None: log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OffPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, num_train_envs=training_num, @@ -55,13 +55,13 @@ def main( env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=False, ) experiment = ( - SACExperimentBuilder(env_factory, experiment_config, sampling_config) + SACExperimentBuilder(env_factory, experiment_config, training_config) .with_sac_params( SACParams( tau=tau, diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 5ec9cc17b..3f9afb237 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -8,7 +8,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import TrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, TD3ExperimentBuilder, @@ -45,7 +45,7 @@ def main( ) -> None: log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = TrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, num_train_envs=training_num, @@ -60,13 +60,13 @@ def main( env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=False, ) experiment = ( - TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) + TD3ExperimentBuilder(env_factory, experiment_config, training_config) .with_td3_params( TD3Params( tau=tau, diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 2aa2a24ca..c2518d07e 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -9,7 +9,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, TRPOExperimentBuilder, @@ -44,7 +44,7 @@ def main( ) -> None: log_name = os.path.join(task, "trpo", str(experiment_config.seed), datetime_tag()) - sampling_config = SamplingConfig( + training_config = OnPolicyTrainingConfig( num_epochs=epoch, step_per_epoch=step_per_epoch, batch_size=batch_size, @@ -57,13 +57,13 @@ def main( env_factory = MujocoEnvFactory( task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, + train_seed=training_config.train_seed, + test_seed=training_config.test_seed, obs_norm=True, ) experiment = ( - TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) + TRPOExperimentBuilder(env_factory, experiment_config, training_config) .with_trpo_params( TRPOParams( discount_factor=gamma, @@ -76,7 +76,7 @@ def main( backtrack_coeff=backtrack_coeff, max_backtracks=max_backtracks, lr=lr, - lr_scheduler=LRSchedulerFactoryFactoryLinear(sampling_config) if lr_decay else None, + lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 5e61ac832..29eaa8074 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -2,7 +2,10 @@ import pytest -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import ( + OffPolicyTrainingConfig, + OnPolicyTrainingConfig, +) from tianshou.highlevel.experiment import ( A2CExperimentBuilder, DDPGExperimentBuilder, @@ -11,6 +14,8 @@ ExperimentBuilder, ExperimentConfig, IQNExperimentBuilder, + OffPolicyExperimentBuilder, + OnPolicyExperimentBuilder, PGExperimentBuilder, PPOExperimentBuilder, REDQExperimentBuilder, @@ -20,6 +25,27 @@ ) +def create_training_config( + builder_cls: type[ExperimentBuilder], + num_epochs: int = 1, + step_per_epoch: int = 100, + num_train_envs: int = 2, + num_test_envs: int = 2, +) -> OffPolicyTrainingConfig | OnPolicyTrainingConfig: + if issubclass(builder_cls, OffPolicyExperimentBuilder): + cfg_class = OffPolicyTrainingConfig + elif issubclass(builder_cls, OnPolicyExperimentBuilder): + cfg_class = OnPolicyTrainingConfig + else: + raise ValueError + return cfg_class( + num_epochs=num_epochs, + step_per_epoch=step_per_epoch, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, + ) + + @pytest.mark.parametrize( "builder_cls", [ @@ -36,7 +62,8 @@ ) def test_experiment_builder_continuous_default_params(builder_cls: type[ExperimentBuilder]) -> None: env_factory = ContinuousTestEnvFactory() - sampling_config = SamplingConfig( + training_config = create_training_config( + builder_cls, num_epochs=1, step_per_epoch=100, num_train_envs=2, @@ -46,7 +73,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime builder = builder_cls( experiment_config=experiment_config, env_factory=env_factory, - sampling_config=sampling_config, + training_config=training_config, ) experiment = builder.build() experiment.run(run_name="test") @@ -66,7 +93,8 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime ) def test_experiment_builder_discrete_default_params(builder_cls: type[ExperimentBuilder]) -> None: env_factory = DiscreteTestEnvFactory() - sampling_config = SamplingConfig( + training_config = create_training_config( + builder_cls, num_epochs=1, step_per_epoch=100, num_train_envs=2, @@ -75,7 +103,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment builder = builder_cls( experiment_config=ExperimentConfig(persistence_enabled=False), env_factory=env_factory, - sampling_config=sampling_config, + training_config=training_config, ) experiment = builder.build() experiment.run(run_name="test") diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 2f5824db0..1ab4d597f 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -9,7 +9,11 @@ from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data.collector import BaseCollector, CollectStats -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import ( + OffPolicyTrainingConfig, + OnPolicyTrainingConfig, + TrainingConfig, +) from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ( ActorFactory, @@ -90,14 +94,15 @@ ) TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) TPolicy = TypeVar("TPolicy", bound=Policy) +TTrainingConfig = TypeVar("TTrainingConfig", bound=TrainingConfig) log = logging.getLogger(__name__) -class AlgorithmFactory(ABC, ToStringMixin): +class AlgorithmFactory(ABC, ToStringMixin, Generic[TTrainingConfig]): """Factory for the creation of an :class:`Algorithm` instance, its policy, trainer as well as collectors.""" - def __init__(self, sampling_config: SamplingConfig, optim_factory: OptimizerFactoryFactory): - self.sampling_config = sampling_config + def __init__(self, training_config: TTrainingConfig, optim_factory: OptimizerFactoryFactory): + self.training_config = training_config self.optim_factory = optim_factory self.algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks() @@ -114,23 +119,23 @@ def create_train_test_collector( Setting to True means that the envs will be reset as well. :return: """ - buffer_size = self.sampling_config.buffer_size + buffer_size = self.training_config.buffer_size train_envs = envs.train_envs buffer: ReplayBuffer if len(train_envs) > 1: buffer = VectorReplayBuffer( buffer_size, len(train_envs), - stack_num=self.sampling_config.replay_buffer_stack_num, - save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, - ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, + stack_num=self.training_config.replay_buffer_stack_num, + save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, + ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, ) else: buffer = ReplayBuffer( buffer_size, - stack_num=self.sampling_config.replay_buffer_stack_num, - save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, - ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, + stack_num=self.training_config.replay_buffer_stack_num, + save_only_last_obs=self.training_config.replay_buffer_save_only_last_obs, + ignore_obs_next=self.training_config.replay_buffer_ignore_obs_next, ) train_collector = Collector[CollectStats]( policy, @@ -180,13 +185,13 @@ def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> pass -class OnPolicyAlgorithmFactory(AlgorithmFactory, ABC): +class OnPolicyAlgorithmFactory(AlgorithmFactory[OnPolicyTrainingConfig], ABC): def create_trainer( self, world: World, policy_persistence: PolicyPersistence, ) -> OnPolicyTrainer: - sampling_config = self.sampling_config + training_config = self.training_config callbacks = self.trainer_callbacks context = TrainingContext(world.policy, world.envs, world.logger) train_fn = ( @@ -209,16 +214,16 @@ def create_trainer( OnPolicyTrainerParams( train_collector=world.train_collector, test_collector=world.test_collector, - max_epoch=sampling_config.num_epochs, - step_per_epoch=sampling_config.step_per_epoch, - repeat_per_collect=sampling_config.repeat_per_collect, - episode_per_test=sampling_config.num_test_episodes, - batch_size=sampling_config.batch_size, - step_per_collect=sampling_config.step_per_collect, + max_epoch=training_config.num_epochs, + step_per_epoch=training_config.step_per_epoch, + repeat_per_collect=training_config.repeat_per_collect, + episode_per_test=training_config.num_test_episodes, + batch_size=training_config.batch_size, + step_per_collect=training_config.step_per_collect, save_best_fn=policy_persistence.get_save_best_fn(world), save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), logger=world.logger, - test_in_train=False, + test_in_train=training_config.test_in_train, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, @@ -227,13 +232,13 @@ def create_trainer( ) -class OffPolicyAlgorithmFactory(AlgorithmFactory, ABC): +class OffPolicyAlgorithmFactory(AlgorithmFactory[OffPolicyTrainingConfig], ABC): def create_trainer( self, world: World, policy_persistence: PolicyPersistence, ) -> OffPolicyTrainer: - sampling_config = self.sampling_config + training_config = self.training_config callbacks = self.trainer_callbacks context = TrainingContext(world.policy, world.envs, world.logger) train_fn = ( @@ -256,15 +261,15 @@ def create_trainer( OffPolicyTrainerParams( train_collector=world.train_collector, test_collector=world.test_collector, - max_epoch=sampling_config.num_epochs, - step_per_epoch=sampling_config.step_per_epoch, - step_per_collect=sampling_config.step_per_collect, - episode_per_test=sampling_config.num_test_episodes, - batch_size=sampling_config.batch_size, + max_epoch=training_config.num_epochs, + step_per_epoch=training_config.step_per_epoch, + step_per_collect=training_config.step_per_collect, + episode_per_test=training_config.num_test_episodes, + batch_size=training_config.batch_size, save_best_fn=policy_persistence.get_save_best_fn(world), logger=world.logger, - update_per_step=sampling_config.update_per_step, - test_in_train=False, + update_per_step=training_config.update_per_step, + test_in_train=training_config.test_in_train, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, @@ -277,11 +282,11 @@ class ReinforceAlgorithmFactory(OnPolicyAlgorithmFactory): def __init__( self, params: PGParams, - sampling_config: SamplingConfig, + training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactoryFactory, ): - super().__init__(sampling_config, optim_factory) + super().__init__(training_config, optim_factory) self.params = params self.actor_factory = actor_factory self.optim_factory = optim_factory @@ -312,20 +317,19 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: ) -class ActorCriticAlgorithmFactory( - Generic[TActorCriticParams, TAlgorithm], +class ActorCriticOnPolicyAlgorithmFactory( OnPolicyAlgorithmFactory, - ABC, + Generic[TActorCriticParams, TAlgorithm], ): def __init__( self, params: TActorCriticParams, - sampling_config: SamplingConfig, + training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optimizer_factory: OptimizerFactoryFactory, ): - super().__init__(sampling_config, optim_factory=optimizer_factory) + super().__init__(training_config, optim_factory=optimizer_factory) self.params = params self.actor_factory = actor_factory self.critic_factory = critic_factory @@ -373,38 +377,38 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: return algorithm_class(policy=policy, **params) -class A2CAlgorithmFactory(ActorCriticAlgorithmFactory[A2CParams, A2C]): +class A2CAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[A2CParams, A2C]): def _get_algorithm_class(self) -> type[A2C]: return A2C -class PPOAlgorithmFactory(ActorCriticAlgorithmFactory[PPOParams, PPO]): +class PPOAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[PPOParams, PPO]): def _get_algorithm_class(self) -> type[PPO]: return PPO -class NPGAlgorithmFactory(ActorCriticAlgorithmFactory[NPGParams, NPG]): +class NPGAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[NPGParams, NPG]): def _get_algorithm_class(self) -> type[NPG]: return NPG -class TRPOAlgorithmFactory(ActorCriticAlgorithmFactory[TRPOParams, TRPO]): +class TRPOAlgorithmFactory(ActorCriticOnPolicyAlgorithmFactory[TRPOParams, TRPO]): def _get_algorithm_class(self) -> type[TRPO]: return TRPO -class DiscreteCriticOnlyAlgorithmFactory( +class DiscreteCriticOnlyOffPolicyAlgorithmFactory( OffPolicyAlgorithmFactory, Generic[TDiscreteCriticOnlyParams, TAlgorithm], ): def __init__( self, params: TDiscreteCriticOnlyParams, - sampling_config: SamplingConfig, + training_config: OffPolicyTrainingConfig, model_factory: ModuleFactory, optim_factory: OptimizerFactoryFactory, ): - super().__init__(sampling_config, optim_factory) + super().__init__(training_config, optim_factory) self.params = params self.model_factory = model_factory self.optim_factory = optim_factory @@ -443,7 +447,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: ) -class DQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[DQNParams, DQN]): +class DQNAlgorithmFactory(DiscreteCriticOnlyOffPolicyAlgorithmFactory[DQNParams, DQN]): def _create_policy( self, model: torch.nn.Module, @@ -464,7 +468,7 @@ def _get_algorithm_class(self) -> type[DQN]: return DQN -class IQNAlgorithmFactory(DiscreteCriticOnlyAlgorithmFactory[IQNParams, IQN]): +class IQNAlgorithmFactory(DiscreteCriticOnlyOffPolicyAlgorithmFactory[IQNParams, IQN]): def _create_policy( self, model: torch.nn.Module, @@ -490,12 +494,12 @@ class DDPGAlgorithmFactory(OffPolicyAlgorithmFactory): def __init__( self, params: DDPGParams, - sampling_config: SamplingConfig, + training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_factory: CriticFactory, optim_factory: OptimizerFactoryFactory, ): - super().__init__(sampling_config, optim_factory) + super().__init__(training_config, optim_factory) self.critic_factory = critic_factory self.actor_factory = actor_factory self.params = params @@ -534,12 +538,12 @@ class REDQAlgorithmFactory(OffPolicyAlgorithmFactory): def __init__( self, params: REDQParams, - sampling_config: SamplingConfig, + training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic_ensemble_factory: CriticEnsembleFactory, optim_factory: OptimizerFactoryFactory, ): - super().__init__(sampling_config, optim_factory) + super().__init__(training_config, optim_factory) self.critic_ensemble_factory = critic_ensemble_factory self.actor_factory = actor_factory self.params = params @@ -580,21 +584,20 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: ) -class ActorDualCriticsAlgorithmFactory( +class ActorDualCriticsOffPolicyAlgorithmFactory( OffPolicyAlgorithmFactory, Generic[TActorDualCriticsParams, TAlgorithm, TPolicy], - ABC, ): def __init__( self, params: TActorDualCriticsParams, - sampling_config: SamplingConfig, + training_config: OffPolicyTrainingConfig, actor_factory: ActorFactory, critic1_factory: CriticFactory, critic2_factory: CriticFactory, optim_factory: OptimizerFactoryFactory, ): - super().__init__(sampling_config, optim_factory) + super().__init__(training_config, optim_factory) self.params = params self.actor_factory = actor_factory self.critic1_factory = critic1_factory @@ -652,7 +655,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: ) -class SACAlgorithmFactory(ActorDualCriticsAlgorithmFactory[SACParams, SAC, TPolicy]): +class SACAlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[SACParams, SAC, TPolicy]): def _create_policy( self, actor: torch.nn.Module | Actor, envs: Environments, params: dict ) -> SACPolicy: @@ -670,7 +673,7 @@ def _get_algorithm_class(self) -> type[SAC]: class DiscreteSACAlgorithmFactory( - ActorDualCriticsAlgorithmFactory[DiscreteSACParams, DiscreteSAC, TPolicy] + ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams, DiscreteSAC, TPolicy] ): def _create_policy( self, actor: torch.nn.Module | Actor, envs: Environments, params: dict @@ -688,7 +691,7 @@ def _get_algorithm_class(self) -> type[DiscreteSAC]: return DiscreteSAC -class TD3AlgorithmFactory(ActorDualCriticsAlgorithmFactory[TD3Params, TD3, DDPGPolicy]): +class TD3AlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params, TD3, DDPGPolicy]): def _create_policy( self, actor: torch.nn.Module | Actor, envs: Environments, params: dict ) -> DDPGPolicy: diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index ac27cba1a..485071658 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -7,31 +7,35 @@ log = logging.getLogger(__name__) -@dataclass -class SamplingConfig(ToStringMixin): - """Configuration of sampling, epochs, parallelization, buffers, collectors, and batching.""" +@dataclass(kw_only=True) +class TrainingConfig(ToStringMixin): + """Training configuration.""" num_epochs: int = 100 """ - the number of epochs to run training for. An epoch is the outermost iteration level and each - epoch consists of a number of training steps and a test step, where each training step + the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each + epoch consists of a number of training steps and one test step, where each training step - * collects environment steps/transitions (collection step), adding them to the (replay) - buffer (see :attr:`step_per_collect`) - * performs one or more gradient updates (see :attr:`update_per_step`), + * [for the online case] collects environment steps/transitions (**collection step**), + adding them to the (replay) buffer (see :attr:`step_per_collect` and :attr:`episode_per_collect`) + * performs an **update step** via the RL algorithm being used, which can involve + one or more actual gradient updates, depending on the algorithm and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate agent performance. - The number of training steps in each epoch is indirectly determined by + Training may be stopped early if the stop criterion is met (see :attr:`stop_fn`). + + For online training, the number of training steps in each epoch is indirectly determined by :attr:`step_per_epoch`: As many training steps will be performed as are required in order to reach :attr:`step_per_epoch` total steps in the training environments. Specifically, if the number of transitions collected per step is `c` (see :attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number of training steps per epoch is `ceil(s / c)`. - Therefore, if `num_epochs = e`, the total number of environment steps taken during training can be computed as `e * ceil(s / c) * c`. + + For offline training, the number of training steps per epoch is equal to :attr:`step_per_epoch`. """ step_per_epoch: int = 30000 @@ -40,19 +44,6 @@ class SamplingConfig(ToStringMixin): an explanation of epoch semantics. """ - batch_size: int | None = 64 - """for off-policy algorithms, this is the number of environment steps/transitions to sample - from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific. - On-policy algorithms use the full buffer that was collected in the preceding collection step - but they may use this parameter to perform the gradient update using mini-batches of this size - (causing the gradient to be less accurate, a form of regularization). - - ``batch_size=None`` means that the full buffer is used for the gradient update. This doesn't - make much sense for off-policy algorithms and is not recommended then. For on-policy or offline algorithms, - this means that the full buffer is used for the gradient update (no mini-batching), and - may make sense in some cases. - """ - num_train_envs: int = -1 """the number of training environments to use. If set to -1, use number of CPUs/threads.""" @@ -96,29 +87,6 @@ class SamplingConfig(ToStringMixin): This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. """ - repeat_per_collect: int | None = 1 - """ - controls, within one gradient update step of an on-policy algorithm, the number of times an - actual gradient update is applied using the full collected dataset, i.e. if the parameter is - 5, then the collected data shall be used five times to update the policy within the same - training step. - - The parameter is ignored and may be set to None for off-policy and offline algorithms. - """ - - update_per_step: float = 1.0 - """ - for off-policy algorithms only: the number of gradient steps to perform per sample - collected (see :attr:`step_per_collect`). - Specifically, if this is set to `u` and the number of samples collected in the preceding - collection step is `n`, then `round(u * n)` gradient steps will be performed. - - Note that for on-policy algorithms, only a single gradient update is usually performed, - because thereafter, the samples no longer reflect the behavior of the updated policy. - To change the number of gradient updates for an on-policy algorithm, use parameter - :attr:`repeat_per_collect` instead. - """ - start_timesteps: int = 0 """ the number of environment steps to collect before the actual training loop begins @@ -190,3 +158,80 @@ def __post_init__(self) -> None: assert ( sum([self.step_per_collect is not None, self.episode_per_collect is not None]) == 1 ), ("Only one of `step_per_collect` and `episode_per_collect` can be set.",) + + +@dataclass(kw_only=True) +class OnlineTrainingConfig(TrainingConfig): + step_per_collect: int | None = 2048 + """ + the number of environment steps/transitions to collect in each collection step before the + network update within each training step. + + This is mutually exclusive with :attr:`episode_per_collect`, and one of the two must be set. + + Note that the exact number can be reached only if this is a multiple of the number of + training environments being used, as each training environment will produce the same + (non-zero) number of transitions. + Specifically, if this is set to `n` and `m` training environments are used, then the total + number of transitions collected per collection step is `ceil(n / m) * m =: c`. + + See :attr:`num_epochs` for information on the total number of environment steps being + collected during training. + """ + + episode_per_collect: int | None = None + """ + the number of episodes to collect in each collection step before the network update within + each training step. If this is set, the number of environment steps collected in each + collection step is the sum of the lengths of the episodes collected. + + This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. + """ + + test_in_train: bool = False + """ + Whether to apply a test step within a training step depending on the early stopping criterion + (see :meth:`~tianshou.highlevel.Experiment.with_epoch_stop_callback`) being satisfied based + on the data collected within the training step. + Specifically, after each collect step, we check whether the early stopping criterion + would be satisfied by data we collected (provided that at least one episode was indeed completed, such + that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step + (collecting :attr:`episode_per_test` episodes in order to evaluate performance), and if the early + stopping criterion is also satisfied based on the test data, we stop training early. + """ + + +@dataclass(kw_only=True) +class OnPolicyTrainingConfig(OnlineTrainingConfig): + batch_size: int | None = 64 + """ + Use mini-batches of this size for gradient updates (causing the gradient to be less accurate, + a form of regularization). + Set ``batch_size=None`` for the full buffer that was collected within the training step to be + used for the gradient update (no mini-batching). + """ + + repeat_per_collect: int = 1 + """ + controls, within one update step of an on-policy algorithm, the number of times + the full collected data is applied for gradient updates, i.e. if the parameter is + 5, then the collected data shall be used five times to update the policy within the same + update step. + """ + + +@dataclass(kw_only=True) +class OffPolicyTrainingConfig(OnlineTrainingConfig): + batch_size: int = 64 + """ + the the number of environment steps/transitions to sample from the buffer for a gradient update. + """ + + # TODO: Given our glossary, this is confusingly named. Should definitely contain the word "gradient"; + # also in corresponding TrainerParams object + update_per_step: float = 1.0 + """ + the number of gradient steps to perform per sample collected (see :attr:`step_per_collect`). + Specifically, if this is set to `u` and the number of samples collected in the preceding + collection step is `n`, then `round(u * n)` gradient steps will be performed. + """ diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index e252e2142..629159005 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -25,7 +25,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pprint import pformat -from typing import TYPE_CHECKING, Any, Self, Union, cast +from typing import TYPE_CHECKING, Any, Generic, Self, Union, cast if TYPE_CHECKING: from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher @@ -52,8 +52,13 @@ SACAlgorithmFactory, TD3AlgorithmFactory, TRPOAlgorithmFactory, + TTrainingConfig, +) +from tianshou.highlevel.config import ( + OffPolicyTrainingConfig, + OnPolicyTrainingConfig, + TrainingConfig, ) -from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.module.actor import ( @@ -185,14 +190,14 @@ def __init__( config: ExperimentConfig, env_factory: EnvFactory, algorithm_factory: AlgorithmFactory, - sampling_config: SamplingConfig, + training_config: TrainingConfig, name: str, logger_factory: LoggerFactory | None = None, ): if logger_factory is None: logger_factory = LoggerFactoryDefault() self.config = config - self.sampling_config = sampling_config + self.training_config = training_config self.env_factory = env_factory self.algorithm_factory = algorithm_factory self.logger_factory = logger_factory @@ -221,8 +226,8 @@ def get_seeding_info_as_str(self) -> str: return "_".join( [ f"exp_seed={self.config.seed}", - f"train_seed={self.sampling_config.train_seed}", - f"test_seed={self.sampling_config.test_seed}", + f"train_seed={self.training_config.train_seed}", + f"test_seed={self.training_config.test_seed}", ], ) @@ -294,8 +299,8 @@ def create_experiment_world( # create environments envs = self.env_factory.create_envs( - self.sampling_config.num_train_envs, - self.sampling_config.num_test_envs, + self.training_config.num_train_envs, + self.training_config.num_test_envs, create_watch_env=self.config.watch, ) log.info(f"Created {envs}") @@ -315,7 +320,7 @@ def create_experiment_world( full_config = self._build_config_dict() full_config.update(envs.info()) full_config["experiment_config"] = asdict(self.config) - full_config["sampling_config"] = asdict(self.sampling_config) + full_config["training_config_config"] = asdict(self.training_config) with suppress(AttributeError): full_config["policy_params"] = asdict(self.algorithm_factory.params) @@ -426,14 +431,14 @@ def run( assert world.test_collector is not None # prefilling buffers with either random or current agent's actions - if self.sampling_config.start_timesteps > 0: + if self.training_config.start_timesteps > 0: log.info( - f"Collecting {self.sampling_config.start_timesteps} initial environment " - f"steps before training (random={self.sampling_config.start_timesteps_random})", + f"Collecting {self.training_config.start_timesteps} initial environment " + f"steps before training (random={self.training_config.start_timesteps_random})", ) world.train_collector.collect( - n_step=self.sampling_config.start_timesteps, - random=self.sampling_config.start_timesteps_random, + n_step=self.training_config.start_timesteps, + random=self.training_config.start_timesteps_random, ) log.info("Starting training") @@ -489,7 +494,7 @@ def run( return launcher.launch(experiments=self.experiments) -class ExperimentBuilder(ABC): +class ExperimentBuilder(ABC, Generic[TTrainingConfig]): """A helper class (following the builder pattern) for creating experiments. It contains a lot of defaults for the setup which can be adjusted using the @@ -502,28 +507,31 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: TTrainingConfig | None = None, ): """:param env_factory: controls how environments are to be created. :param experiment_config: the configuration for the experiment. If None, will use the default values of `ExperimentConfig`. - :param sampling_config: the sampling configuration to use. If None, will use the default values - of `SamplingConfig`. + :param training_config: the training configuration to use. If None, use default values (not recommended). """ if experiment_config is None: experiment_config = ExperimentConfig() - if sampling_config is None: - sampling_config = SamplingConfig() + if training_config is None: + training_config = self._create_training_config() self._config = experiment_config self._env_factory = env_factory - self._sampling_config = sampling_config + self._training_config = training_config self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactoryFactory | None = None self._algorithm_wrapper_factory: AlgorithmWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() + @abstractmethod + def _create_training_config(self) -> TTrainingConfig: + pass + def copy(self) -> Self: return deepcopy(self) @@ -536,12 +544,12 @@ def experiment_config(self, experiment_config: ExperimentConfig) -> None: self._config = experiment_config @property - def sampling_config(self) -> SamplingConfig: - return self._sampling_config + def sampling_config(self) -> TrainingConfig: + return self._training_config @sampling_config.setter - def sampling_config(self, sampling_config: SamplingConfig) -> None: - self._sampling_config = sampling_config + def sampling_config(self, sampling_config: TrainingConfig) -> None: + self._training_config = sampling_config def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: """Allows to customize the logger factory to use. @@ -644,7 +652,7 @@ def build(self) -> Experiment: config=self._config, env_factory=self._env_factory, algorithm_factory=algorithm_factory, - sampling_config=self._sampling_config, + training_config=self._training_config, name=self._name, logger_factory=self._logger_factory, ) @@ -674,6 +682,44 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: return ExperimentCollection(seeded_experiments) +class OnPolicyExperimentBuilder(ExperimentBuilder[OnPolicyTrainingConfig], ABC): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, + ): + """ + :param env_factory: controls how environments are to be created. + :param experiment_config: the configuration for the experiment. If None, will use the default values + of :class:`ExperimentConfig`. + :param training_config: the training configuration to use. If None, use default values (not recommended). + """ + super().__init__(env_factory, experiment_config, training_config) + + def _create_training_config(self) -> OnPolicyTrainingConfig: + return OnPolicyTrainingConfig() + + +class OffPolicyExperimentBuilder(ExperimentBuilder[OffPolicyTrainingConfig], ABC): + def __init__( + self, + env_factory: EnvFactory, + experiment_config: ExperimentConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, + ): + """ + :param env_factory: controls how environments are to be created. + :param experiment_config: the configuration for the experiment. If None, will use the default values + of :class:`ExperimentConfig`. + :param training_config: the training configuration to use. If None, use default values (not recommended). + """ + super().__init__(env_factory, experiment_config, training_config) + + def _create_training_config(self) -> OffPolicyTrainingConfig: + return OffPolicyTrainingConfig() + + class _BuilderMixinActorFactory(ActorFutureProviderProtocol): def __init__(self, continuous_actor_type: ContinuousActorType): self._continuous_actor_type = continuous_actor_type @@ -1001,16 +1047,16 @@ def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory: class PGExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) self._params: PGParams = PGParams() self._env_config = None @@ -1022,14 +1068,14 @@ def with_pg_params(self, params: PGParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return ReinforceAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_optim_factory(), ) class A2CExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1037,9 +1083,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: A2CParams = A2CParams() @@ -1052,7 +1098,7 @@ def with_a2c_params(self, params: A2CParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return A2CAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1060,7 +1106,7 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: class PPOExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1068,9 +1114,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: PPOParams = PPOParams() @@ -1082,7 +1128,7 @@ def with_ppo_params(self, params: PPOParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return PPOAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1090,7 +1136,7 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: class NPGExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1098,9 +1144,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: NPGParams = NPGParams() @@ -1112,7 +1158,7 @@ def with_npg_params(self, params: NPGParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return NPGAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1120,7 +1166,7 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: class TRPOExperimentBuilder( - ExperimentBuilder, + OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1128,9 +1174,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OnPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: TRPOParams = TRPOParams() @@ -1142,7 +1188,7 @@ def with_trpo_params(self, params: TRPOParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return TRPOAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1150,15 +1196,15 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: class DQNExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, ): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) self._params: DQNParams = DQNParams() self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory( ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False), @@ -1200,20 +1246,20 @@ def with_model_factory_default( def _create_algorithm_factory(self) -> AlgorithmFactory: return DQNAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._model_factory, self._get_optim_factory(), ) -class IQNExperimentBuilder(ExperimentBuilder): +class IQNExperimentBuilder(OffPolicyExperimentBuilder): def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) self._params: IQNParams = IQNParams() self._preprocess_network_factory: IntermediateModuleFactory = ( IntermediateModuleFactoryFromActorFactory( @@ -1237,14 +1283,14 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: ) return IQNAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, model_factory, self._get_optim_factory(), ) class DDPGExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinSingleCriticCanUseActorFactory, ): @@ -1252,9 +1298,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinSingleCriticCanUseActorFactory.__init__(self, self) self._params: DDPGParams = DDPGParams() @@ -1266,7 +1312,7 @@ def with_ddpg_params(self, params: DDPGParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return DDPGAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_optim_factory(), @@ -1274,7 +1320,7 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: class REDQExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinCriticEnsembleFactory, ): @@ -1282,9 +1328,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinCriticEnsembleFactory.__init__(self) self._params: REDQParams = REDQParams() @@ -1296,7 +1342,7 @@ def with_redq_params(self, params: REDQParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return REDQAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_ensemble_factory(), self._get_optim_factory(), @@ -1304,7 +1350,7 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: class SACExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, _BuilderMixinDualCriticFactory, ): @@ -1312,9 +1358,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: SACParams = SACParams() @@ -1326,7 +1372,7 @@ def with_sac_params(self, params: SACParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return SACAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), @@ -1335,7 +1381,7 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: class DiscreteSACExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_DiscreteOnly, _BuilderMixinDualCriticFactory, ): @@ -1343,9 +1389,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_DiscreteOnly.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: DiscreteSACParams = DiscreteSACParams() @@ -1357,7 +1403,7 @@ def with_sac_params(self, params: DiscreteSACParams) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return DiscreteSACAlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), @@ -1366,7 +1412,7 @@ def _create_algorithm_factory(self) -> AlgorithmFactory: class TD3ExperimentBuilder( - ExperimentBuilder, + OffPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinDualCriticFactory, ): @@ -1374,9 +1420,9 @@ def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, - sampling_config: SamplingConfig | None = None, + training_config: OffPolicyTrainingConfig | None = None, ): - super().__init__(env_factory, experiment_config, sampling_config) + super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousDeterministic.__init__(self) _BuilderMixinDualCriticFactory.__init__(self, self) self._params: TD3Params = TD3Params() @@ -1388,7 +1434,7 @@ def with_td3_params(self, params: TD3Params) -> Self: def _create_algorithm_factory(self) -> AlgorithmFactory: return TD3AlgorithmFactory( self._params, - self._sampling_config, + self._training_config, self._get_actor_factory(), self._get_critic_factory(0), self._get_critic_factory(1), diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index bfd9cd76b..6d50b88c6 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -2,7 +2,7 @@ from sensai.util.string import ToStringMixin -from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.config import TrainingConfig from tianshou.policy.optim import LRSchedulerFactory, LRSchedulerFactoryLinear @@ -15,12 +15,17 @@ def create_lr_scheduler_factory(self) -> LRSchedulerFactory: class LRSchedulerFactoryFactoryLinear(LRSchedulerFactoryFactory): - def __init__(self, sampling_config: SamplingConfig): - self.sampling_config = sampling_config + def __init__(self, training_config: TrainingConfig): + self.training_config = training_config def create_lr_scheduler_factory(self) -> LRSchedulerFactory: + if self.training_config.step_per_epoch is None: + raise ValueError( + f"{self.__class__.__name__} requires step_per_epoch to be set " + f"in order for the total number of update steps to be computable" + ) return LRSchedulerFactoryLinear( - num_epochs=self.sampling_config.num_epochs, - step_per_epoch=self.sampling_config.step_per_epoch, - step_per_collect=self.sampling_config.step_per_collect, + num_epochs=self.training_config.num_epochs, + step_per_epoch=self.training_config.step_per_epoch, + step_per_collect=self.training_config.step_per_collect, ) From ed23c1170c7e86bd543a1ffb9fa4c8a4eac10fa6 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 18:24:19 +0100 Subject: [PATCH 070/230] v2: Change default of test_in_train from True to False, explicitly setting the value for all usages False is the more reasonable default, as it does not make assumptions about returns/score values computed for the data from a collection step being at all meaningful for early stopping. --- CHANGELOG.md | 4 ++++ examples/box2d/acrobot_dualdqn.py | 1 + examples/box2d/bipedal_bdq.py | 1 + examples/box2d/lunarlander_dqn.py | 1 + examples/box2d/mcc_sac.py | 1 + examples/discrete/discrete_dqn.py | 1 + test/continuous/test_ddpg.py | 1 + test/continuous/test_npg.py | 1 + test/continuous/test_ppo.py | 1 + test/continuous/test_redq.py | 1 + test/continuous/test_sac_with_il.py | 2 ++ test/continuous/test_td3.py | 1 + test/continuous/test_trpo.py | 1 + test/discrete/test_a2c_with_il.py | 2 ++ test/discrete/test_bdqn.py | 1 + test/discrete/test_c51.py | 1 + test/discrete/test_dqn.py | 1 + test/discrete/test_drqn.py | 1 + test/discrete/test_fqf.py | 1 + test/discrete/test_iqn.py | 1 + test/discrete/test_pg.py | 1 + test/discrete/test_ppo.py | 1 + test/discrete/test_qrdqn.py | 1 + test/discrete/test_rainbow.py | 1 + test/modelbased/test_dqn_icm.py | 1 + test/modelbased/test_ppo_icm.py | 1 + test/offline/gather_cartpole_data.py | 1 + test/offline/gather_pendulum_data.py | 1 + test/offline/test_gail.py | 1 + test/pettingzoo/pistonball_continuous.py | 1 + tianshou/trainer/base.py | 2 +- 31 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad0be9bfa..ff1a66262 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full episodes) * See also "Issues resolved" below (as issue resolution can result in usage changes) + * The default value for `test_in_train` was changed from True to False (updating all usage sites to explicitly + set the parameter), because False is the more natural default, which does not make assumptions about + returns/score values computed for the data from a collection step being at all meaningful for early stopping * Further internal changes unlikely to affect usage: * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` * The two places that collected and evaluated test episodes (`_test_in_train` and `_reset`) in addition to @@ -39,6 +42,7 @@ * Migration information at a glance: * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: `OnPolicyTrainerParams`, `OffPolicyTrainerParams`, `OfflineTrainerParams`. + * Changed parameter default: Default for `test_in_train` was changed from True to False. * Trainer classes have been renamed: * `OnpolicyTrainer` -> `OnPolicyTrainer` * `OffpolicyTrainer` -> `OffPolicyTrainer` diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 570346122..37d9bb9df 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -143,6 +143,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index dea269d4a..5a3eabab4 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -161,6 +161,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: test_fn=test_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index fb558cf70..eb55fb1ce 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -140,6 +140,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: test_fn=test_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 068bef1e0..766e32ad9 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -152,6 +152,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index ff679c542..9a71d2364 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -79,6 +79,7 @@ def stop_fn(mean_rewards: float) -> bool: test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=stop_fn, logger=logger, + test_in_train=True, ) ) print(f"Finished training in {result.timing.total_time} seconds") diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 3c3e94312..bdba823b6 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -135,6 +135,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 11bba11e9..b373a03ec 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -156,6 +156,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index f475a045a..85d697d21 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -185,6 +185,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, + test_in_train=True, ) ) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index d34034384..f97430ae0 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -166,6 +166,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 2d7c36f04..bc9b39d03 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -162,6 +162,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) @@ -211,6 +212,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 48ac97df6..b52d488c0 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -152,6 +152,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index e74c1e7cd..37f34febf 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -157,6 +157,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index e07939698..4909814c3 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -149,6 +149,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) @@ -197,6 +198,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 7f322c8c1..a292db23f 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -146,6 +146,7 @@ def stop_fn(mean_rewards: float) -> bool: train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 435f1020a..de0b69396 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -207,6 +207,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 0345ba9c6..b33096fda 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -158,6 +158,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index fd636c5bc..48112add8 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -132,6 +132,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index abe413e26..bd3f76328 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -176,6 +176,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index a14ae05ac..eaebe99db 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -172,6 +172,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index bfac05f1b..465dc6f75 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -127,6 +127,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) result = algorithm.run_training(training_config) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index d4c0bbf8c..5c7c723bb 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -149,6 +149,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index b14867b4d..51f721ed1 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -164,6 +164,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 7c6c551ed..dfaa86bd2 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -226,6 +226,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 0a8b2c9aa..0bd21a211 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -206,6 +206,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 6e2bbfdbb..2691178fa 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -197,6 +197,7 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index e84111f65..e72ffc5a5 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -167,6 +167,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 03c63387d..34eae0829 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -160,6 +160,7 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, + test_in_train=True, ) ) train_collector.reset() diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 6d5d7c3c6..24ddf9faa 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -225,6 +225,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, + test_in_train=True, ) ) assert stop_fn(result.best_reward) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index f57d03255..807736c4a 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -278,6 +278,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, + test_in_train=True, ) ) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 4b015b04c..cf9ec8471 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -249,7 +249,7 @@ class OnlineTrainerParams(TrainerParams): This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. """ - test_in_train: bool = True + test_in_train: bool = False """ Whether to apply a test step within a training step depending on the early stopping criterion (given by :attr:`stop_fn`) being satisfied based on the data collected within the training step. From f0c160a43237e0596cfed7c6198bf0d5158b6d1d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 18:55:17 +0100 Subject: [PATCH 071/230] v2: Remove obsolete module utils.lr_scheduler with now unused class MultipleLRSchedulers --- test/base/test_utils.py | 35 +---------------------------- tianshou/utils/__init__.py | 2 -- tianshou/utils/lr_scheduler.py | 41 ---------------------------------- 3 files changed, 1 insertion(+), 77 deletions(-) delete mode 100644 tianshou/utils/lr_scheduler.py diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 8e44ad57b..faa93ee84 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -8,7 +8,7 @@ from torch import nn from tianshou.exploration import GaussianNoise, OUNoise -from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd +from tianshou.utils import MovAvg, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.torch_utils import create_uniform_action_dist, torch_train_mode @@ -106,39 +106,6 @@ def test_net() -> None: assert list(net(data, act).shape) == [bsz, 1] -def test_lr_schedulers() -> None: - initial_lr_1 = 10.0 - step_size_1 = 1 - gamma_1 = 0.5 - net_1 = torch.nn.Linear(2, 3) - optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1) - sched_1 = torch.optim.lr_scheduler.StepLR(optim_1, step_size=step_size_1, gamma=gamma_1) - - initial_lr_2 = 5.0 - step_size_2 = 2 - gamma_2 = 0.3 - net_2 = torch.nn.Linear(3, 2) - optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2) - sched_2 = torch.optim.lr_scheduler.StepLR(optim_2, step_size=step_size_2, gamma=gamma_2) - schedulers = MultipleLRSchedulers(sched_1, sched_2) - for _ in range(10): - loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum() - optim_1.zero_grad() - loss_1.backward() - optim_1.step() - loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum() - optim_2.zero_grad() - loss_2.backward() - optim_2.step() - schedulers.step() - assert optim_1.state_dict()["param_groups"][0]["lr"] == ( - initial_lr_1 * gamma_1 ** (10 // step_size_1) - ) - assert optim_2.state_dict()["param_groups"][0]["lr"] == ( - initial_lr_2 * gamma_2 ** (10 // step_size_2) - ) - - def test_in_eval_mode() -> None: module = nn.Linear(3, 4) module.train() diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 47a3c4497..42e7152b2 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -3,7 +3,6 @@ from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import TensorboardLogger from tianshou.utils.logger.wandb import WandbLogger -from tianshou.utils.lr_scheduler import MultipleLRSchedulers from tianshou.utils.progress_bar import DummyTqdm, tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.warning import deprecation @@ -18,5 +17,4 @@ "TensorboardLogger", "LazyLogger", "WandbLogger", - "MultipleLRSchedulers", ] diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py deleted file mode 100644 index 66b313c7c..000000000 --- a/tianshou/utils/lr_scheduler.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch - - -# TODO: We no longer need this class as Algorithm now uses an explicit list -class MultipleLRSchedulers: - """A wrapper for multiple learning rate schedulers. - - Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called, - it calls the step() method of each of the schedulers that it contains. - Example usage: - :: - - scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2) - scheduler2 = ExponentialLR(opt2, gamma=0.9) - scheduler = MultipleLRSchedulers(scheduler1, scheduler2) - policy = PPOPolicy(..., lr_scheduler=scheduler) - """ - - def __init__(self, *args: torch.optim.lr_scheduler.LRScheduler): - self.schedulers = args - - def step(self) -> None: - """Take a step in each of the learning rate schedulers.""" - for scheduler in self.schedulers: - scheduler.step() - - def state_dict(self) -> list[dict]: - """Get state_dict for each of the learning rate schedulers. - - :return: A list of state_dict of learning rate schedulers. - """ - return [s.state_dict() for s in self.schedulers] - - def load_state_dict(self, state_dict: list[dict]) -> None: - """Load states from state_dict. - - :param state_dict: A list of learning rate scheduler - state_dict, in the same order as the schedulers. - """ - for s, sd in zip(self.schedulers, state_dict, strict=True): - s.__dict__.update(sd) From 41049fda75ad85e57438e8528e0c165f7e1b8fea Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 01:01:19 +0100 Subject: [PATCH 072/230] v2: Fix typing issues in *WrapperAlgorithm --- tianshou/policy/base.py | 27 ++++++++++++++++++++++----- tianshou/policy/modelbased/icm.py | 17 ++++++++++------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 381138367..cb348a1b5 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -901,10 +901,21 @@ def post_process_fn( def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> TTrainingStats: - """Performs the update as defined by the wrapped algorithm.""" - return self.wrapped_algorithm._update_with_batch( + """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update .""" + original_stats = self.wrapped_algorithm._update_with_batch( batch, batch_size=batch_size, repeat=repeat ) + return self._wrapper_update_with_batch(batch, batch_size, repeat, original_stats) + + @abstractmethod + def _wrapper_update_with_batch( + self, + batch: RolloutBatchProtocol, + batch_size: int | None, + repeat: int, + original_stats: TWrappedAlgorthmTrainingStats, + ) -> TTrainingStats: + pass class OffPolicyWrapperAlgorithm( @@ -937,13 +948,19 @@ def post_process_fn( """Performs the batch post-processing as defined by the wrapped algorithm.""" self.wrapped_algorithm.post_process_fn(batch, buffer, indices) - @abstractmethod def _update_with_batch( self, batch: RolloutBatchProtocol, ) -> TTrainingStats: - """Performs the update as defined by the wrapped algorithm.""" - return self.wrapped_algorithm._update_with_batch(batch) + """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update .""" + original_stats = self.wrapped_algorithm._update_with_batch(batch) + return self._wrapper_update_with_batch(batch, original_stats) + + @abstractmethod + def _wrapper_update_with_batch( + self, batch: RolloutBatchProtocol, original_stats: TWrappedAlgorthmTrainingStats + ) -> TTrainingStats: + pass class RandomActionPolicy(Policy): diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 2b87414b9..6a223d844 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -148,12 +148,12 @@ def post_process_fn( super().post_process_fn(batch, buffer, indices) self._icm_postprocess_batch(batch) - def _update_with_batch( + def _wrapper_update_with_batch( self, batch: RolloutBatchProtocol, + original_stats: TTrainingStats, ) -> ICMTrainingStats: - wrapped_stats = super()._update_with_batch(batch) - return self._icm_update(batch, wrapped_stats) + return self._icm_update(batch, original_stats) class ICMOnPolicyWrapper( @@ -209,8 +209,11 @@ def post_process_fn( super().post_process_fn(batch, buffer, indices) self._icm_postprocess_batch(batch) - def _update_with_batch( - self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int + def _wrapper_update_with_batch( + self, + batch: RolloutBatchProtocol, + batch_size: int | None, + repeat: int, + original_stats: TTrainingStats, ) -> ICMTrainingStats: - wrapped_stats = super()._update_with_batch(batch, batch_size=batch_size, repeat=repeat) - return self._icm_update(batch, wrapped_stats) + return self._icm_update(batch, original_stats) From 33414aa0300e9fe480e00d992f5d6173a3cd50fc Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 01:09:43 +0100 Subject: [PATCH 073/230] v2: Remove unused type-ignores --- tianshou/policy/modelfree/trpo.py | 2 +- tianshou/policy/multiagent/mapolicy.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 2dac104af..4bb975f63 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -148,7 +148,7 @@ def _update_with_batch( # type: ignore # optimize critic # TODO: remove type-ignore once the top-level type-ignore is removed - for _ in range(self.optim_critic_iters): # type: ignore + for _ in range(self.optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) self.optim.zero_grad() diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 9333b5e25..d2f2a2b11 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -205,7 +205,7 @@ def __init__(self, algorithms: list[TAlgorithm], env: PettingZooEnv): def create_policy(self) -> MultiAgentPolicy: return MultiAgentPolicy({agent_id: a.policy for agent_id, a in self.algorithms.items()}) - def dispatch_process_fn( # type: ignore + def dispatch_process_fn( self, batch: MAPRolloutBatchProtocol, buffer: ReplayBuffer, @@ -248,7 +248,7 @@ def dispatch_process_fn( # type: ignore buffer._meta.rew = save_rew return cast(MAPRolloutBatchProtocol, Batch(results)) - def dispatch_update_with_batch( # type: ignore + def dispatch_update_with_batch( self, batch: MAPRolloutBatchProtocol, algorithm_update_with_batch_fn: Callable[[TAlgorithm, RolloutBatchProtocol], TrainingStats], From a2958c717186634cba8dc02b91205d86b1293084 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 01:32:17 +0100 Subject: [PATCH 074/230] v2: Fix typing issues in Trainer --- tianshou/policy/base.py | 2 +- tianshou/trainer/base.py | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index cb348a1b5..45724eb65 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -776,7 +776,7 @@ def _update_with_batch( def update( self, buffer: ReplayBuffer, - batch_size: int, + batch_size: int | None, repeat: int, ) -> TTrainingStats: update_with_batch_fn = lambda batch: self._update_with_batch( diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index cf9ec8471..baa1be014 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -192,7 +192,7 @@ class TrainerParams(ToStringMixin): whether to display a progress bars during training. """ - def __post_init__(self): + def __post_init__(self) -> None: if self.resume_from_log and self.logger is None: raise ValueError("Cannot resume from log without a logger being provided") if self.test_collector is None: @@ -260,7 +260,7 @@ class OnlineTrainerParams(TrainerParams): stopping criterion is also satisfied based on the test data, we stop training early. """ - def __post_init__(self): + def __post_init__(self) -> None: super().__post_init__() if count_none(self.step_per_collect, self.episode_per_collect) != 1: raise ValueError("Exactly one of {step_per_collect, episode_per_collect} must be set") @@ -440,7 +440,7 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F class _TrainingStepResult(ABC): @abstractmethod - def get_steps_in_epoch_advancement(self): + def get_steps_in_epoch_advancement(self) -> int: """ :return: the number of steps that were done within the epoch, where the concrete semantics of what a step is depend on the type of algorith. See docstring of `TrainingConfig.step_per_epoch`. @@ -455,7 +455,7 @@ def get_training_stats(self) -> TrainingStats | None: pass @abstractmethod - def is_training_done(self): + def is_training_done(self) -> bool: """:return: whether the early stopping criterion is satisfied and training shall stop.""" @abstractmethod @@ -591,6 +591,7 @@ def _should_stop_training_early( def _collect_test_episodes( self, ) -> CollectStats: + assert self.params.test_collector is not None collector = self.params.test_collector collector.reset(reset_stats=False) if self.params.test_fn: @@ -715,7 +716,7 @@ class OfflineTrainer(Trainer[OfflineAlgorithm, OfflineTrainerParams]): def __init__( self, - algorithm: "Algorithm", + algorithm: OfflineAlgorithm, params: OfflineTrainerParams, ): super().__init__(algorithm, params) @@ -726,7 +727,7 @@ def __init__(self, training_stats: TrainingStats, env_step_advancement: int): self._training_stats = training_stats self._env_step_advancement = env_step_advancement - def get_steps_in_epoch_advancement(self): + def get_steps_in_epoch_advancement(self) -> int: return 1 def get_collect_stats(self) -> None: @@ -756,7 +757,7 @@ def _training_step(self) -> _TrainingStepResult: ) def _create_epoch_pbar_data_dict( - self, training_step_result: _TrainingStepResult + self, training_step_result: Trainer._TrainingStepResult ) -> dict[str, str]: return {} @@ -772,7 +773,7 @@ class OnlineTrainer( def __init__( self, - algorithm: "Algorithm", + algorithm: TAlgorithm, params: OnlineTrainerParams, ): super().__init__(algorithm, params) @@ -812,7 +813,7 @@ def __init__( self._training_stats = training_stats self._is_training_done = is_training_done - def get_steps_in_epoch_advancement(self): + def get_steps_in_epoch_advancement(self) -> int: return self.get_env_step_advancement() def get_collect_stats(self) -> CollectStats: @@ -821,7 +822,7 @@ def get_collect_stats(self) -> CollectStats: def get_training_stats(self) -> TrainingStats | None: return self._training_stats - def is_training_done(self): + def is_training_done(self) -> bool: return self._is_training_done def get_env_step_advancement(self) -> int: @@ -940,9 +941,10 @@ def _update_step( """ def _create_epoch_pbar_data_dict( - self, training_step_result: _TrainingStepResult + self, training_step_result: Trainer._TrainingStepResult ) -> dict[str, str]: collect_stats = training_step_result.get_collect_stats() + assert collect_stats is not None result = { "env_step": str(self._env_step), "env_episode": str(self._env_episode), @@ -951,6 +953,8 @@ def _create_epoch_pbar_data_dict( } # return and episode length info is only available if at least one episode was completed if collect_stats.n_collected_episodes > 0: + assert collect_stats.returns_stat is not None + assert collect_stats.lens_stat is not None result.update( { "rew": f"{collect_stats.returns_stat.mean:.2f}", @@ -998,6 +1002,7 @@ def _update_step( self._policy_update_time += update_stat.train_time # TODO: only the last update_stat is returned, should be improved + assert update_stat is not None return update_stat def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: From b855ab4b4b198b7d3d61d57aef37514178e4c93a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 01:42:30 +0100 Subject: [PATCH 075/230] v2: Adapt test_policy and test_collector --- test/base/test_collector.py | 13 ++----------- test/base/test_policy.py | 29 +++++++++++++++++------------ tianshou/data/collector.py | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index ed648075d..0d47c456f 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -25,8 +25,7 @@ ) from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import BasePolicy, TrainingStats -from tianshou.policy.base import episode_mc_return_to_go +from tianshou.policy.base import Policy, episode_mc_return_to_go try: import envpool @@ -34,7 +33,7 @@ envpool = None -class MaxActionPolicy(BasePolicy): +class MaxActionPolicy(Policy): def __init__( self, action_space: gym.spaces.Space | None = None, @@ -80,14 +79,6 @@ def forward( action_shape = self.action_shape if self.action_shape else len(batch.obs) return Batch(act=np.ones(action_shape), state=state) - def _update_with_batch( - self, - batch: RolloutBatchProtocol, - *args: Any, - **kwargs: Any, - ) -> TrainingStats: - raise NotImplementedError - @pytest.fixture() def collector_with_single_env() -> Collector[CollectStats]: diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 618958825..a861a68f1 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -5,9 +5,11 @@ from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.data import Batch -from tianshou.policy import PPO, Algorithm +from tianshou.policy import PPO from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.discrete import Actor @@ -26,7 +28,7 @@ def test_calculate_discounted_returns() -> None: @pytest.fixture(params=["continuous", "discrete"]) -def policy(request: pytest.FixtureRequest) -> PPO: +def algorithm(request: pytest.FixtureRequest) -> PPO: action_type = request.param action_space: gym.spaces.Box | gym.spaces.Discrete actor: Actor | ActorProb @@ -55,24 +57,27 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: Net(obs_shape, hidden_sizes=[64, 64]), ) - actor_critic = ActorCritic(actor, critic) - optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3) + optim = AdamOptimizerFactory(lr=1e-3) - policy: Algorithm - policy = PPO( + algorithm: PPO + policy = ActorPolicy( actor=actor, - critic=critic, dist_fn=dist_fn, - optim=optim, action_space=action_space, action_scaling=False, ) - policy.eval() - return policy + algorithm = PPO( + policy=policy, + critic=critic, + optim=optim, + ) + algorithm.eval() + return algorithm class TestPolicyBasics: - def test_get_action(self, policy: PPO) -> None: + def test_get_action(self, algorithm: PPO) -> None: + policy = algorithm.policy policy.is_within_training_step = False sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 0f3fbaf31..20fa89721 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1084,7 +1084,7 @@ class AsyncCollector(Collector[CollectStats]): def __init__( self, - policy: Algorithm, + policy: Policy | Algorithm, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, From cd73407558ee8a28504e2e49e674cf091faba1aa Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 02:00:51 +0100 Subject: [PATCH 076/230] v2: Fix some incorrect types, make mypy happier --- examples/atari/atari_c51.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 2 +- examples/discrete/discrete_dqn.py | 4 ++-- examples/mujoco/fetch_her_ddpg.py | 2 +- examples/offline/atari_il.py | 2 +- examples/offline/d4rl_cql.py | 2 +- examples/offline/d4rl_il.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_sac_with_il.py | 2 +- test/discrete/test_a2c_with_il.py | 2 +- test/discrete/test_ppo.py | 5 +++-- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_rainbow.py | 2 +- test/modelbased/test_dqn_icm.py | 2 +- test/modelbased/test_ppo_icm.py | 2 +- test/offline/gather_cartpole_data.py | 5 +++-- test/offline/gather_pendulum_data.py | 4 ++-- test/offline/test_cql.py | 2 +- test/offline/test_gail.py | 2 +- test/pettingzoo/pistonball.py | 7 ++++--- test/pettingzoo/pistonball_continuous.py | 7 ++++--- test/pettingzoo/tic_tac_toe.py | 6 +++--- tianshou/highlevel/algorithm.py | 14 ++++++++------ tianshou/policy/imitation/discrete_cql.py | 2 +- tianshou/policy/imitation/td3_bc.py | 2 +- 25 files changed, 46 insertions(+), 40 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index cf4725b47..4c627dbbb 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -98,7 +98,7 @@ def main(args: argparse.Namespace = get_args()) -> None: v_min=args.v_min, v_max=args.v_max, ) - algorithm = C51( + algorithm: C51 = C51( policy=policy, optim=optim, discount_factor=args.gamma, diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 9ccc48690..7e626d0e7 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -137,7 +137,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: device=args.device, ) critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = AdamOptimizerFactory(critic2.parameters(), lr=args.critic_lr) + critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 9a71d2364..aa958286c 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -1,10 +1,10 @@ import gymnasium as gym -import torch from torch.utils.tensorboard import SummaryWriter import tianshou as ts from tianshou.data import CollectStats from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils.space_info import SpaceInfo @@ -35,7 +35,7 @@ def main() -> None: state_shape = space_info.observation_info.obs_shape action_shape = space_info.action_info.action_shape net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) - optim = torch.optim.Adam(net.parameters(), lr=lr) + optim = AdamOptimizerFactory(lr=lr) policy = DQNPolicy(model=net, action_space=env.action_space) algorithm: ts.policy.DQN = ts.policy.DQN( diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index e1a3bc823..bbdf7e1ec 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -175,7 +175,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, ) - algorithm = DDPG( + algorithm: DDPG = DDPG( policy=policy, policy_optim=actor_optim, critic=critic, diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 9a028f232..52f61679a 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -92,7 +92,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: optim = AdamOptimizerFactory(lr=args.lr) # define policy policy = ImitationPolicy(actor=net, action_space=env.action_space) - algorithm = OfflineImitationLearning( + algorithm: OfflineImitationLearning = OfflineImitationLearning( policy=policy, optim=optim, ) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index d0121d131..2fda53a9a 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -290,7 +290,7 @@ def test_cql() -> None: actor=actor, action_space=env.action_space, ) - algorithm = CQL( + algorithm: CQL = CQL( policy=policy, policy_optim=actor_optim, critic=critic, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 22b52f212..e94ee3765 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -103,7 +103,7 @@ def test_il() -> None: action_scaling=True, action_bound_method="clip", ) - algorithm = OfflineImitationLearning( + algorithm: OfflineImitationLearning = OfflineImitationLearning( policy=policy, optim=optim, ) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 85d697d21..795d0b46d 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -108,7 +108,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: dist_fn=dist, action_space=env.action_space, ) - algorithm = PPO( + algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index bc9b39d03..acf74973b 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -116,7 +116,7 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: actor=actor, action_space=env.action_space, ) - algorithm = SAC( + algorithm: SAC = SAC( policy=policy, policy_optim=actor_optim, critic=critic1, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 4909814c3..c3f252857 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -103,7 +103,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: action_scaling=isinstance(env.action_space, Box), action_space=env.action_space, ) - algorithm = A2C( + algorithm: A2C = A2C( policy=policy, critic=critic, optim=optim, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 5c7c723bb..bb157755c 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -12,6 +12,7 @@ from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.pg import DiscreteActorPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net @@ -93,7 +94,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical policy = DiscreteActorPolicy( actor=actor, @@ -101,7 +102,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: action_space=env.action_space, deterministic_eval=True, ) - algorithm = PPO( + algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 51f721ed1..4113282d2 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -96,7 +96,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: policy = QRDQNPolicy( model=net, action_space=env.action_space, observation_space=env.observation_space ) - algorithm = QRDQN( + algorithm: QRDQN = QRDQN( policy=policy, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index dfaa86bd2..569b9e3eb 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -111,7 +111,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: v_min=args.v_min, v_max=args.v_max, ) - algorithm = RainbowDQN( + algorithm: RainbowDQN = RainbowDQN( policy=policy, optim=optim, discount_factor=args.gamma, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 0bd21a211..0f70ddd17 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -111,7 +111,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: model=net, action_space=env.action_space, ) - algorithm = DQN( + algorithm: DQN = DQN( policy=policy, optim=optim, discount_factor=args.gamma, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 2691178fa..b3a68d1a4 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -121,7 +121,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: action_space=env.action_space, deterministic_eval=True, ) - algorithm = PPO( + algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index e72ffc5a5..e53c2c518 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -17,6 +17,7 @@ from tianshou.policy import QRDQN from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -96,12 +97,12 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: softmax=False, num_atoms=args.num_quantiles, ) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = QRDQNPolicy( model=net, action_space=env.action_space, ) - algorithm = QRDQN( + algorithm: QRDQN = QRDQN( policy=policy, optim=optim, discount_factor=args.gamma, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 34eae0829..996cb818a 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -98,7 +98,7 @@ def gather_data() -> VectorReplayBuffer: device=args.device, unbounded=True, ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, @@ -107,7 +107,7 @@ def gather_data() -> VectorReplayBuffer: device=args.device, ) critic = Critic(net_c, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + critic_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 18786eab7..7af161300 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -132,7 +132,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: critic_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: - target_entropy = -np.prod(args.action_shape) + target_entropy = float(-np.prod(args.action_shape)) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 24ddf9faa..7a6940778 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -144,7 +144,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: dist_fn=dist, action_space=env.action_space, ) - algorithm = GAIL( + algorithm: GAIL = GAIL( policy=policy, critic=critic, optim=optim, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index fce2ddaa3..48defd9ac 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -13,6 +13,7 @@ from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import DQN, Algorithm, MultiAgentOffPolicyAlgorithm from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -97,7 +98,7 @@ def get_agents( hidden_sizes=args.hidden_sizes, device=args.device, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) policy = DQNPolicy( model=net, action_space=env.action_space, @@ -112,8 +113,8 @@ def get_agents( agents.append(agent) optims.append(optim) - policy = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) - return policy, optims, env.agents + ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) + return ma_algorithm, optims, env.agents def train_agent( diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 807736c4a..dc2ffdce8 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -18,6 +18,7 @@ from tianshou.policy import PPO, Algorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.continuous import ActorProb, Critic @@ -182,7 +183,7 @@ def get_agents( if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr) + optim = AdamOptimizerFactory(lr=args.lr) def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale @@ -216,11 +217,11 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: agents.append(agent) optims.append(optim) - policy = MultiAgentOnPolicyAlgorithm( + ma_algorithm = MultiAgentOnPolicyAlgorithm( algorithms=agents, env=env, ) - return policy, optims, env.agents + return ma_algorithm, optims, env.agents def train_agent( diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 670b6a722..f1b15bf12 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -149,15 +149,15 @@ def get_agents( agents = [agent_learn, agent_opponent] else: agents = [agent_opponent, agent_learn] - algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) - return algorithm, optim, env.agents + ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) + return ma_algorithm, optim, env.agents def train_agent( args: argparse.Namespace = get_args(), agent_learn: Algorithm | None = None, agent_opponent: Algorithm | None = None, - optim: torch.optim.Optimizer | None = None, + optim: OptimizerFactory | None = None, ) -> tuple[InfoStats, Algorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 1ab4d597f..1dafb0ad4 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -160,7 +160,7 @@ def set_trainer_callbacks(self, callbacks: TrainerCallbacks) -> None: @staticmethod def _create_policy_from_args( - constructor: type[TPolicy], params_dict: dict, policy_params: list[str], **kwargs + constructor: type[TPolicy], params_dict: dict, policy_params: list[str], **kwargs: Any ) -> TPolicy: params = {p: params_dict.pop(p) for p in policy_params} return constructor(**params, **kwargs) @@ -210,6 +210,7 @@ def create_trainer( else None ) algorithm = cast(OnPolicyAlgorithm, world.policy) + assert world.train_collector is not None return algorithm.create_trainer( OnPolicyTrainerParams( train_collector=world.train_collector, @@ -257,6 +258,7 @@ def create_trainer( else None ) algorithm = cast(OffPolicyAlgorithm, world.policy) + assert world.train_collector is not None return algorithm.create_trainer( OffPolicyTrainerParams( train_collector=world.train_collector, @@ -424,7 +426,7 @@ def _create_policy( params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, - ) -> TPolicy: + ) -> Policy: pass @typing.no_type_check @@ -454,7 +456,7 @@ def _create_policy( params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, - ) -> TPolicy: + ) -> Policy: return self._create_policy_from_args( constructor=DQNPolicy, params_dict=params, @@ -475,7 +477,7 @@ def _create_policy( params: dict, action_space: gymnasium.spaces.Discrete, observation_space: gymnasium.spaces.Space, - ) -> TPolicy: + ) -> Policy: pass return self._create_policy_from_args( IQNPolicy, @@ -655,7 +657,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: ) -class SACAlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[SACParams, SAC, TPolicy]): +class SACAlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[SACParams, SAC, SACPolicy]): def _create_policy( self, actor: torch.nn.Module | Actor, envs: Environments, params: dict ) -> SACPolicy: @@ -673,7 +675,7 @@ def _get_algorithm_class(self) -> type[SAC]: class DiscreteSACAlgorithmFactory( - ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams, DiscreteSAC, TPolicy] + ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams, DiscreteSAC, DiscreteSACPolicy] ): def _create_policy( self, actor: torch.nn.Module | Actor, envs: Environments, params: dict diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index b2d62d7fc..6d6de8f00 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -23,7 +23,7 @@ class DiscreteCQLTrainingStats(QRDQNTrainingStats): # NOTE: This uses diamond inheritance to convert from off-policy to offline -class DiscreteCQL( +class DiscreteCQL( # type: ignore OfflineAlgorithm[QRDQNPolicy, TDiscreteCQLTrainingStats], QRDQN[QRDQNPolicy, TDiscreteCQLTrainingStats], ): diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index dfddfa429..211a160c9 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -22,7 +22,7 @@ class TD3BCTrainingStats(TD3TrainingStats): # NOTE: This uses diamond inheritance to convert from off-policy to offline -class TD3BC(OfflineAlgorithm[DDPGPolicy, TTD3BCTrainingStats], TD3[TTD3BCTrainingStats]): +class TD3BC(OfflineAlgorithm[DDPGPolicy, TTD3BCTrainingStats], TD3[TTD3BCTrainingStats]): # type: ignore """Implementation of TD3+BC. arXiv:2106.06860.""" def __init__( From 0c26b338d8d408870e2dd03c248a6c5492d3565d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 12:50:20 +0100 Subject: [PATCH 077/230] v2: AutoAlpha: Use optimizer factory and create the tensor internally --- examples/atari/atari_sac.py | 4 ++-- examples/box2d/bipedal_hardcore_sac.py | 6 +++--- examples/box2d/mcc_sac.py | 6 +++--- examples/mujoco/mujoco_redq.py | 7 ++++--- examples/mujoco/mujoco_sac.py | 8 +++---- examples/offline/d4rl_cql.py | 8 +++---- test/continuous/test_redq.py | 7 ++++--- test/continuous/test_sac_with_il.py | 4 ++-- test/discrete/test_discrete_sac.py | 7 ++++--- test/offline/gather_pendulum_data.py | 7 ++++--- test/offline/test_cql.py | 4 ++-- tianshou/highlevel/params/alpha.py | 21 ++++++++----------- tianshou/highlevel/params/policy_wrapper.py | 2 +- tianshou/policy/modelfree/sac.py | 23 ++++++++------------- 14 files changed, 55 insertions(+), 59 deletions(-) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 363c95c73..9c52de85d 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -126,8 +126,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: # define policy and algorithm if args.auto_alpha: target_entropy = 0.98 * np.log(np.prod(args.action_shape)) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) algorithm: DiscreteSAC | ICMOffPolicyWrapper policy = DiscreteSACPolicy( diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 7e626d0e7..b7c7a3f12 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -142,9 +142,9 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 766e32ad9..76c5611da 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -93,9 +93,9 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 31082ab6b..eb6bb2dae 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -14,6 +14,7 @@ from tianshou.policy import REDQ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.redq import REDQPolicy +from tianshou.policy.modelfree.sac import AutoAlpha from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net @@ -119,9 +120,9 @@ def linear(x: int, y: int) -> EnsembleLinear: if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = REDQPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 85c34884c..718aef429 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -13,7 +13,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import SAC from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net @@ -115,9 +115,9 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.auto_alpha: target_entropy = -np.prod(env.action_space.shape) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 2fda53a9a..1b2580626 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -15,7 +15,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.policy import CQL from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger @@ -282,9 +282,9 @@ def test_cql() -> None: if args.auto_alpha: target_entropy = -args.action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index f97430ae0..6517be5c4 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -12,6 +12,7 @@ from tianshou.policy import REDQ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.redq import REDQPolicy +from tianshou.policy.modelfree.sac import AutoAlpha from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -109,9 +110,9 @@ def linear(x: int, y: int) -> nn.Module: action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = REDQPolicy( actor=actor, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index acf74973b..d6930722f 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -109,8 +109,8 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) policy = SACPolicy( actor=actor, diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 503264fa6..0a2ef27d4 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -14,6 +14,7 @@ DiscreteSACPolicy, DiscreteSACTrainingStats, ) +from tianshou.policy.modelfree.sac import AutoAlpha from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -92,9 +93,9 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: # better not to use auto alpha in CartPole if args.auto_alpha: target_entropy = 0.98 * np.log(action_dim) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = DiscreteSACPolicy( actor=actor, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 996cb818a..c1ca8cb33 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -12,6 +12,7 @@ from tianshou.policy import SAC from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats +from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net @@ -112,9 +113,9 @@ def gather_data() -> VectorReplayBuffer: action_dim = space_info.action_info.action_dim if args.auto_alpha: target_entropy = -action_dim - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) + args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim).to(args.device) policy = SACPolicy( actor=actor, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 7af161300..ee734c560 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -133,8 +133,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: if args.auto_alpha: target_entropy = float(-np.prod(args.action_shape)) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + log_alpha = 0.0 + alpha_optim = AdamOptimizerFactory(lr=args.alpha_lr) args.alpha = AutoAlpha(target_entropy, log_alpha, alpha_optim) policy = SACPolicy( diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index fc23baeb4..55787b7cb 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -1,14 +1,12 @@ from abc import ABC, abstractmethod import numpy as np -import torch from sensai.util.string import ToStringMixin -from torch.nn import ParameterList from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.optim import OptimizerFactoryFactory -from tianshou.policy.modelfree.sac import AutoAlpha +from tianshou.policy.modelfree.sac import Alpha, AutoAlpha class AutoAlphaFactory(ToStringMixin, ABC): @@ -17,7 +15,7 @@ def create_auto_alpha( self, envs: Environments, device: TDevice, - ) -> AutoAlpha: + ) -> Alpha: pass @@ -26,7 +24,8 @@ def __init__( self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0, - optimizer: OptimizerFactoryFactory | None = None, + log_alpha=0.0, + optim: OptimizerFactoryFactory | None = None, ): """ :param lr: the learning rate for the optimizer of the alpha parameter @@ -36,11 +35,13 @@ def __init__( spaces respectively, which gives a reasonable trade-off between exploration and exploitation. For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); 1.0 would give full entropy exploration. - :param optimizer: the optimizer factory to use; if None, use default + :param log_alpha: the (initial) value of the log of the entropy regularization coefficient alpha. + :param optim: the optimizer factory to use; if None, use default """ self.lr = lr self.target_entropy_coefficient = target_entropy_coefficient - self.optimizer_factory_factory = optimizer or OptimizerFactoryFactory.default() + self.log_alpha = log_alpha + self.optimizer_factory_factory = optim or OptimizerFactoryFactory.default() def create_auto_alpha( self, @@ -52,9 +53,5 @@ def create_auto_alpha( target_entropy = self.target_entropy_coefficient * float(action_dim) else: target_entropy = self.target_entropy_coefficient * np.log(action_dim) - log_alpha = torch.zeros(1, requires_grad=True, device=device) optim_factory = self.optimizer_factory_factory.create_optimizer_factory(lr=self.lr) - optim, lr_scheduler = optim_factory.create_instances(ParameterList([log_alpha])) - if lr_scheduler is not None: - raise ValueError("Learning rate schedulers are not supported for AutoAlpha") - return AutoAlpha(target_entropy, log_alpha, optim) + return AutoAlpha(target_entropy, self.log_alpha, optim_factory) diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 2d808eccd..f4fb74b58 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -56,7 +56,7 @@ def create_wrapped_algorithm( envs: Environments, optim_factory_default: OptimizerFactoryFactory, device: TDevice, - ) -> ICMOffPolicyWrapper: + ) -> ICMOffPolicyWrapper | ICMOnPolicyWrapper: feature_net = self.feature_net_factory.create_intermediate_module(envs, device) action_dim = envs.get_action_shape() if not isinstance(action_dim, int): diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index f25cf6887..58bb40f21 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -6,6 +6,7 @@ import numpy as np import torch from torch.distributions import Independent, Normal +from torch.nn import ParameterList from tianshou.data import Batch from tianshou.data.types import ( @@ -162,9 +163,7 @@ def update(self, entropy: torch.Tensor) -> float | None: class AutoAlpha(torch.nn.Module, Alpha): """Represents an entropy regularization coefficient alpha that is automatically tuned.""" - def __init__( - self, target_entropy: float, log_alpha: torch.Tensor, optim: torch.optim.Optimizer - ): + def __init__(self, target_entropy: float, log_alpha: float, optim: OptimizerFactory): """ :param target_entropy: the target entropy value. For discrete action spaces, it is usually -log(|A|) for a balance between stochasticity @@ -172,21 +171,17 @@ def __init__( lambda*log(|A|), e.g. with lambda close to 1 (e.g. 0.98) for pronounced stochasticity. For continuous action spaces, it is usually -dim(A) for a balance between stochasticity and determinism, with similar generalizations as for discrete action spaces. - :param log_alpha: the (initial) log of the entropy regularization coefficient alpha. - This must be a scalar tensor with requires_grad=True. - :param optim: the optimizer for `log_alpha`. + :param log_alpha: the (initial) value of the log of the entropy regularization coefficient alpha. + :param optim: the factory with which to create the optimizer for `log_alpha`. """ super().__init__() - if not log_alpha.requires_grad: - raise ValueError("Expected log_alpha to require gradient, but it doesn't.") - if log_alpha.shape != torch.Size([1]): + self._target_entropy = target_entropy + self._log_alpha = torch.tensor(log_alpha, requires_grad=True) + self._optim, lr_scheduler = optim.create_instances(ParameterList([self._log_alpha])) + if lr_scheduler is not None: raise ValueError( - f"Expected log_alpha to have shape torch.Size([1]), " - f"but got {log_alpha.shape} instead.", + f"Learning rate schedulers are not supported by {self.__class__.__name__}" ) - self._target_entropy = target_entropy - self._log_alpha = log_alpha - self._optim = optim @property def value(self) -> float: From c9a8e2f540368f2f02b9fb02d65e71b992bd1da0 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 13:23:39 +0100 Subject: [PATCH 078/230] v2: Improve Algorithm (formerly BasePolicy) method names * process_fn -> preprocess_batch * post_process_fn -> postprocess_batch --- CHANGELOG.md | 4 ++++ tianshou/policy/base.py | 24 +++++++++++------------ tianshou/policy/imitation/discrete_bcq.py | 2 +- tianshou/policy/imitation/discrete_crr.py | 2 +- tianshou/policy/imitation/gail.py | 4 ++-- tianshou/policy/modelbased/icm.py | 16 +++++++-------- tianshou/policy/modelfree/a2c.py | 2 +- tianshou/policy/modelfree/bdqn.py | 2 +- tianshou/policy/modelfree/ddpg.py | 2 +- tianshou/policy/modelfree/dqn.py | 2 +- tianshou/policy/modelfree/npg.py | 2 +- tianshou/policy/modelfree/pg.py | 2 +- tianshou/policy/modelfree/ppo.py | 2 +- tianshou/policy/multiagent/mapolicy.py | 6 +++--- 14 files changed, 38 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff1a66262..1136f029c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,10 @@ * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. * Interface changes/improvements: + * Core methods have been renamed: + * `process_fn` -> `preprocess_batch` + * `post_process_fn` -> `postprocess_batch` + * `learn` -> `_update_with_batch` (no longer in public interface) * The updating interface has been cleaned up: * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 45724eb65..3a56ec05a 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -493,7 +493,7 @@ def _create_optimizer( self.lr_schedulers.append(lr_scheduler) return optimizer - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -510,7 +510,7 @@ def process_fn( """ return batch - def post_process_fn( + def postprocess_batch( self, batch: BatchProtocol, buffer: ReplayBuffer, @@ -570,10 +570,10 @@ def _update( start_time = time.time() batch, indices = buffer.sample(sample_size) self.updating = True - batch = self.process_fn(batch, buffer, indices) + batch = self.preprocess_batch(batch, buffer, indices) with torch_train_mode(self): training_stat = update_with_batch_fn(batch) - self.post_process_fn(batch, buffer, indices) + self.postprocess_batch(batch, buffer, indices) for lr_scheduler in self.lr_schedulers: lr_scheduler.step() self.updating = False @@ -880,23 +880,23 @@ def __init__( super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: """Performs the pre-processing as defined by the wrapped algorithm.""" - return self.wrapped_algorithm.process_fn(batch, buffer, indices) + return self.wrapped_algorithm.preprocess_batch(batch, buffer, indices) - def post_process_fn( + def postprocess_batch( self, batch: BatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: """Performs the batch post-processing as defined by the wrapped algorithm.""" - self.wrapped_algorithm.post_process_fn(batch, buffer, indices) + self.wrapped_algorithm.postprocess_batch(batch, buffer, indices) def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int @@ -930,23 +930,23 @@ def __init__( super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: """Performs the pre-processing as defined by the wrapped algorithm.""" - return self.wrapped_algorithm.process_fn(batch, buffer, indices) + return self.wrapped_algorithm.preprocess_batch(batch, buffer, indices) - def post_process_fn( + def postprocess_batch( self, batch: BatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: """Performs the batch post-processing as defined by the wrapped algorithm.""" - self.wrapped_algorithm.post_process_fn(batch, buffer, indices) + self.wrapped_algorithm.postprocess_batch(batch, buffer, indices) def _update_with_batch( self, diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index a5b47e225..b0ebf94dc 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -162,7 +162,7 @@ def __init__( self.eps = eval_eps self._weight_reg = imitation_logits_penalty - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index e8c642ab6..282eba63b 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -95,7 +95,7 @@ def __init__( self._beta = beta self._min_q_weight = min_q_weight - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 2d5a4b445..a12c31f25 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -110,7 +110,7 @@ def __init__( # only the output dimension? self.action_dim = self.policy.actor.output_dim - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -123,7 +123,7 @@ def process_fn( # update reward with torch.no_grad(): batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten()) - return super().process_fn(batch, buffer, indices) + return super().preprocess_batch(batch, buffer, indices) def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: obs = to_torch(batch.obs, device=self.disc_net.device) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 6a223d844..5da69d05f 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -130,22 +130,22 @@ def __init__( forward_loss_weight=forward_loss_weight, ) - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: self._icm_preprocess_batch(batch) - return super().process_fn(batch, buffer, indices) + return super().preprocess_batch(batch, buffer, indices) - def post_process_fn( + def postprocess_batch( self, batch: BatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: - super().post_process_fn(batch, buffer, indices) + super().postprocess_batch(batch, buffer, indices) self._icm_postprocess_batch(batch) def _wrapper_update_with_batch( @@ -191,22 +191,22 @@ def __init__( forward_loss_weight=forward_loss_weight, ) - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: self._icm_preprocess_batch(batch) - return super().process_fn(batch, buffer, indices) + return super().preprocess_batch(batch, buffer, indices) - def post_process_fn( + def postprocess_batch( self, batch: BatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: - super().post_process_fn(batch, buffer, indices) + super().postprocess_batch(batch, buffer, indices) self._icm_postprocess_batch(batch) def _wrapper_update_with_batch( diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 60e6d7487..c7a4d9bd9 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -161,7 +161,7 @@ def __init__( self.ent_coef = ent_coef self.max_grad_norm = max_grad_norm - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 9e8fcbfc3..313d201c0 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -171,7 +171,7 @@ def _compute_return( batch.weight = to_torch_as(batch.weight, target_q_torch) return cast(BatchWithReturnsProtocol, batch) - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 6b5344f7a..b24a159bf 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -240,7 +240,7 @@ def _minimize_critic_squared_loss( optimizer.step() return td, critic_loss - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index ec60109c9..54ce79c97 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -206,7 +206,7 @@ def use_target_network(self) -> bool: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: pass - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 506fc08c3..2c8ad9b12 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -77,7 +77,7 @@ def __init__( # adjusts Hessian-vector product calculation for numerical stability self._damping = 0.1 - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index e7381c8a8..9c2f18d59 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -274,7 +274,7 @@ def __init__( ) self.optim = self._create_optimizer(self.policy, optim) - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 044f459f7..42a7f1f8b 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -122,7 +122,7 @@ def __init__( self.recompute_adv = recompute_advantage self._actor_critic: ActorCritic - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index d2f2a2b11..52592fdc2 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -243,7 +243,7 @@ def dispatch_process_fn( tmp_batch.obs = tmp_batch.obs.obs if hasattr(tmp_batch.obs_next, "obs"): tmp_batch.obs_next = tmp_batch.obs_next.obs - results[agent] = algorithm.process_fn(tmp_batch, buffer, tmp_indice) + results[agent] = algorithm.preprocess_batch(tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return cast(MAPRolloutBatchProtocol, Batch(results)) @@ -291,7 +291,7 @@ def __init__( def get_algorithm(self, agent_id: str | int) -> OffPolicyAlgorithm: return self._dispatcher.algorithms[agent_id] - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -334,7 +334,7 @@ def __init__( def get_algorithm(self, agent_id: str | int) -> OnPolicyAlgorithm: return self._dispatcher.algorithms[agent_id] - def process_fn( + def preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, From d0d1f14072f50456a77681d2a4f270f162d67e44 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 13:36:33 +0100 Subject: [PATCH 079/230] v2: Rename Actor classes to improve clarity * BaseActor -> Actor * continuous.ActorProb -> ContinuousActorProb * coninuous.Actor -> ContinuousActorDeterministic * discrete.Actor -> DiscreteActor --- CHANGELOG.md | 8 +++++++- examples/atari/atari_ppo.py | 4 ++-- examples/atari/atari_sac.py | 4 ++-- examples/box2d/bipedal_hardcore_sac.py | 4 ++-- examples/box2d/mcc_sac.py | 6 ++++-- examples/inverse/irl_gail.py | 4 ++-- examples/mujoco/fetch_her_ddpg.py | 4 ++-- examples/mujoco/mujoco_a2c.py | 4 ++-- examples/mujoco/mujoco_ddpg.py | 6 ++++-- examples/mujoco/mujoco_npg.py | 4 ++-- examples/mujoco/mujoco_ppo.py | 4 ++-- examples/mujoco/mujoco_redq.py | 4 ++-- examples/mujoco/mujoco_reinforce.py | 4 ++-- examples/mujoco/mujoco_sac.py | 4 ++-- examples/mujoco/mujoco_td3.py | 6 ++++-- examples/mujoco/mujoco_trpo.py | 4 ++-- examples/offline/atari_bcq.py | 6 +++--- examples/offline/atari_crr.py | 4 ++-- examples/offline/d4rl_cql.py | 4 ++-- examples/offline/d4rl_il.py | 4 ++-- examples/offline/d4rl_td3_bc.py | 4 ++-- examples/vizdoom/vizdoom_ppo.py | 4 ++-- test/base/test_policy.py | 10 +++++----- test/continuous/test_ddpg.py | 6 ++++-- test/continuous/test_npg.py | 6 ++++-- test/continuous/test_ppo.py | 6 ++++-- test/continuous/test_redq.py | 4 ++-- test/continuous/test_sac_with_il.py | 12 +++++++++--- test/continuous/test_td3.py | 6 ++++-- test/continuous/test_trpo.py | 6 ++++-- test/discrete/test_a2c_with_il.py | 6 +++--- test/discrete/test_discrete_sac.py | 6 ++++-- test/discrete/test_ppo.py | 8 +++++--- test/modelbased/test_ppo_icm.py | 4 ++-- test/offline/gather_pendulum_data.py | 4 ++-- test/offline/test_cql.py | 4 ++-- test/offline/test_discrete_bcq.py | 6 +++--- test/offline/test_discrete_crr.py | 4 ++-- test/offline/test_gail.py | 4 ++-- test/offline/test_td3_bc.py | 4 ++-- test/pettingzoo/pistonball_continuous.py | 4 ++-- tianshou/env/atari/atari_network.py | 6 +++--- tianshou/highlevel/algorithm.py | 10 +++++----- tianshou/highlevel/module/actor.py | 24 ++++++++++++------------ tianshou/highlevel/module/critic.py | 6 +++--- tianshou/policy/modelfree/ddpg.py | 4 ++-- tianshou/policy/modelfree/pg.py | 8 ++++---- tianshou/policy/modelfree/redq.py | 4 ++-- tianshou/policy/modelfree/sac.py | 4 ++-- tianshou/utils/net/common.py | 4 ++-- tianshou/utils/net/continuous.py | 6 +++--- tianshou/utils/net/discrete.py | 5 ++--- 52 files changed, 161 insertions(+), 130 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1136f029c..da546864b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -134,7 +134,13 @@ * The `test_in_train` parameter is now exposed (default False). * Inapplicable arguments can no longer be set in the respective subclass (e.g. `OffPolicyTrainingConfig` does not contain parameter `repeat_per_collect`). -* Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. +* Peripheral changes: + * The `Actor` classes have been renamed for clarity: + * `BaseActor` -> `Actor` + * `continuous.ActorProb` -> `ContinuousActorProb` + * `coninuous.Actor` -> `ContinuousActorDeterministic` + * `discrete.Actor` -> `DiscreteActor` + * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. ## Unreleased diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 65e0d9d49..6214d76a4 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -17,7 +17,7 @@ from tianshou.policy.modelfree.pg import DiscreteActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import Critic, DiscreteActor, IntrinsicCuriosityModule def get_args() -> argparse.Namespace: @@ -120,7 +120,7 @@ def main(args: argparse.Namespace = get_args()) -> None: ) if args.scale_obs: net = scale_obs(net) - actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) optim = AdamOptimizerFactory(lr=args.lr, eps=1e-5) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 9c52de85d..8de0c0870 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -17,7 +17,7 @@ from tianshou.policy.modelfree.sac import AutoAlpha from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import Critic, DiscreteActor, IntrinsicCuriosityModule def get_args() -> argparse.Namespace: @@ -116,7 +116,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: features_only=True, output_dim_added_layer=args.hidden_size, ) - actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) critic1 = Critic(net, last_size=args.action_shape, device=args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index b7c7a3f12..44a6d36dd 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -18,7 +18,7 @@ from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -111,7 +111,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( + actor = ContinuousActorProb( preprocess_net=net_a, action_shape=args.action_shape, device=args.device, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 76c5611da..13d833035 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -17,7 +17,7 @@ from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -69,7 +69,9 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) + actor = ContinuousActorProb(net, args.action_shape, device=args.device, unbounded=True).to( + args.device + ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 4a4b652f1..f41fa070c 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -30,7 +30,7 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -128,7 +128,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProb( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index bbdf7e1ec..1e0e31c50 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -28,7 +28,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net, get_dict_state_decorator -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic from tianshou.utils.space_info import ActionSpaceInfo @@ -154,7 +154,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, device=args.device, ) - actor = dict_state_dec(Actor)( + actor = dict_state_dec(ContinuousActorDeterministic)( net_a, args.action_shape, max_action=args.max_action, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 3efbb9f8b..27952fef8 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic def get_args() -> argparse.Namespace: @@ -95,7 +95,7 @@ def main(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProb( net_a, args.action_shape, unbounded=True, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 888b206a1..283a9fc9c 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic def get_args() -> argparse.Namespace: @@ -86,7 +86,9 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( + actor = ContinuousActorDeterministic( + net_a, args.action_shape, max_action=args.max_action, device=args.device + ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 596bbc433..0933bbdaa 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic def get_args() -> argparse.Namespace: @@ -100,7 +100,7 @@ def main(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProb( net_a, args.action_shape, unbounded=True, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index a425815ed..1c3d0f997 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic def get_args() -> argparse.Namespace: @@ -100,7 +100,7 @@ def main(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProb( net_a, args.action_shape, unbounded=True, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index eb6bb2dae..06a85ed45 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic def get_args() -> argparse.Namespace: @@ -90,7 +90,7 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( + actor = ContinuousActorProb( net_a, args.action_shape, device=args.device, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 1f946b9df..14ae0b84d 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.net.continuous import ContinuousActorProb def get_args() -> argparse.Namespace: @@ -92,7 +92,7 @@ def main(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProb( net_a, args.action_shape, unbounded=True, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 718aef429..bfc70d9b8 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic def get_args() -> argparse.Namespace: @@ -86,7 +86,7 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( + actor = ContinuousActorProb( net_a, args.action_shape, device=args.device, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 8e7b6ed5f..0f1273918 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic def get_args() -> argparse.Namespace: @@ -91,7 +91,9 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( + actor = ContinuousActorDeterministic( + net_a, args.action_shape, max_action=args.max_action, device=args.device + ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 4620fa00b..cfe0dfdda 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic def get_args() -> argparse.Namespace: @@ -103,7 +103,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProb( net_a, args.action_shape, unbounded=True, diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index ffa671b6e..9e6fb1ed9 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -21,7 +21,7 @@ from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.discrete import DiscreteActor def get_args() -> argparse.Namespace: @@ -105,14 +105,14 @@ def main(args: argparse.Namespace = get_args()) -> None: device=args.device, features_only=True, ).to(args.device) - policy_net = Actor( + policy_net = DiscreteActor( feature_net, args.action_shape, device=args.device, hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) - imitation_net = Actor( + imitation_net = DiscreteActor( feature_net, args.action_shape, device=args.device, diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 6972f5bb4..76accb371 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -21,7 +21,7 @@ from tianshou.policy.modelfree.pg import DiscreteActorPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.discrete import Critic, DiscreteActor from tianshou.utils.space_info import SpaceInfo @@ -107,7 +107,7 @@ def main(args: argparse.Namespace = get_args()) -> None: device=args.device, features_only=True, ).to(args.device) - actor = Actor( + actor = DiscreteActor( feature_net, args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 1b2580626..65d987ef0 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -20,7 +20,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -251,7 +251,7 @@ def test_cql() -> None: hidden_sizes=args.hidden_sizes, device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProb( net_a, action_shape=args.action_shape, device=args.device, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index e94ee3765..396557d83 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -19,7 +19,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor +from tianshou.utils.net.continuous import ContinuousActorDeterministic from tianshou.utils.space_info import SpaceInfo @@ -89,7 +89,7 @@ def test_il() -> None: hidden_sizes=args.hidden_sizes, device=args.device, ) - actor = Actor( + actor = ContinuousActorDeterministic( net, action_shape=args.action_shape, max_action=args.max_action, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 71817dc33..9d2c69cff 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -21,7 +21,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic from tianshou.utils.space_info import SpaceInfo @@ -109,7 +109,7 @@ def test_td3_bc() -> None: hidden_sizes=args.hidden_sizes, device=args.device, ) - actor = Actor( + actor = ContinuousActorDeterministic( net_a, action_shape=args.action_shape, max_action=args.max_action, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 551f8009e..ea69c8c00 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -18,7 +18,7 @@ from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import Critic, DiscreteActor, IntrinsicCuriosityModule def get_args() -> argparse.Namespace: @@ -126,7 +126,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: features_only=True, output_dim=args.hidden_size, ) - actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) critic = Critic(net, device=args.device) optim = AdamOptimizerFactory(lr=args.lr) diff --git a/test/base/test_policy.py b/test/base/test_policy.py index a861a68f1..84286b946 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -10,8 +10,8 @@ from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.discrete import DiscreteActor obs_shape = (5,) @@ -31,10 +31,10 @@ def test_calculate_discounted_returns() -> None: def algorithm(request: pytest.FixtureRequest) -> PPO: action_type = request.param action_space: gym.spaces.Box | gym.spaces.Discrete - actor: Actor | ActorProb + actor: DiscreteActor | ContinuousActorProb if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) - actor = ActorProb( + actor = ContinuousActorProb( Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape), action_shape=action_space.shape, ) @@ -45,7 +45,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: elif action_type == "discrete": action_space = gym.spaces.Discrete(3) - actor = Actor( + actor = DiscreteActor( Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n), action_shape=action_space.n, ) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index bdba823b6..3f706c21b 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic from tianshou.utils.space_info import SpaceInfo @@ -74,7 +74,9 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( + actor = ContinuousActorDeterministic( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to( args.device, ) net = Net( diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index b373a03ec..a42111bf2 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -18,7 +18,7 @@ from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -84,7 +84,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) + actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( + args.device + ) critic = Critic( Net( args.state_shape, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 795d0b46d..46872f948 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -84,7 +84,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) + actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( + args.device + ) critic = Critic( Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 6517be5c4..eaf7e2461 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -17,7 +17,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -82,7 +82,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( + actor = ContinuousActorProb( net, args.action_shape, device=args.device, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index d6930722f..7d8cae7fd 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -16,7 +16,11 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, ActorProb, Critic +from tianshou.utils.net.continuous import ( + ContinuousActorDeterministic, + ContinuousActorProb, + Critic, +) from tianshou.utils.space_info import SpaceInfo try: @@ -86,7 +90,9 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) + actor = ContinuousActorProb(net, args.action_shape, device=args.device, unbounded=True).to( + args.device + ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, @@ -175,7 +181,7 @@ def stop_fn(mean_rewards: float) -> bool: hidden_sizes=args.imitation_hidden_sizes, device=args.device, ) - il_actor = Actor( + il_actor = ContinuousActorDeterministic( il_net, args.action_shape, max_action=args.max_action, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index b52d488c0..0fae5a7ab 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic from tianshou.utils.space_info import SpaceInfo @@ -77,7 +77,9 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( + actor = ContinuousActorDeterministic( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 37f34febf..eda1112e0 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -17,7 +17,7 @@ from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -84,7 +84,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) + actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( + args.device + ) critic = Critic( Net( args.state_shape, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index c3f252857..3bdeb9196 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -17,7 +17,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.discrete import Critic, DiscreteActor try: import envpool @@ -93,7 +93,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, device=args.device).to(args.device) + actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical @@ -158,7 +158,7 @@ def stop_fn(mean_rewards: float) -> bool: # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, device=args.device).to(args.device) + actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( actor=actor, diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 0a2ef27d4..e7172b87a 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -19,7 +19,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.discrete import Critic, DiscreteActor from tianshou.utils.space_info import SpaceInfo @@ -81,7 +81,9 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: obs_dim = space_info.observation_info.obs_dim action_dim = space_info.action_info.action_dim net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, softmax_output=False, device=args.device).to(args.device) + actor = DiscreteActor(net, args.action_shape, softmax_output=False, device=args.device).to( + args.device + ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) critic1 = Critic(net_c1, last_size=action_dim, device=args.device).to(args.device) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index bb157755c..1109c9986 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -16,7 +16,7 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.discrete import Critic, DiscreteActor from tianshou.utils.space_info import SpaceInfo @@ -83,10 +83,12 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: actor: nn.Module critic: nn.Module if torch.cuda.is_available(): - actor = DataParallelNet(Actor(net, args.action_shape, device=args.device).to(args.device)) + actor = DataParallelNet( + DiscreteActor(net, args.action_shape, device=args.device).to(args.device) + ) critic = DataParallelNet(Critic(net, device=args.device).to(args.device)) else: - actor = Actor(net, args.action_shape, device=args.device).to(args.device) + actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index b3a68d1a4..e15a0b90e 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -17,7 +17,7 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net -from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import Critic, DiscreteActor, IntrinsicCuriosityModule from tianshou.utils.space_info import SpaceInfo @@ -101,7 +101,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, device=args.device).to(args.device) + actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index c1ca8cb33..c2e0bae55 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -93,7 +93,7 @@ def gather_data() -> VectorReplayBuffer: test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( + actor = ContinuousActorProb( net, args.action_shape, device=args.device, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index ee734c560..92a4632d9 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -18,7 +18,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -111,7 +111,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, device=args.device, ) - actor = ActorProb( + actor = ContinuousActorProb( net_a, action_shape=args.action_shape, device=args.device, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 17d05bb88..937b62b2b 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -21,7 +21,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.discrete import DiscreteActor from tianshou.utils.space_info import SpaceInfo @@ -77,13 +77,13 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: # model net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) - policy_net = Actor( + policy_net = DiscreteActor( net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ).to(args.device) - imitation_net = Actor( + imitation_net = DiscreteActor( net, args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index c07af56a8..38aa2fadf 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -21,7 +21,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.discrete import Critic, DiscreteActor from tianshou.utils.space_info import SpaceInfo @@ -72,7 +72,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: # model and algorithm net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) - actor = Actor( + actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 7a6940778..a94d869e7 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -17,7 +17,7 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic from tianshou.utils.space_info import SpaceInfo @@ -95,7 +95,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( + actor = ContinuousActorProb( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 641203122..baa101805 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -19,7 +19,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic from tianshou.utils.space_info import SpaceInfo @@ -100,7 +100,7 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, device=args.device, ) - actor = Actor( + actor = ContinuousActorDeterministic( net_a, action_shape=args.action_shape, max_action=args.max_action, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index dc2ffdce8..21a9943a0 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, Critic class DQNet(nn.Module): @@ -166,7 +166,7 @@ def get_agents( device=args.device, ).to(args.device) - actor = ActorProb( + actor = ContinuousActorProb( net, args.action_shape, max_action=args.max_action, diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index 3b1d90d64..20965af7e 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -17,7 +17,7 @@ from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import NetBase -from tianshou.utils.net.discrete import Actor, NoisyLinear +from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: @@ -262,7 +262,7 @@ def __init__( self.scale_obs = scale_obs self.features_only = features_only - def create_module(self, envs: Environments, device: TDevice) -> Actor: + def create_module(self, envs: Environments, device: TDevice) -> DiscreteActor: c, h, w = envs.get_observation_shape() # type: ignore # only right shape is a sequence of length 3 action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): @@ -280,7 +280,7 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: ) if self.scale_obs: net = scale_obs(net) - return Actor( + return DiscreteActor( net, envs.get_action_shape(), device=device, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 1dafb0ad4..8d785b7a8 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -75,7 +75,7 @@ from tianshou.policy.modelfree.sac import SACPolicy from tianshou.trainer import OffPolicyTrainer, OnPolicyTrainer, Trainer from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.discrete import DiscreteActor CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" @@ -619,7 +619,7 @@ def _get_critic_use_action(envs: Environments) -> bool: @abstractmethod def _create_policy( - self, actor: torch.nn.Module | Actor, envs: Environments, params: dict + self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> TPolicy: pass @@ -659,7 +659,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: class SACAlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[SACParams, SAC, SACPolicy]): def _create_policy( - self, actor: torch.nn.Module | Actor, envs: Environments, params: dict + self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> SACPolicy: return self._create_policy_from_args( SACPolicy, @@ -678,7 +678,7 @@ class DiscreteSACAlgorithmFactory( ActorDualCriticsOffPolicyAlgorithmFactory[DiscreteSACParams, DiscreteSAC, DiscreteSACPolicy] ): def _create_policy( - self, actor: torch.nn.Module | Actor, envs: Environments, params: dict + self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> DiscreteSACPolicy: return self._create_policy_from_args( DiscreteSACPolicy, @@ -695,7 +695,7 @@ def _get_algorithm_class(self) -> type[DiscreteSAC]: class TD3AlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params, TD3, DDPGPolicy]): def _create_policy( - self, actor: torch.nn.Module | Actor, envs: Environments, params: dict + self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict ) -> DDPGPolicy: return self._create_policy_from_args( DDPGPolicy, diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index ca73dc2a4..1b8de1150 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -24,7 +24,7 @@ ) from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import BaseActor, ModuleType, Net +from tianshou.utils.net.common import Actor, ModuleType, Net class ContinuousActorType(Enum): @@ -37,7 +37,7 @@ class ContinuousActorType(Enum): class ActorFuture: """Container, which, in the future, will hold an actor instance.""" - actor: BaseActor | nn.Module | None = None + actor: Actor | nn.Module | None = None class ActorFutureProviderProtocol(Protocol): @@ -47,7 +47,7 @@ def get_actor_future(self) -> ActorFuture: class ActorFactory(ModuleFactory, ToStringMixin, ABC): @abstractmethod - def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: pass @abstractmethod @@ -127,7 +127,7 @@ def _create_factory(self, envs: Environments) -> ActorFactory: raise ValueError(f"{env_type} not supported") return factory - def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: factory = self._create_factory(envs) return factory.create_module(envs, device) @@ -145,14 +145,14 @@ def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU self.hidden_sizes = hidden_sizes self.activation = activation - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, ) - return continuous.Actor( + return continuous.ContinuousActorDeterministic( preprocess_net=net_a, action_shape=envs.get_action_shape(), hidden_sizes=(), @@ -183,14 +183,14 @@ def __init__( self.conditioned_sigma = conditioned_sigma self.activation = activation - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, ) - actor = continuous.ActorProb( + actor = continuous.ContinuousActorProb( preprocess_net=net_a, action_shape=envs.get_action_shape(), unbounded=self.unbounded, @@ -220,14 +220,14 @@ def __init__( self.softmax_output = softmax_output self.activation = activation - def create_module(self, envs: Environments, device: TDevice) -> BaseActor: + def create_module(self, envs: Environments, device: TDevice) -> Actor: net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, ) - return discrete.Actor( + return discrete.DiscreteActor( net_a, envs.get_action_shape(), hidden_sizes=(), @@ -260,7 +260,7 @@ def __setstate__(self, state: dict) -> None: def _tostring_excludes(self) -> list[str]: return [*super()._tostring_excludes(), "_actor_future"] - def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: + def create_module(self, envs: Environments, device: TDevice) -> Actor | nn.Module: module = self.actor_factory.create_module(envs, device) self._actor_future.actor = module return module @@ -275,5 +275,5 @@ def __init__(self, actor_factory: ActorFactory): def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: actor = self.actor_factory.create_module(envs, device) - assert isinstance(actor, BaseActor) + assert isinstance(actor, Actor) return IntermediateModule(actor, actor.get_output_dim()) diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 35f5f9483..c5f6e3438 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -9,7 +9,7 @@ from tianshou.highlevel.module.actor import ActorFuture from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import BaseActor, EnsembleLinear, ModuleType, Net +from tianshou.utils.net.common import Actor, EnsembleLinear, ModuleType, Net class CriticFactory(ToStringMixin, ABC): @@ -158,9 +158,9 @@ def create_module( discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: actor = self.actor_future.actor - if not isinstance(actor, BaseActor): + if not isinstance(actor, Actor): raise ValueError( - f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}", + f"Option critic_use_action can only be used if actor is of type {Actor.__class__.__name__}", ) if envs.get_type().is_discrete(): # TODO get rid of this prod pattern here and elsewhere diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index b24a159bf..30a293a3f 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -28,7 +28,7 @@ TTrainingStats, ) from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic mark_used(ActBatchProtocol) @@ -93,7 +93,7 @@ class DDPGPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, - actor: torch.nn.Module | Actor, + actor: torch.nn.Module | ContinuousActorDeterministic, exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.Space, observation_space: gym.Space | None = None, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 9c2f18d59..2c2a5f7e7 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -29,8 +29,8 @@ ) from tianshou.policy.optim import OptimizerFactory from tianshou.utils import RunningMeanStd -from tianshou.utils.net.continuous import ActorProb -from tianshou.utils.net.discrete import Actor, dist_fn_categorical_from_logits +from tianshou.utils.net.continuous import ContinuousActorProb +from tianshou.utils.net.discrete import DiscreteActor, dist_fn_categorical_from_logits # Dimension Naming Convention # B - Batch Size @@ -59,7 +59,7 @@ class ActorPolicy(Policy): def __init__( self, *, - actor: torch.nn.Module | ActorProb | Actor, + actor: torch.nn.Module | ContinuousActorProb | DiscreteActor, dist_fn: TDistFnDiscrOrCont, deterministic_eval: bool = False, action_space: gym.Space, @@ -145,7 +145,7 @@ class DiscreteActorPolicy(ActorPolicy): def __init__( self, *, - actor: torch.nn.Module | Actor, + actor: torch.nn.Module | DiscreteActor, dist_fn: TDistFnDiscrete = dist_fn_categorical_from_logits, deterministic_eval: bool = False, action_space: gym.Space, diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index f62b7c818..401026e41 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -20,7 +20,7 @@ ) from tianshou.policy.modelfree.sac import Alpha from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.net.continuous import ContinuousActorProb @dataclass @@ -38,7 +38,7 @@ class REDQPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, - actor: torch.nn.Module | ActorProb, + actor: torch.nn.Module | ContinuousActorProb, exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.spaces.Space, deterministic_eval: bool = True, diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 58bb40f21..5d95529de 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -20,7 +20,7 @@ from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.policy.optim import OptimizerFactory from tianshou.utils.conversion import to_optional_float -from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.net.continuous import ContinuousActorProb def correct_log_prob_gaussian_tanh( @@ -56,7 +56,7 @@ class SACPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, - actor: torch.nn.Module | ActorProb, + actor: torch.nn.Module | ContinuousActorProb, exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, action_scaling: bool = True, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 243a04093..9bf7125f0 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -613,7 +613,7 @@ def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: return decorator_fn, new_state_shape -class BaseActor(nn.Module, ABC): +class Actor(nn.Module, ABC): @abstractmethod def get_preprocess_net(self) -> nn.Module: pass @@ -634,7 +634,7 @@ def forward( pass -class RandomActor(BaseActor): +class RandomActor(Actor): """An actor that returns random actions. For continuous action spaces, forward returns a batch of random actions sampled from the action space. diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index a0b85ede2..60eb4a1ac 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -10,7 +10,7 @@ from tianshou.utils.net.common import ( MLP, - BaseActor, + Actor, Net, TActionShape, TLinearLayer, @@ -21,7 +21,7 @@ SIGMA_MAX = 2 -class Actor(BaseActor): +class ContinuousActorDeterministic(Actor): """Simple actor network that directly outputs actions for continuous action space. Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`. @@ -178,7 +178,7 @@ def forward( return self.last(obs) -class ActorProb(BaseActor): +class ContinuousActorProb(Actor): """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian). Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index e8b596f65..81b70a53c 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -7,7 +7,7 @@ from torch import nn from tianshou.data import Batch, to_torch -from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim +from tianshou.utils.net.common import MLP, Actor, Net, TActionShape, get_output_dim def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions.Categorical: @@ -15,8 +15,7 @@ def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions return torch.distributions.Categorical(logits=logits) -# TODO rename to DiscreteActor? -class Actor(BaseActor): +class DiscreteActor(Actor): """Simple actor network for discrete action spaces. :param preprocess_net: a self-defined preprocess_net. Typically, an instance of From dbd472f87b4d231369e4c22421287cdc1014ac97 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 18 Mar 2025 13:44:59 +0100 Subject: [PATCH 080/230] v2: Rename Critic classes to improve clarity * continuous.Critic -> ContinuousCritic * discrete.Critic -> DiscreteCritic --- CHANGELOG.md | 3 +++ examples/atari/atari_ppo.py | 8 ++++++-- examples/atari/atari_sac.py | 10 +++++++--- examples/box2d/bipedal_hardcore_sac.py | 6 +++--- examples/box2d/mcc_sac.py | 6 +++--- examples/inverse/irl_gail.py | 6 +++--- examples/mujoco/fetch_her_ddpg.py | 4 ++-- examples/mujoco/mujoco_a2c.py | 4 ++-- examples/mujoco/mujoco_ddpg.py | 4 ++-- examples/mujoco/mujoco_npg.py | 4 ++-- examples/mujoco/mujoco_ppo.py | 4 ++-- examples/mujoco/mujoco_redq.py | 4 ++-- examples/mujoco/mujoco_sac.py | 6 +++--- examples/mujoco/mujoco_td3.py | 6 +++--- examples/mujoco/mujoco_trpo.py | 4 ++-- examples/offline/atari_crr.py | 4 ++-- examples/offline/d4rl_bcq.py | 6 +++--- examples/offline/d4rl_cql.py | 6 +++--- examples/offline/d4rl_td3_bc.py | 6 +++--- examples/vizdoom/vizdoom_ppo.py | 8 ++++++-- test/base/test_policy.py | 4 ++-- test/continuous/test_ddpg.py | 4 ++-- test/continuous/test_npg.py | 4 ++-- test/continuous/test_ppo.py | 4 ++-- test/continuous/test_redq.py | 6 ++++-- test/continuous/test_sac_with_il.py | 6 +++--- test/continuous/test_td3.py | 6 +++--- test/continuous/test_trpo.py | 4 ++-- test/discrete/test_a2c_with_il.py | 4 ++-- test/discrete/test_discrete_sac.py | 6 +++--- test/discrete/test_ppo.py | 6 +++--- test/modelbased/test_ppo_icm.py | 8 ++++++-- test/offline/gather_pendulum_data.py | 4 ++-- test/offline/test_bcq.py | 4 ++-- test/offline/test_cql.py | 4 ++-- test/offline/test_discrete_crr.py | 4 ++-- test/offline/test_gail.py | 6 +++--- test/offline/test_td3_bc.py | 6 +++--- test/pettingzoo/pistonball_continuous.py | 4 ++-- tianshou/highlevel/module/critic.py | 10 +++++----- tianshou/policy/imitation/discrete_crr.py | 4 ++-- tianshou/policy/imitation/gail.py | 6 +++--- tianshou/policy/modelfree/a2c.py | 8 ++++---- tianshou/policy/modelfree/ddpg.py | 4 ++-- tianshou/policy/modelfree/discrete_sac.py | 6 +++--- tianshou/policy/modelfree/npg.py | 6 +++--- tianshou/policy/modelfree/ppo.py | 6 +++--- tianshou/policy/modelfree/trpo.py | 6 +++--- tianshou/utils/net/continuous.py | 6 +++--- tianshou/utils/net/discrete.py | 4 ++-- 50 files changed, 145 insertions(+), 124 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index da546864b..35ebbc681 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -140,6 +140,9 @@ * `continuous.ActorProb` -> `ContinuousActorProb` * `coninuous.Actor` -> `ContinuousActorDeterministic` * `discrete.Actor` -> `DiscreteActor` + * The `Critic` classes have been renamed for clarity: + * `continuous.Critic` -> `ContinuousCritic` + * `discrete.Critic` -> `DiscreteCritic` * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. ## Unreleased diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 6214d76a4..9c65ad382 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -17,7 +17,11 @@ from tianshou.policy.modelfree.pg import DiscreteActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.discrete import Critic, DiscreteActor, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import ( + DiscreteActor, + DiscreteCritic, + IntrinsicCuriosityModule, +) def get_args() -> argparse.Namespace: @@ -121,7 +125,7 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.scale_obs: net = scale_obs(net) actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) - critic = Critic(net, device=args.device) + critic = DiscreteCritic(net, device=args.device) optim = AdamOptimizerFactory(lr=args.lr, eps=1e-5) if args.lr_decay: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 8de0c0870..0ca879d81 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -17,7 +17,11 @@ from tianshou.policy.modelfree.sac import AutoAlpha from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.discrete import Critic, DiscreteActor, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import ( + DiscreteActor, + DiscreteCritic, + IntrinsicCuriosityModule, +) def get_args() -> argparse.Namespace: @@ -118,9 +122,9 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: ) actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - critic1 = Critic(net, last_size=args.action_shape, device=args.device) + critic1 = DiscreteCritic(net, last_size=args.action_shape, device=args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = Critic(net, last_size=args.action_shape, device=args.device) + critic2 = DiscreteCritic(net, last_size=args.action_shape, device=args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # define policy and algorithm diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 44a6d36dd..cbb154d6b 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -18,7 +18,7 @@ from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -126,7 +126,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( @@ -136,7 +136,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 13d833035..0ecfa4ed8 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -17,7 +17,7 @@ from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -80,7 +80,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, @@ -89,7 +89,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index f41fa070c..feec53c44 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -30,7 +30,7 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -140,7 +140,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -165,7 +165,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: device=args.device, concat=True, ) - disc_net = Critic(net_d, device=args.device).to(args.device) + disc_net = ContinuousCritic(net_d, device=args.device).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 1e0e31c50..fa0c04713 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -28,7 +28,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net, get_dict_state_decorator -from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import ActionSpaceInfo @@ -168,7 +168,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device) + critic = dict_state_dec(ContinuousCritic)(net_c, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 27952fef8..ee5f4f5d0 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic def get_args() -> argparse.Namespace: @@ -107,7 +107,7 @@ def main(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) torch.nn.init.constant_(actor.sigma_param, -0.5) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 283a9fc9c..e201c435f 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic def get_args() -> argparse.Namespace: @@ -99,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 0933bbdaa..85e74a1d8 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic def get_args() -> argparse.Namespace: @@ -112,7 +112,7 @@ def main(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 1c3d0f997..2b2fb550b 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic def get_args() -> argparse.Namespace: @@ -112,7 +112,7 @@ def main(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) torch.nn.init.constant_(actor.sigma_param, -0.5) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 06a85ed45..b3965de47 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic def get_args() -> argparse.Namespace: @@ -110,7 +110,7 @@ def linear(x: int, y: int) -> EnsembleLinear: device=args.device, linear_layer=linear, ) - critics = Critic( + critics = ContinuousCritic( net_c, device=args.device, linear_layer=linear, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index bfc70d9b8..ef881b4ee 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic def get_args() -> argparse.Namespace: @@ -108,9 +108,9 @@ def main(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 0f1273918..18babdc4f 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic def get_args() -> argparse.Namespace: @@ -111,9 +111,9 @@ def main(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index cfe0dfdda..284b05747 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic def get_args() -> argparse.Namespace: @@ -115,7 +115,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 76accb371..8c98631b0 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -21,7 +21,7 @@ from tianshou.policy.modelfree.pg import DiscreteActorPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams -from tianshou.utils.net.discrete import Critic, DiscreteActor +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -114,7 +114,7 @@ def main(args: argparse.Namespace = get_args()) -> None: device=args.device, softmax_output=False, ).to(args.device) - critic = Critic( + critic = DiscreteCritic( feature_net, hidden_sizes=args.hidden_sizes, last_size=int(np.prod(args.action_shape)), diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 05322bf3e..6e6513599 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -20,7 +20,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net -from tianshou.utils.net.continuous import VAE, Critic, Perturbation +from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo @@ -124,9 +124,9 @@ def test_bcq() -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 65d987ef0..bd607b06b 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -20,7 +20,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -275,9 +275,9 @@ def test_cql() -> None: concat=True, device=args.device, ) - critic = Critic(net_c1, device=args.device).to(args.device) + critic = ContinuousCritic(net_c1, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 9d2c69cff..7b46fe326 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -21,7 +21,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -132,9 +132,9 @@ def test_td3_bc() -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index ea69c8c00..a79bbaec1 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -18,7 +18,11 @@ from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.discrete import Critic, DiscreteActor, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import ( + DiscreteActor, + DiscreteCritic, + IntrinsicCuriosityModule, +) def get_args() -> argparse.Namespace: @@ -127,7 +131,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: output_dim=args.hidden_size, ) actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) - critic = Critic(net, device=args.device) + critic = DiscreteCritic(net, device=args.device) optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 84286b946..57859e303 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -10,7 +10,7 @@ from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.net.discrete import DiscreteActor obs_shape = (5,) @@ -53,7 +53,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: else: raise ValueError(f"Unknown action type: {action_type}") - critic = Critic( + critic = ContinuousCritic( Net(obs_shape, hidden_sizes=[64, 64]), ) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 3f706c21b..08b0ff729 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -86,7 +86,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic = Critic(net, device=args.device).to(args.device) + critic = ContinuousCritic(net, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index a42111bf2..586022cc0 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -18,7 +18,7 @@ from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -87,7 +87,7 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( args.device ) - critic = Critic( + critic = ContinuousCritic( Net( args.state_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 46872f948..7775aa338 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -87,7 +87,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( args.device ) - critic = Critic( + critic = ContinuousCritic( Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device, ).to(args.device) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index eaf7e2461..2d16a646c 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -17,7 +17,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -102,7 +102,9 @@ def linear(x: int, y: int) -> nn.Module: device=args.device, linear_layer=linear, ) - critic = Critic(net_c, device=args.device, linear_layer=linear, flatten_input=False).to( + critic = ContinuousCritic( + net_c, device=args.device, linear_layer=linear, flatten_input=False + ).to( args.device, ) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 7d8cae7fd..44d622914 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -19,7 +19,7 @@ from tianshou.utils.net.continuous import ( ContinuousActorDeterministic, ContinuousActorProb, - Critic, + ContinuousCritic, ) from tianshou.utils.space_info import SpaceInfo @@ -101,7 +101,7 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, @@ -110,7 +110,7 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 0fae5a7ab..5c884cd77 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -90,7 +90,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, @@ -99,7 +99,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index eda1112e0..3284d8468 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -17,7 +17,7 @@ from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -87,7 +87,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( args.device ) - critic = Critic( + critic = ContinuousCritic( Net( args.state_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 3bdeb9196..318c19939 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -17,7 +17,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import Critic, DiscreteActor +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic try: import envpool @@ -94,7 +94,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) + critic = DiscreteCritic(net, device=args.device).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical policy = ActorPolicy( diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index e7172b87a..63c9e1723 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -19,7 +19,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import Critic, DiscreteActor +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -86,10 +86,10 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic1 = Critic(net_c1, last_size=action_dim, device=args.device).to(args.device) + critic1 = DiscreteCritic(net_c1, last_size=action_dim, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net(obs_dim, hidden_sizes=args.hidden_sizes, device=args.device) - critic2 = Critic(net_c2, last_size=action_dim, device=args.device).to(args.device) + critic2 = DiscreteCritic(net_c2, last_size=action_dim, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # better not to use auto alpha in CartPole diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 1109c9986..598966b80 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -16,7 +16,7 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net -from tianshou.utils.net.discrete import Critic, DiscreteActor +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -86,10 +86,10 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: actor = DataParallelNet( DiscreteActor(net, args.action_shape, device=args.device).to(args.device) ) - critic = DataParallelNet(Critic(net, device=args.device).to(args.device)) + critic = DataParallelNet(DiscreteCritic(net, device=args.device).to(args.device)) else: actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) + critic = DiscreteCritic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index e15a0b90e..9e499dcc5 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -17,7 +17,11 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net -from tianshou.utils.net.discrete import Critic, DiscreteActor, IntrinsicCuriosityModule +from tianshou.utils.net.discrete import ( + DiscreteActor, + DiscreteCritic, + IntrinsicCuriosityModule, +) from tianshou.utils.space_info import SpaceInfo @@ -102,7 +106,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) + critic = DiscreteCritic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index c2e0bae55..59b103fb8 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -107,7 +107,7 @@ def gather_data() -> VectorReplayBuffer: concat=True, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 98349d318..65634621b 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -17,7 +17,7 @@ from tianshou.trainer.base import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net -from tianshou.utils.net.continuous import VAE, Critic, Perturbation +from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo @@ -116,7 +116,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 92a4632d9..62aee3eed 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -18,7 +18,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -128,7 +128,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic = Critic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(net_c, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 38aa2fadf..f6112203a 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -21,7 +21,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.discrete import Critic, DiscreteActor +from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -80,7 +80,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: softmax_output=False, ) action_dim = space_info.action_info.action_dim - critic = Critic( + critic = DiscreteCritic( net, hidden_sizes=args.hidden_sizes, last_size=action_dim, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index a94d869e7..7ffd5742e 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -17,7 +17,7 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -103,7 +103,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: ).to( args.device, ) - critic = Critic( + critic = ContinuousCritic( Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device, ).to(args.device) @@ -115,7 +115,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: torch.nn.init.zeros_(m.bias) optim = AdamOptimizerFactory(lr=args.lr) # discriminator - disc_net = Critic( + disc_net = ContinuousCritic( Net( state_shape=args.state_shape, action_shape=args.action_shape, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index baa101805..bba849e83 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -19,7 +19,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -123,9 +123,9 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: concat=True, device=args.device, ) - critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # policy and algorithm diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 21a9943a0..68e316717 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.continuous import ContinuousActorProb, Critic +from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic class DQNet(nn.Module): @@ -178,7 +178,7 @@ def get_agents( observation_space.shape[0], device=args.device, ).to(args.device) - critic = Critic(net2, device=args.device).to(args.device) + critic = ContinuousCritic(net2, device=args.device).to(args.device) for m in set(actor.modules()).union(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index c5f6e3438..992c53046 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -97,7 +97,7 @@ def create_module( activation=self.activation, device=device, ) - critic = continuous.Critic(net_c, device=device).to(device) + critic = continuous.ContinuousCritic(net_c, device=device).to(device) init_linear_orthogonal(critic) return critic @@ -126,7 +126,7 @@ def create_module( last_size = ( int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 ) - critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device) + critic = discrete.DiscreteCritic(net_c, device=device, last_size=last_size).to(device) init_linear_orthogonal(critic) return critic @@ -167,13 +167,13 @@ def create_module( last_size = ( int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 ) - return discrete.Critic( + return discrete.DiscreteCritic( actor.get_preprocess_net(), device=device, last_size=last_size, ).to(device) elif envs.get_type().is_continuous(): - return continuous.Critic( + return continuous.ContinuousCritic( actor.get_preprocess_net(), device=device, apply_preprocess_net_to_obs_only=True, @@ -250,7 +250,7 @@ def linear_layer(x: int, y: int) -> EnsembleLinear: device=device, linear_layer=linear_layer, ) - critic = continuous.Critic( + critic = continuous.ContinuousCritic( net_c, device=device, linear_layer=linear_layer, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 282eba63b..7eae67d67 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -19,7 +19,7 @@ PGTrainingStats, ) from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.discrete import Critic +from tianshou.utils.net.discrete import DiscreteCritic @dataclass @@ -42,7 +42,7 @@ def __init__( self, *, policy: DiscreteActorPolicy, - critic: torch.nn.Module | Critic, + critic: torch.nn.Module | DiscreteCritic, optim: OptimizerFactory, discount_factor: float = 0.99, policy_improvement_mode: Literal["exp", "binary", "all"] = "exp", diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index a12c31f25..0c625052a 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -16,8 +16,8 @@ from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.continuous import Critic -from tianshou.utils.net.discrete import Critic as DiscreteCritic +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic @dataclass(kw_only=True) @@ -37,7 +37,7 @@ def __init__( self, *, policy: ActorPolicy, - critic: torch.nn.Module | Critic | DiscreteCritic, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, expert_buffer: ReplayBuffer, disc_net: torch.nn.Module, diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index c7a4d9bd9..aaa5e9dd9 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -17,8 +17,8 @@ from tianshou.policy.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.continuous import Critic -from tianshou.utils.net.discrete import Critic as DiscreteCritic +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic @dataclass(kw_only=True) @@ -41,7 +41,7 @@ def __init__( self, *, policy: ActorPolicy, - critic: torch.nn.Module | Critic | DiscreteCritic, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_include_actor: bool, gae_lambda: float = 0.95, @@ -124,7 +124,7 @@ def __init__( self, *, policy: ActorPolicy, - critic: torch.nn.Module | Critic | DiscreteCritic, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, vf_coef: float = 0.5, ent_coef: float = 0.01, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 30a293a3f..eb1bb7294 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -28,7 +28,7 @@ TTrainingStats, ) from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.continuous import ContinuousActorDeterministic, Critic +from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic mark_used(ActBatchProtocol) @@ -306,7 +306,7 @@ def __init__( *, policy: DDPGPolicy, policy_optim: OptimizerFactory, - critic: torch.nn.Module | Critic, + critic: torch.nn.Module | ContinuousCritic, critic_optim: OptimizerFactory, tau: float = 0.005, gamma: float = 0.99, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 6ec6beb26..06c1fd860 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -16,7 +16,7 @@ from tianshou.policy.modelfree.sac import Alpha, SACTrainingStats from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.discrete import Critic +from tianshou.utils.net.discrete import DiscreteCritic @dataclass @@ -83,9 +83,9 @@ def __init__( *, policy: DiscreteSACPolicy, policy_optim: OptimizerFactory, - critic: torch.nn.Module | Critic, + critic: torch.nn.Module | DiscreteCritic, critic_optim: OptimizerFactory, - critic2: torch.nn.Module | Critic | None = None, + critic2: torch.nn.Module | DiscreteCritic | None = None, critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 2c8ad9b12..05d1dc532 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -13,8 +13,8 @@ from tianshou.policy.modelfree.a2c import ActorCriticOnPolicyAlgorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.continuous import Critic -from tianshou.utils.net.discrete import Critic as DiscreteCritic +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic @dataclass(kw_only=True) @@ -37,7 +37,7 @@ def __init__( self, *, policy: ActorPolicy, - critic: torch.nn.Module | Critic | DiscreteCritic, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_critic_iters: int = 5, actor_step_size: float = 0.5, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 42a7f1f8b..53a699c9c 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -13,8 +13,8 @@ from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import ActorCritic -from tianshou.utils.net.continuous import Critic -from tianshou.utils.net.discrete import Critic as DiscreteCritic +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic @dataclass(kw_only=True) @@ -61,7 +61,7 @@ def __init__( self, *, policy: ActorPolicy, - critic: torch.nn.Module | Critic | DiscreteCritic, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, eps_clip: float = 0.2, dual_clip: float | None = None, diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 4bb975f63..c3e51ea65 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -11,8 +11,8 @@ from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.continuous import Critic -from tianshou.utils.net.discrete import Critic as DiscreteCritic +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic @dataclass(kw_only=True) @@ -30,7 +30,7 @@ def __init__( self, *, policy: ActorPolicy, - critic: torch.nn.Module | Critic | DiscreteCritic, + critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, max_kl: float = 0.01, backtrack_coeff: float = 0.8, diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 60eb4a1ac..cab7da2cb 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -86,7 +86,7 @@ def forward( return action_BA, hidden_BH -class CriticBase(nn.Module, ABC): +class AbstractContinuousCritic(nn.Module, ABC): @abstractmethod def forward( self, @@ -97,7 +97,7 @@ def forward( """Mapping: (s_B, a_B) -> Q(s, a)_B.""" -class Critic(CriticBase): +class ContinuousCritic(AbstractContinuousCritic): """Simple critic network. It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). @@ -145,7 +145,7 @@ def __init__( def __setstate__(self, state: dict) -> None: setstate( - Critic, + ContinuousCritic, self, state, new_default_properties={"apply_preprocess_net_to_obs_only": False}, diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 81b70a53c..a1baf4cf7 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -87,7 +87,7 @@ def forward( return output_BA, hidden_BH -class Critic(nn.Module): +class DiscreteCritic(nn.Module): """Simple critic network for discrete action spaces. :param preprocess_net: a self-defined preprocess_net. Typically, an instance of @@ -163,7 +163,7 @@ def forward(self, taus: torch.Tensor) -> torch.Tensor: return self.net(cosines).view(batch_size, N, self.embedding_dim) -class ImplicitQuantileNetwork(Critic): +class ImplicitQuantileNetwork(DiscreteCritic): """Implicit Quantile Network. :param preprocess_net: a self-defined preprocess_net which output a From 864d17e60c375425a045a094ae1872c48c898beb Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 19 Mar 2025 15:02:46 +0100 Subject: [PATCH 081/230] v2: Fix device assignment issue #810 * Remove 'device' member (and constructor arg) from the following classes: * BranchingNet * C51Net * ContinuousActorDeterministic * ContinuousActorProb * ContinuousCritic * DiscreteActor * DiscreteCritic * DQNet * FullQuantileFunction * ImplicitQuantileNetwork * IntrinsicCuriosityModule * Net * MLP * Perturbation * QRDQNet * Rainbow * Recurrent * RecurrentActorProb * RecurrentCritic * VAE * Peripheral change: Force use of kwargs for all of the above --- CHANGELOG.md | 7 ++ examples/atari/atari_c51.py | 3 +- examples/atari/atari_dqn.py | 7 +- examples/atari/atari_fqf.py | 12 +-- examples/atari/atari_iqn.py | 10 +-- examples/atari/atari_ppo.py | 22 +++--- examples/atari/atari_qrdqn.py | 1 - examples/atari/atari_rainbow.py | 12 +-- examples/atari/atari_sac.py | 16 ++-- examples/box2d/acrobot_dualdqn.py | 1 - examples/box2d/bipedal_bdq.py | 13 ++- examples/box2d/bipedal_hardcore_sac.py | 9 +-- examples/box2d/lunarlander_dqn.py | 1 - examples/box2d/mcc_sac.py | 14 ++-- examples/inverse/irl_gail.py | 14 ++-- examples/mujoco/mujoco_a2c.py | 13 ++- examples/mujoco/mujoco_ddpg.py | 7 +- examples/mujoco/mujoco_npg.py | 13 ++- examples/mujoco/mujoco_ppo.py | 13 ++- examples/mujoco/mujoco_redq.py | 11 +-- examples/mujoco/mujoco_reinforce.py | 8 +- examples/mujoco/mujoco_sac.py | 13 ++- examples/mujoco/mujoco_td3.py | 10 +-- examples/mujoco/mujoco_trpo.py | 13 ++- examples/offline/atari_bcq.py | 19 ++--- examples/offline/atari_cql.py | 1 - examples/offline/atari_crr.py | 17 ++-- examples/offline/atari_il.py | 2 +- examples/offline/d4rl_bcq.py | 16 ++-- examples/offline/d4rl_cql.py | 10 +-- examples/offline/d4rl_il.py | 4 +- examples/offline/d4rl_td3_bc.py | 12 +-- examples/vizdoom/vizdoom_c51.py | 3 +- examples/vizdoom/vizdoom_ppo.py | 27 ++++--- test/base/test_policy.py | 10 ++- test/base/test_utils.py | 22 +++--- test/continuous/test_ddpg.py | 7 +- test/continuous/test_npg.py | 15 ++-- test/continuous/test_ppo.py | 11 ++- test/continuous/test_redq.py | 12 +-- test/continuous/test_sac_with_il.py | 22 +++--- test/continuous/test_td3.py | 10 +-- test/continuous/test_trpo.py | 15 ++-- test/discrete/test_a2c_with_il.py | 10 +-- test/discrete/test_bdqn.py | 13 ++- test/discrete/test_c51.py | 1 - test/discrete/test_discrete_sac.py | 16 ++-- test/discrete/test_dqn.py | 1 - test/discrete/test_drqn.py | 4 +- test/discrete/test_fqf.py | 12 ++- test/discrete/test_iqn.py | 10 +-- test/discrete/test_pg.py | 1 - test/discrete/test_ppo.py | 10 +-- test/discrete/test_qrdqn.py | 1 - test/discrete/test_rainbow.py | 1 - test/modelbased/test_dqn_icm.py | 11 +-- test/modelbased/test_ppo_icm.py | 16 ++-- test/offline/gather_cartpole_data.py | 1 - test/offline/gather_pendulum_data.py | 10 +-- test/offline/test_bcq.py | 13 +-- test/offline/test_cql.py | 7 +- test/offline/test_discrete_bcq.py | 12 ++- test/offline/test_discrete_cql.py | 1 - test/offline/test_discrete_crr.py | 6 +- test/offline/test_gail.py | 10 +-- test/offline/test_td3_bc.py | 12 +-- test/pettingzoo/pistonball.py | 1 - test/pettingzoo/pistonball_continuous.py | 7 +- test/pettingzoo/tic_tac_toe.py | 1 - tianshou/env/atari/atari_network.py | 24 +++--- tianshou/highlevel/module/actor.py | 10 +-- tianshou/highlevel/module/critic.py | 24 +++--- tianshou/highlevel/module/special.py | 1 - tianshou/highlevel/params/policy_wrapper.py | 7 +- tianshou/policy/imitation/gail.py | 6 +- tianshou/utils/net/common.py | 87 +++++++++------------ tianshou/utils/net/continuous.py | 63 +++++++-------- tianshou/utils/net/discrete.py | 61 +++++++-------- 78 files changed, 408 insertions(+), 541 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35ebbc681..be4eee4cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -144,6 +144,13 @@ * `continuous.Critic` -> `ContinuousCritic` * `discrete.Critic` -> `DiscreteCritic` * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. + * Fix issues pertaining to the torch device assignment of network components (#810): + * Remove 'device' member (and the corresponding constructor argument) from the following classes: + `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProb`, `ContinuousCritic`, + `DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`, + `IntrinsicCuriosityModule`, `Net`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, + `RecurrentActorProb`, `RecurrentCritic`, `VAE` + * (Peripheral change:) Require the use of keyword arguments for the constructors of all of these classes ## Unreleased diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 4c627dbbb..fbc485299 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -87,7 +87,8 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model - net = C51Net(*args.state_shape, args.action_shape, args.num_atoms, args.device) + c, h, w = args.state_shape + net = C51Net(c=c, h=h, w=w, action_shape=args.action_shape, num_atoms=args.num_atoms) # define policy and algorithm optim = AdamOptimizerFactory(lr=args.lr) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 901bcf64f..96f23a574 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -106,7 +106,8 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model - net = DQNet(*args.state_shape, args.action_shape, args.device).to(args.device) + c, h, w = args.state_shape + net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm @@ -123,7 +124,8 @@ def main(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ) if args.icm_lr_scale > 0: - feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( @@ -131,7 +133,6 @@ def main(args: argparse.Namespace = get_args()) -> None: feature_dim=feature_dim, action_dim=action_dim, hidden_sizes=[512], - device=args.device, ) icm_optim = AdamOptimizerFactory(lr=args.lr) algorithm = ICMOffPolicyWrapper( diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index b55e9804b..f8901583c 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -92,13 +92,13 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model - feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) net = FullQuantileFunction( - feature_net, - args.action_shape, - args.hidden_sizes, - args.num_cosines, - device=args.device, + preprocess_net=feature_net, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, + num_cosines=args.num_cosines, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 3bcbd133a..398031be8 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -91,13 +91,13 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model - feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) net = ImplicitQuantileNetwork( - feature_net, - args.action_shape, - args.hidden_sizes, + preprocess_net=feature_net, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, num_cosines=args.num_cosines, - device=args.device, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 9c65ad382..3c3a7b258 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -114,18 +114,20 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model + c, h, w = args.state_shape net = DQNet( - *args.state_shape, - args.action_shape, - device=args.device, + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, output_dim_added_layer=args.hidden_size, layer_init=layer_init, ) if args.scale_obs: net = scale_obs(net) - actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) - critic = DiscreteCritic(net, device=args.device) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) + critic = DiscreteCritic(preprocess_net=net) optim = AdamOptimizerFactory(lr=args.lr, eps=1e-5) if args.lr_decay: @@ -159,15 +161,15 @@ def main(args: argparse.Namespace = get_args()) -> None: recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: - feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, + feature_net=feature_net.net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=[args.hidden_size], - device=args.device, ) icm_optim = AdamOptimizerFactory(lr=args.lr) algorithm = ICMOnPolicyWrapper( # type: ignore[no-redef] diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index cd17a6ec8..05cff3103 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -94,7 +94,6 @@ def main(args: argparse.Namespace = get_args()) -> None: w=w, action_shape=args.action_shape, num_quantiles=args.num_quantiles, - device=args.device, ) # define policy and algorithm diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index a6ea5fe88..22a034394 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -102,12 +102,14 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model + c, h, w = args.state_shape net = Rainbow( - *args.state_shape, - args.action_shape, - args.num_atoms, - args.noisy_std, - args.device, + c=c, + h=h, + w=w, + action_shape=args.action_shape, + num_atoms=args.num_atoms, + noisy_std=args.noisy_std, is_dueling=not args.no_dueling, is_noisy=not args.no_noisy, ) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 0ca879d81..eebe3ca3b 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -120,11 +120,11 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: features_only=True, output_dim_added_layer=args.hidden_size, ) - actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - critic1 = DiscreteCritic(net, last_size=args.action_shape, device=args.device) + critic1 = DiscreteCritic(preprocess_net=net, last_size=args.action_shape) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = DiscreteCritic(net, last_size=args.action_shape, device=args.device) + critic2 = DiscreteCritic(preprocess_net=net, last_size=args.action_shape) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # define policy and algorithm @@ -151,15 +151,15 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, ).to(args.device) if args.icm_lr_scale > 0: - feature_net = DQNet(*args.state_shape, args.action_shape, args.device, features_only=True) + c, h, w = args.state_shape + feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, + feature_net=feature_net.net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=[args.hidden_size], - device=args.device, ) icm_optim = AdamOptimizerFactory(lr=args.actor_lr) algorithm = ICMOffPolicyWrapper( diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 37d9bb9df..07d592bed 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -73,7 +73,6 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, dueling_param=(Q_param, V_param), ) optim = AdamOptimizerFactory(lr=args.lr) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 5a3eabab4..9364d1282 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -94,13 +94,12 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = BranchingNet( - args.state_shape, - args.num_branches, - args.action_per_branch, - args.common_hidden_sizes, - args.value_hidden_sizes, - args.action_hidden_sizes, - device=args.device, + state_shape=args.state_shape, + num_branches=args.num_branches, + action_per_branch=args.action_per_branch, + common_hidden_sizes=args.common_hidden_sizes, + value_hidden_sizes=args.value_hidden_sizes, + action_hidden_sizes=args.action_hidden_sizes, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = BDQNPolicy( diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index cbb154d6b..ce494ff90 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -110,11 +110,10 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProb( preprocess_net=net_a, action_shape=args.action_shape, - device=args.device, unbounded=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) @@ -124,9 +123,8 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( @@ -134,9 +132,8 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index eb55fb1ce..d1edb3336 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -75,7 +75,6 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, dueling_param=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 0ecfa4ed8..9fa0310a8 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -68,28 +68,26 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ContinuousActorProb(net, args.action_shape, device=args.device, unbounded=True).to( - args.device - ) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProb( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index feec53c44..d56d770ee 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -123,24 +123,21 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) actor = ContinuousActorProb( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -158,14 +155,13 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: optim = AdamOptimizerFactory(lr=args.lr) # discriminator net_d = Net( - args.state_shape, + state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, concat=True, ) - disc_net = ContinuousCritic(net_d, device=args.device).to(args.device) + disc_net = ContinuousCritic(preprocess_net=net_d).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index ee5f4f5d0..f09469d7d 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -90,24 +90,21 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) actor = ContinuousActorProb( - net_a, - args.action_shape, + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) actor_critic = ActorCritic(actor, critic) torch.nn.init.constant_(actor.sigma_param, -0.5) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index e201c435f..2e7bfa2b6 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -85,9 +85,9 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( - net_a, args.action_shape, max_action=args.max_action, device=args.device + preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) @@ -97,9 +97,8 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 85e74a1d8..702df3c78 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -95,24 +95,21 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) actor = ContinuousActorProb( - net_a, - args.action_shape, + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 2b2fb550b..c41821ef1 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -95,24 +95,21 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) actor = ContinuousActorProb( - net_a, - args.action_shape, + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) actor_critic = ActorCritic(actor, critic) torch.nn.init.constant_(actor.sigma_param, -0.5) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index b3965de47..6d8ae40d7 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -89,11 +89,10 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProb( - net_a, - args.action_shape, - device=args.device, + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) @@ -107,12 +106,10 @@ def linear(x: int, y: int) -> EnsembleLinear: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, linear_layer=linear, ) critics = ContinuousCritic( - net_c, - device=args.device, + preprocess_net=net_c, linear_layer=linear, flatten_input=False, ).to(args.device) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 14ae0b84d..2932eeef8 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -87,16 +87,14 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) actor = ContinuousActorProb( - net_a, - args.action_shape, + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in actor.modules(): diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index ef881b4ee..1c57aa955 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -85,11 +85,10 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProb( - net_a, - args.action_shape, - device=args.device, + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) @@ -99,18 +98,16 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 18babdc4f..219147288 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -90,9 +90,9 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( - net_a, args.action_shape, max_action=args.max_action, device=args.device + preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) @@ -102,18 +102,16 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 284b05747..430fbd56a 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -98,24 +98,21 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) actor = ContinuousActorProb( - net_a, - args.action_shape, + preprocess_net=net_a, + action_shape=args.action_shape, unbounded=True, - device=args.device, ).to(args.device) net_c = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 9e6fb1ed9..b962d25aa 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -98,24 +98,21 @@ def main(args: argparse.Namespace = get_args()) -> None: assert len(args.state_shape) == 3 c, h, w = args.state_shape feature_net = DQNet( - c, - h, - w, - args.action_shape, - device=args.device, + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, ).to(args.device) policy_net = DiscreteActor( - feature_net, - args.action_shape, - device=args.device, + preprocess_net=feature_net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) imitation_net = DiscreteActor( - feature_net, - args.action_shape, - device=args.device, + preprocess_net=feature_net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, softmax_output=False, ).to(args.device) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 2f7849c1b..dae6c665b 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -105,7 +105,6 @@ def main(args: argparse.Namespace = get_args()) -> None: w=w, action_shape=args.action_shape, num_quantiles=args.num_quantiles, - device=args.device, ) optim = AdamOptimizerFactory(lr=args.lr) # define policy diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 8c98631b0..28a188514 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -100,25 +100,22 @@ def main(args: argparse.Namespace = get_args()) -> None: assert len(args.state_shape) == 3 c, h, w = args.state_shape feature_net = DQNet( - c, - h, - w, - args.action_shape, - device=args.device, + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, ).to(args.device) actor = DiscreteActor( - feature_net, - args.action_shape, + preprocess_net=feature_net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax_output=False, ).to(args.device) critic = DiscreteCritic( - feature_net, + preprocess_net=feature_net, hidden_sizes=args.hidden_sizes, last_size=int(np.prod(args.action_shape)), - device=args.device, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 52f61679a..885267efa 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -88,7 +88,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = DQNet(c, h, w, args.action_shape, device=args.device).to(args.device) + net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) # define policy policy = ImitationPolicy(actor=net, action_space=env.action_space) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 6e6513599..252d915e8 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -103,9 +103,8 @@ def test_bcq() -> None: input_dim=args.state_dim + args.action_dim, output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to( + actor = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) @@ -115,18 +114,16 @@ def test_bcq() -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae @@ -134,7 +131,6 @@ def test_bcq() -> None: vae_encoder = MLP( input_dim=args.state_dim + args.action_dim, hidden_sizes=args.vae_hidden_sizes, - device=args.device, ) if not args.latent_dim: args.latent_dim = args.action_dim * 2 @@ -142,15 +138,13 @@ def test_bcq() -> None: input_dim=args.state_dim + args.latent_dim, output_dim=args.action_dim, hidden_sizes=args.vae_hidden_sizes, - device=args.device, ) vae = VAE( - vae_encoder, - vae_decoder, + encoder=vae_encoder, + decoder=vae_decoder, hidden_dim=args.vae_hidden_sizes[-1], latent_dim=args.latent_dim, max_action=args.max_action, - device=args.device, ).to(args.device) vae_optim = AdamOptimizerFactory() diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index bd607b06b..64eccfd2e 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -249,12 +249,10 @@ def test_cql() -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) actor = ContinuousActorProb( - net_a, + preprocess_net=net_a, action_shape=args.action_shape, - device=args.device, unbounded=True, conditioned_sigma=True, ).to(args.device) @@ -266,18 +264,16 @@ def test_cql() -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 396557d83..8a4a9c520 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -87,13 +87,11 @@ def test_il() -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) actor = ContinuousActorDeterministic( - net, + preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 7b46fe326..772a90f0d 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -105,15 +105,13 @@ def test_td3_bc() -> None: # model # actor network net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) actor = ContinuousActorDeterministic( - net_a, + preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) @@ -123,18 +121,16 @@ def test_td3_bc() -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index a8cfca26f..dcc17e2a3 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -94,7 +94,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = C51Net(*args.state_shape, args.action_shape, args.num_atoms, args.device) + c, h, w = args.state_shape + net = C51Net(c=c, h=h, w=w, action_shape=args.action_shape, num_atoms=args.num_atoms) optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm policy = C51Policy( diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index a79bbaec1..b55202861 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -123,15 +123,17 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model + c, h, w = args.state_shape net = DQNet( - *args.state_shape, - args.action_shape, - device=args.device, + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, output_dim=args.hidden_size, ) - actor = DiscreteActor(net, args.action_shape, device=args.device, softmax_output=False) - critic = DiscreteCritic(net, device=args.device) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) + critic = DiscreteCritic(preprocess_net=net) optim = AdamOptimizerFactory(lr=args.lr) if args.lr_decay: @@ -170,20 +172,21 @@ def dist(logits: torch.Tensor) -> Categorical: recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: + c, h, w = args.state_shape feature_net = DQNet( - *args.state_shape, - args.action_shape, - device=args.device, + c=c, + h=h, + w=w, + action_shape=args.action_shape, features_only=True, output_dim=args.hidden_size, ) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, - device=args.device, + feature_net=feature_net.net, + feature_dim=feature_dim, + action_dim=action_dim, ) icm_optim = AdamOptimizerFactory(lr=args.lr) algorithm = ICMOnPolicyWrapper( diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 57859e303..bd94e4ecf 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -35,7 +35,9 @@ def algorithm(request: pytest.FixtureRequest) -> PPO: if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) actor = ContinuousActorProb( - Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape), + preprocess_net=Net( + state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape + ), action_shape=action_space.shape, ) @@ -46,7 +48,9 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: elif action_type == "discrete": action_space = gym.spaces.Discrete(3) actor = DiscreteActor( - Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n), + preprocess_net=Net( + state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n + ), action_shape=action_space.n, ) dist_fn = Categorical @@ -54,7 +58,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: raise ValueError(f"Unknown action type: {action_type}") critic = ContinuousCritic( - Net(obs_shape, hidden_sizes=[64, 64]), + preprocess_net=Net(state_shape=obs_shape, hidden_sizes=[64, 64]), ) optim = AdamOptimizerFactory(lr=1e-3) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index faa93ee84..6c992a165 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -52,10 +52,10 @@ def test_net() -> None: bsz = 64 # MLP data = torch.rand([bsz, 3]) - mlp = MLP(3, 6, hidden_sizes=[128]) + mlp = MLP(input_dim=3, output_dim=6, hidden_sizes=[128]) assert list(mlp(data).shape) == [bsz, 6] # output == 0 and len(hidden_sizes) == 0 means identity model - mlp = MLP(6, 0) + mlp = MLP(input_dim=6, output_dim=0) assert data.shape == mlp(data).shape # common net state_shape = (10, 2) @@ -63,8 +63,8 @@ def test_net() -> None: data = torch.rand([bsz, *state_shape]) expect_output_shape = [bsz, *action_shape] net = Net( - state_shape, - action_shape, + state_shape=state_shape, + action_shape=action_shape, hidden_sizes=[128, 128], norm_layer=torch.nn.LayerNorm, activation=None, @@ -74,20 +74,20 @@ def test_net() -> None: assert str(net).count("ReLU") == 0 Q_param = V_param = {"hidden_sizes": [128, 128]} net = Net( - state_shape, - action_shape, + state_shape=state_shape, + action_shape=action_shape, hidden_sizes=[128, 128], dueling_param=(Q_param, V_param), ) assert list(net(data)[0].shape) == expect_output_shape # concat - net = Net(state_shape, action_shape, hidden_sizes=[128], concat=True) + net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], concat=True) data = torch.rand([bsz, int(np.prod(state_shape)) + int(np.prod(action_shape))]) expect_output_shape = [bsz, 128] assert list(net(data)[0].shape) == expect_output_shape net = Net( - state_shape, - action_shape, + state_shape=state_shape, + action_shape=action_shape, hidden_sizes=[128], concat=True, dueling_param=(Q_param, V_param), @@ -96,11 +96,11 @@ def test_net() -> None: # recurrent actor/critic data = torch.rand([bsz, *state_shape]).flatten(1) expect_output_shape = [bsz, *action_shape] - net = RecurrentActorProb(3, state_shape, action_shape) + net = RecurrentActorProb(layer_num=3, state_shape=state_shape, action_shape=action_shape) mu, sigma = net(data)[0] assert mu.shape == sigma.shape assert list(mu.shape) == [bsz, 5] - net = RecurrentCritic(3, state_shape, action_shape) + net = RecurrentCritic(layer_num=3, state_shape=state_shape, action_shape=action_shape) data = torch.rand([bsz, 8, int(np.prod(state_shape))]) act = torch.rand(expect_output_shape) assert list(net(data, act).shape) == [bsz, 1] diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 08b0ff729..698fbaeca 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -73,9 +73,9 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( - net, args.action_shape, max_action=args.max_action, device=args.device + preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) @@ -84,9 +84,8 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = ContinuousCritic(net, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 586022cc0..0c9222b6c 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -79,22 +79,19 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: # model net = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, - ) - actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( - args.device ) + actor = ContinuousActorProb( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) critic = ContinuousCritic( - Net( - args.state_shape, + preprocess_net=Net( + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device, activation=nn.Tanh, ), - device=args.device, ).to(args.device) # orthogonal initialization diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 7775aa338..3c6edf6c3 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -83,13 +83,12 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( - args.device - ) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProb( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) critic = ContinuousCritic( - Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), - device=args.device, + preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 2d16a646c..976103419 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -81,11 +81,10 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProb( - net, - args.action_shape, - device=args.device, + preprocess_net=net, + action_shape=args.action_shape, unbounded=True, conditioned_sigma=True, ).to(args.device) @@ -99,12 +98,9 @@ def linear(x: int, y: int) -> nn.Module: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, linear_layer=linear, ) - critic = ContinuousCritic( - net_c, device=args.device, linear_layer=linear, flatten_input=False - ).to( + critic = ContinuousCritic(preprocess_net=net_c, linear_layer=linear, flatten_input=False).to( args.device, ) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 44d622914..b8320d42f 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -89,28 +89,26 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed + args.training_num) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ContinuousActorProb(net, args.action_shape, device=args.device, unbounded=True).to( - args.device - ) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = ContinuousActorProb( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim if args.auto_alpha: @@ -177,15 +175,13 @@ def stop_fn(mean_rewards: float) -> bool: if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal il_net = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.imitation_hidden_sizes, - device=args.device, ) il_actor = ContinuousActorDeterministic( - il_net, - args.action_shape, + preprocess_net=il_net, + action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 5c884cd77..e718f99eb 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -76,9 +76,9 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( - net, args.action_shape, max_action=args.max_action, device=args.device + preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) @@ -88,18 +88,16 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) policy = DDPGPolicy( actor=actor, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 3284d8468..518a42ddd 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -79,22 +79,19 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, - device=args.device, - ) - actor = ContinuousActorProb(net, args.action_shape, unbounded=True, device=args.device).to( - args.device ) + actor = ContinuousActorProb( + preprocess_net=net, action_shape=args.action_shape, unbounded=True + ).to(args.device) critic = ContinuousCritic( - Net( - args.state_shape, + preprocess_net=Net( + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device, activation=nn.Tanh, ), - device=args.device, ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 318c19939..194a4b4f1 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -92,9 +92,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) - critic = DiscreteCritic(net, device=args.device).to(args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) + critic = DiscreteCritic(preprocess_net=net).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical policy = ActorPolicy( @@ -157,8 +157,8 @@ def stop_fn(mean_rewards: float) -> bool: # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( actor=actor, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index a292db23f..eca1e1df9 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -91,13 +91,12 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = BranchingNet( - args.state_shape, - args.num_branches, - args.action_per_branch, - args.common_hidden_sizes, - args.value_hidden_sizes, - args.action_hidden_sizes, - device=args.device, + state_shape=args.state_shape, + num_branches=args.num_branches, + action_per_branch=args.action_per_branch, + common_hidden_sizes=args.common_hidden_sizes, + value_hidden_sizes=args.value_hidden_sizes, + action_hidden_sizes=args.action_hidden_sizes, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = BDQNPolicy( diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index de0b69396..240650c41 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -89,7 +89,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=True, num_atoms=args.num_atoms, ) diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 63c9e1723..2744d5e02 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -80,16 +80,16 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: # model obs_dim = space_info.observation_info.obs_dim action_dim = space_info.action_info.action_dim - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = DiscreteActor(net, args.action_shape, softmax_output=False, device=args.device).to( - args.device - ) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = DiscreteActor( + preprocess_net=net, action_shape=args.action_shape, softmax_output=False + ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic1 = DiscreteCritic(net_c1, last_size=action_dim, device=args.device).to(args.device) + net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + critic1 = DiscreteCritic(preprocess_net=net_c1, last_size=action_dim).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = Net(obs_dim, hidden_sizes=args.hidden_sizes, device=args.device) - critic2 = DiscreteCritic(net_c2, last_size=action_dim, device=args.device).to(args.device) + net_c2 = Net(state_shape=obs_dim, hidden_sizes=args.hidden_sizes) + critic2 = DiscreteCritic(preprocess_net=net_c2, last_size=action_dim).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # better not to use auto alpha in CartPole diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b33096fda..8d88d3e61 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -84,7 +84,6 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, # dueling=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 48112add8..29b525e75 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -72,7 +72,9 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Recurrent(args.layer_num, args.state_shape, args.action_shape, args.device).to( + net = Recurrent( + layer_num=args.layer_num, state_shape=args.state_shape, action_shape=args.action_shape + ).to( args.device, ) optim = AdamOptimizerFactory(lr=args.lr) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index bd3f76328..1d52e078b 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -85,18 +85,16 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: # model feature_net = Net( - args.state_shape, - args.hidden_sizes[-1], + state_shape=args.state_shape, + action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], - device=args.device, softmax=False, ) net = FullQuantileFunction( - feature_net, - args.action_shape, - args.hidden_sizes, + preprocess_net=feature_net, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, num_cosines=args.num_cosines, - device=args.device, ) optim = AdamOptimizerFactory(lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index eaebe99db..1381ebd06 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -85,17 +85,15 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: # model feature_net = Net( - args.state_shape, - args.hidden_sizes[-1], + state_shape=args.state_shape, + action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], - device=args.device, softmax=False, ) net = ImplicitQuantileNetwork( - feature_net, - args.action_shape, + preprocess_net=feature_net, + action_shape=args.action_shape, num_cosines=args.num_cosines, - device=args.device, ) optim = AdamOptimizerFactory(lr=args.lr) policy = IQNPolicy( diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 465dc6f75..273c4c121 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -71,7 +71,6 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=True, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 598966b80..5e25459df 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -79,17 +79,17 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor: nn.Module critic: nn.Module if torch.cuda.is_available(): actor = DataParallelNet( - DiscreteActor(net, args.action_shape, device=args.device).to(args.device) + DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) ) - critic = DataParallelNet(DiscreteCritic(net, device=args.device).to(args.device)) + critic = DataParallelNet(DiscreteCritic(preprocess_net=net).to(args.device)) else: - actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) - critic = DiscreteCritic(net, device=args.device).to(args.device) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) + critic = DiscreteCritic(preprocess_net=net).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization for m in actor_critic.modules(): diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 4113282d2..b7fa509da 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -88,7 +88,6 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=False, num_atoms=args.num_quantiles, ) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 569b9e3eb..2dd342bf4 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -98,7 +98,6 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=True, num_atoms=args.num_atoms, dueling_param=({"linear_layer": noisy_linear}, {"linear_layer": noisy_linear}), diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 0f70ddd17..cedf3bd16 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -103,7 +103,6 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, # dueling=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) @@ -123,18 +122,16 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: feature_dim = args.hidden_sizes[-1] obs_dim = space_info.observation_info.obs_dim feature_net = MLP( - obs_dim, + input_dim=obs_dim, output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], - device=args.device, ) action_dim = space_info.action_info.action_dim icm_net = IntrinsicCuriosityModule( - feature_net, - feature_dim, - action_dim, + feature_net=feature_net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=args.hidden_sizes[-1:], - device=args.device, ).to(args.device) icm_optim = AdamOptimizerFactory(lr=args.lr) icm_algorithm = ICMOffPolicyWrapper( diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 9e499dcc5..efcf250ee 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -104,9 +104,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = DiscreteActor(net, args.action_shape, device=args.device).to(args.device) - critic = DiscreteCritic(net, device=args.device).to(args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) + critic = DiscreteCritic(preprocess_net=net).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization @@ -145,18 +145,16 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # ICM wrapper feature_dim = args.hidden_sizes[-1] feature_net = MLP( - space_info.observation_info.obs_dim, + input_dim=space_info.observation_info.obs_dim, output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], - device=args.device, ) action_dim = space_info.action_info.action_dim icm_net = IntrinsicCuriosityModule( - feature_net, - feature_dim, - action_dim, + feature_net=feature_net, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=args.hidden_sizes[-1:], - device=args.device, ).to(args.device) icm_optim = AdamOptimizerFactory(lr=args.lr) icm_algorithm = ICMOnPolicyWrapper( diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index e53c2c518..21fa9930c 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -93,7 +93,6 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=False, num_atoms=args.num_quantiles, ) diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 59b103fb8..cca234f91 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -92,11 +92,10 @@ def gather_data() -> VectorReplayBuffer: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProb( - net, - args.action_shape, - device=args.device, + preprocess_net=net, + action_shape=args.action_shape, unbounded=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) @@ -105,9 +104,8 @@ def gather_data() -> VectorReplayBuffer: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) action_dim = space_info.action_info.action_dim diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 65634621b..0516dc039 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -102,9 +102,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: input_dim=args.state_dim + args.action_dim, output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, - device=args.device, ) - actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to( + actor = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) @@ -114,9 +113,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) # vae @@ -124,7 +122,6 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: vae_encoder = MLP( input_dim=args.state_dim + args.action_dim, hidden_sizes=args.vae_hidden_sizes, - device=args.device, ) if not args.latent_dim: args.latent_dim = args.action_dim * 2 @@ -132,15 +129,13 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: input_dim=args.state_dim + args.latent_dim, output_dim=args.action_dim, hidden_sizes=args.vae_hidden_sizes, - device=args.device, ) vae = VAE( - vae_encoder, - vae_decoder, + encoder=vae_encoder, + decoder=vae_decoder, hidden_dim=args.vae_hidden_sizes[-1], latent_dim=args.latent_dim, max_action=args.max_action, - device=args.device, ).to(args.device) vae_optim = AdamOptimizerFactory() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 62aee3eed..1f31515b9 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -109,12 +109,10 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) actor = ContinuousActorProb( - net_a, + preprocess_net=net_a, action_shape=args.action_shape, - device=args.device, unbounded=True, conditioned_sigma=True, ).to(args.device) @@ -126,9 +124,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic = ContinuousCritic(net_c, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net_c).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) if args.auto_alpha: diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 937b62b2b..2a94f0ed8 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -76,18 +76,16 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) policy_net = DiscreteActor( - net, - args.action_shape, + preprocess_net=net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ).to(args.device) imitation_net = DiscreteActor( - net, - args.action_shape, + preprocess_net=net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteBCQPolicy( diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index a992773e8..5706dbe71 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -77,7 +77,6 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax=False, num_atoms=args.num_quantiles, ) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index f6112203a..5c06a56b7 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -71,20 +71,18 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model and algorithm - net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, softmax_output=False, ) action_dim = space_info.action_info.action_dim critic = DiscreteCritic( - net, + preprocess_net=net, hidden_sizes=args.hidden_sizes, last_size=action_dim, - device=args.device, ) optim = AdamOptimizerFactory(lr=args.lr) policy = DiscreteActorPolicy( diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 7ffd5742e..dd975284f 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -94,18 +94,16 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProb( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to( args.device, ) critic = ContinuousCritic( - Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), - device=args.device, + preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization @@ -116,15 +114,13 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: optim = AdamOptimizerFactory(lr=args.lr) # discriminator disc_net = ContinuousCritic( - Net( + preprocess_net=Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=torch.nn.Tanh, - device=args.device, concat=True, ), - device=args.device, ).to(args.device) for m in disc_net.modules(): if isinstance(m, torch.nn.Linear): diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index bba849e83..5cfe864ff 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -96,15 +96,13 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: # actor network net_a = Net( - args.state_shape, + state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ) actor = ContinuousActorDeterministic( - net_a, + preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) @@ -114,18 +112,16 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, - device=args.device, ) - critic1 = ContinuousCritic(net_c1, device=args.device).to(args.device) + critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - critic2 = ContinuousCritic(net_c2, device=args.device).to(args.device) + critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # policy and algorithm diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 48defd9ac..cf505d6a1 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -96,7 +96,6 @@ def get_agents( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = DQNPolicy( diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 68e316717..477549067 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -167,10 +167,9 @@ def get_agents( ).to(args.device) actor = ContinuousActorProb( - net, - args.action_shape, + preprocess_net=net, + action_shape=args.action_shape, max_action=args.max_action, - device=args.device, ).to(args.device) net2 = DQNet( observation_space.shape[2], @@ -178,7 +177,7 @@ def get_agents( observation_space.shape[0], device=args.device, ).to(args.device) - critic = ContinuousCritic(net2, device=args.device).to(args.device) + critic = ContinuousCritic(preprocess_net=net2).to(args.device) for m in set(actor.modules()).union(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index f1b15bf12..7ae46c68e 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -118,7 +118,6 @@ def get_agents( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, - device=args.device, ).to(args.device) if optim is None: optim = AdamOptimizerFactory(lr=args.lr) diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index 20965af7e..338b72d19 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -18,6 +18,7 @@ from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import NetBase from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear +from tianshou.utils.torch_utils import torch_device def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: @@ -64,7 +65,6 @@ def __init__( h: int, w: int, action_shape: Sequence[int] | int, - device: str | int | torch.device = "cpu", features_only: bool = False, output_dim_added_layer: int | None = None, layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, @@ -75,7 +75,6 @@ def __init__( "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", ) super().__init__() - self.device = device self.net = nn.Sequential( layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)), nn.ReLU(inplace=True), @@ -114,7 +113,8 @@ def forward( **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*).""" - obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + device = torch_device(self) + obs = torch.as_tensor(obs, device=device, dtype=torch.float32) return self.net(obs), state @@ -127,15 +127,15 @@ class C51Net(DQNet): def __init__( self, + *, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51, - device: str | int | torch.device = "cpu", ) -> None: self.action_num = int(np.prod(action_shape)) - super().__init__(c, h, w, [self.action_num * num_atoms], device) + super().__init__(c=c, h=h, w=w, action_shape=[self.action_num * num_atoms]) self.num_atoms = num_atoms def forward( @@ -161,17 +161,17 @@ class Rainbow(DQNet): def __init__( self, + *, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51, noisy_std: float = 0.5, - device: str | int | torch.device = "cpu", is_dueling: bool = True, is_noisy: bool = True, ) -> None: - super().__init__(c, h, w, action_shape, device, features_only=True) + super().__init__(c=c, h=h, w=w, action_shape=action_shape, features_only=True) self.action_num = int(np.prod(action_shape)) self.num_atoms = num_atoms @@ -230,10 +230,9 @@ def __init__( w: int, action_shape: Sequence[int] | int, num_quantiles: int = 200, - device: str | int | torch.device = "cpu", ) -> None: self.action_num = int(np.prod(action_shape)) - super().__init__(c, h, w, [self.action_num * num_quantiles], device) + super().__init__(c=c, h=h, w=w, action_shape=[self.action_num * num_quantiles]) self.num_quantiles = num_quantiles def forward( @@ -273,7 +272,6 @@ def create_module(self, envs: Environments, device: TDevice) -> DiscreteActor: h=h, w=w, action_shape=action_shape, - device=device, features_only=self.features_only, output_dim_added_layer=self.output_dim_added_layer, layer_init=layer_init, @@ -281,9 +279,8 @@ def create_module(self, envs: Environments, device: TDevice) -> DiscreteActor: if self.scale_obs: net = scale_obs(net) return DiscreteActor( - net, - envs.get_action_shape(), - device=device, + preprocess_net=net, + action_shape=envs.get_action_shape(), softmax_output=self.USE_SOFTMAX_OUTPUT, ).to(device) @@ -312,7 +309,6 @@ def create_intermediate_module(self, envs: Environments, device: TDevice) -> Int h=h, w=w, action_shape=action_shape, - device=device, features_only=self.features_only, ).to(device) module = dqn.net if self.net_only else dqn diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 1b8de1150..e849d6f9c 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -150,13 +150,11 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, - device=device, ) return continuous.ContinuousActorDeterministic( preprocess_net=net_a, action_shape=envs.get_action_shape(), hidden_sizes=(), - device=device, ).to(device) def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None: @@ -188,13 +186,11 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, - device=device, ) actor = continuous.ContinuousActorProb( preprocess_net=net_a, action_shape=envs.get_action_shape(), unbounded=self.unbounded, - device=device, conditioned_sigma=self.conditioned_sigma, ).to(device) @@ -225,13 +221,11 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, - device=device, ) return discrete.DiscreteActor( - net_a, - envs.get_action_shape(), + preprocess_net=net_a, + action_shape=envs.get_action_shape(), hidden_sizes=(), - device=device, softmax_output=self.softmax_output, ).to(device) diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 992c53046..54596be12 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -8,8 +8,10 @@ from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.actor import ActorFuture from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal -from tianshou.utils.net import continuous, discrete +from tianshou.utils.net import continuous from tianshou.utils.net.common import Actor, EnsembleLinear, ModuleType, Net +from tianshou.utils.net.continuous import ContinuousCritic +from tianshou.utils.net.discrete import DiscreteCritic class CriticFactory(ToStringMixin, ABC): @@ -95,9 +97,8 @@ def create_module( hidden_sizes=self.hidden_sizes, concat=use_action, activation=self.activation, - device=device, ) - critic = continuous.ContinuousCritic(net_c, device=device).to(device) + critic = continuous.ContinuousCritic(preprocess_net=net_c).to(device) init_linear_orthogonal(critic) return critic @@ -121,12 +122,11 @@ def create_module( hidden_sizes=self.hidden_sizes, concat=use_action, activation=self.activation, - device=device, ) last_size = ( int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 ) - critic = discrete.DiscreteCritic(net_c, device=device, last_size=last_size).to(device) + critic = DiscreteCritic(preprocess_net=net_c, last_size=last_size).to(device) init_linear_orthogonal(critic) return critic @@ -167,15 +167,13 @@ def create_module( last_size = ( int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1 ) - return discrete.DiscreteCritic( - actor.get_preprocess_net(), - device=device, + return DiscreteCritic( + preprocess_net=actor.get_preprocess_net(), last_size=last_size, ).to(device) elif envs.get_type().is_continuous(): - return continuous.ContinuousCritic( - actor.get_preprocess_net(), - device=device, + return ContinuousCritic( + preprocess_net=actor.get_preprocess_net(), apply_preprocess_net_to_obs_only=True, ).to(device) else: @@ -247,12 +245,10 @@ def linear_layer(x: int, y: int) -> EnsembleLinear: hidden_sizes=self.hidden_sizes, concat=use_action, activation=nn.Tanh, - device=device, linear_layer=linear_layer, ) critic = continuous.ContinuousCritic( - net_c, - device=device, + preprocess_net=net_c, linear_layer=linear_layer, flatten_input=False, ).to(device) diff --git a/tianshou/highlevel/module/special.py b/tianshou/highlevel/module/special.py index de572d7a1..6d119d739 100644 --- a/tianshou/highlevel/module/special.py +++ b/tianshou/highlevel/module/special.py @@ -27,5 +27,4 @@ def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantile hidden_sizes=self.hidden_sizes, num_cosines=self.num_cosines, preprocess_net_output_dim=preprocess_net.output_dim, - device=device, ).to(device) diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index f4fb74b58..2b49fcc08 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -63,11 +63,10 @@ def create_wrapped_algorithm( raise ValueError(f"Environment action shape must be an integer, got {action_dim}") feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.module, - feature_dim, - action_dim, + feature_net=feature_net.module, + feature_dim=feature_dim, + action_dim=action_dim, hidden_sizes=self.hidden_sizes, - device=device, ) optim_factory = self.optim_factory or optim_factory_default icm_optim = optim_factory.create_optimizer_factory(lr=self.lr) diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 0c625052a..b6341a3a7 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -18,6 +18,7 @@ from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic +from tianshou.utils.torch_utils import torch_device @dataclass(kw_only=True) @@ -126,8 +127,9 @@ def preprocess_batch( return super().preprocess_batch(batch, buffer, indices) def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: - obs = to_torch(batch.obs, device=self.disc_net.device) - act = to_torch(batch.act, device=self.disc_net.device) + device = torch_device(self.disc_net) + obs = to_torch(batch.obs, device=device) + act = to_torch(batch.act, device=device) return self.disc_net(torch.cat([obs, act], dim=1)) def _update_with_batch( # type: ignore diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 9bf7125f0..f77771d6c 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -10,6 +10,7 @@ from tianshou.data.batch import Batch, BatchProtocol from tianshou.data.types import RecurrentStateBatch from tianshou.utils.space_info import ActionSpaceInfo +from tianshou.utils.torch_utils import torch_device ModuleType = type[nn.Module] ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]] @@ -66,13 +67,13 @@ class MLP(nn.Module): the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. - :param device: which device to create this model on. Default to None. :param linear_layer: use this module as linear layer. Default to nn.Linear. :param flatten_input: whether to flatten input data. Default to True. """ def __init__( self, + *, input_dim: int, output_dim: int = 0, hidden_sizes: Sequence[int] = (), @@ -80,12 +81,10 @@ def __init__( norm_args: ArgsType | None = None, activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, act_args: ArgsType | None = None, - device: str | int | torch.device | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, ) -> None: super().__init__() - self.device = device if norm_layer: if isinstance(norm_layer, list): assert len(norm_layer) == len(hidden_sizes) @@ -136,7 +135,8 @@ def __init__( @no_type_check def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: - obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + device = torch_device(self) + obs = torch.as_tensor(obs, device=device, dtype=torch.float32) if self.flatten_input: obs = obs.flatten(1) return self.model(obs) @@ -176,8 +176,6 @@ class Net(NetBase[Any]): the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. - :param device: specify the device when the network actually runs. Default - to "cpu". :param softmax: whether to apply a softmax layer over the last layer's output. :param concat: whether the input shape is concatenated by state_shape @@ -205,6 +203,7 @@ class Net(NetBase[Any]): def __init__( self, + *, state_shape: int | Sequence[int], action_shape: TActionShape = 0, hidden_sizes: Sequence[int] = (), @@ -212,7 +211,6 @@ def __init__( norm_args: ArgsType | None = None, activation: ModuleType | Sequence[ModuleType] | None = nn.ReLU, act_args: ArgsType | None = None, - device: str | int | torch.device = "cpu", softmax: bool = False, concat: bool = False, num_atoms: int = 1, @@ -220,7 +218,6 @@ def __init__( linear_layer: TLinearLayer = nn.Linear, ) -> None: super().__init__() - self.device = device self.softmax = softmax self.num_atoms = num_atoms self.Q: MLP | None = None @@ -233,21 +230,19 @@ def __init__( self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 self.model = MLP( - input_dim, - output_dim, - hidden_sizes, - norm_layer, - norm_args, - activation, - act_args, - device, - linear_layer, + input_dim=input_dim, + output_dim=output_dim, + hidden_sizes=hidden_sizes, + norm_layer=norm_layer, + norm_args=norm_args, + activation=activation, + act_args=act_args, + linear_layer=linear_layer, ) if self.use_dueling: # dueling DQN assert dueling_param is not None kwargs_update = { "input_dim": self.model.output_dim, - "device": self.device, } # Important: don't change the original dict (e.g., don't use .update()) q_kwargs = {**dueling_param[0], **kwargs_update} @@ -298,14 +293,13 @@ class Recurrent(NetBase[RecurrentStateBatch]): def __init__( self, + *, layer_num: int, state_shape: int | Sequence[int], action_shape: TActionShape, - device: str | int | torch.device = "cpu", hidden_layer_size: int = 128, ) -> None: super().__init__() - self.device = device self.nn = nn.LSTM( input_size=hidden_layer_size, hidden_size=hidden_layer_size, @@ -340,7 +334,8 @@ def forward( f"Expected to find keys 'hidden' and 'cell' but instead found {state.keys()}", ) - obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) + device = torch_device(self) + obs = torch.as_tensor(obs, device=device, dtype=torch.float32) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -474,14 +469,13 @@ class BranchingNet(NetBase[Any]): the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. - :param device: specify the device when the network actually runs. Default - to "cpu". :param softmax: whether to apply a softmax layer over the last layer's output. """ def __init__( self, + *, state_shape: int | Sequence[int], num_branches: int = 0, action_per_branch: int = 2, @@ -492,41 +486,37 @@ def __init__( norm_args: ArgsType | None = None, activation: ModuleType | None = nn.ReLU, act_args: ArgsType | None = None, - device: str | int | torch.device = "cpu", ) -> None: super().__init__() common_hidden_sizes = common_hidden_sizes or [] value_hidden_sizes = value_hidden_sizes or [] action_hidden_sizes = action_hidden_sizes or [] - self.device = device self.num_branches = num_branches self.action_per_branch = action_per_branch # common network common_input_dim = int(np.prod(state_shape)) common_output_dim = 0 self.common = MLP( - common_input_dim, - common_output_dim, - common_hidden_sizes, - norm_layer, - norm_args, - activation, - act_args, - device, + input_dim=common_input_dim, + output_dim=common_output_dim, + hidden_sizes=common_hidden_sizes, + norm_layer=norm_layer, + norm_args=norm_args, + activation=activation, + act_args=act_args, ) # value network value_input_dim = common_hidden_sizes[-1] value_output_dim = 1 self.value = MLP( - value_input_dim, - value_output_dim, - value_hidden_sizes, - norm_layer, - norm_args, - activation, - act_args, - device, + input_dim=value_input_dim, + output_dim=value_output_dim, + hidden_sizes=value_hidden_sizes, + norm_layer=norm_layer, + norm_args=norm_args, + activation=activation, + act_args=act_args, ) # action branching network action_input_dim = common_hidden_sizes[-1] @@ -534,14 +524,13 @@ def __init__( self.branches = nn.ModuleList( [ MLP( - action_input_dim, - action_output_dim, - action_hidden_sizes, - norm_layer, - norm_args, - activation, - act_args, - device, + input_dim=action_input_dim, + output_dim=action_output_dim, + hidden_sizes=action_hidden_sizes, + norm_layer=norm_layer, + norm_args=norm_args, + activation=activation, + act_args=act_args, ) for _ in range(self.num_branches) ], diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index cab7da2cb..d1e279bee 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -16,6 +16,7 @@ TLinearLayer, get_output_dim, ) +from tianshou.utils.torch_utils import torch_device SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -42,23 +43,21 @@ class ContinuousActorDeterministic(Actor): def __init__( self, + *, preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, - device: str | int | torch.device = "cpu", preprocess_net_output_dim: int | None = None, ) -> None: super().__init__() - self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( - input_dim, - self.output_dim, - hidden_sizes, - device=self.device, + input_dim=input_dim, + output_dim=self.output_dim, + hidden_sizes=hidden_sizes, ) self.max_action = max_action @@ -120,25 +119,23 @@ class ContinuousCritic(AbstractContinuousCritic): def __init__( self, + *, preprocess_net: nn.Module | Net, hidden_sizes: Sequence[int] = (), - device: str | int | torch.device = "cpu", preprocess_net_output_dim: int | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, apply_preprocess_net_to_obs_only: bool = False, ) -> None: super().__init__() - self.device = device self.preprocess = preprocess_net self.output_dim = 1 self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( - input_dim, - 1, - hidden_sizes, - device=self.device, + input_dim=input_dim, + output_dim=1, + hidden_sizes=hidden_sizes, linear_layer=linear_layer, flatten_input=flatten_input, ) @@ -158,9 +155,10 @@ def forward( info: dict[str, Any] | None = None, ) -> torch.Tensor: """Mapping: (s_B, a_B) -> Q(s, a)_B.""" + device = torch_device(self) obs = torch.as_tensor( obs, - device=self.device, + device=device, dtype=torch.float32, ) if self.apply_preprocess_net_to_obs_only: @@ -169,7 +167,7 @@ def forward( if act is not None: act = torch.as_tensor( act, - device=self.device, + device=device, dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) @@ -199,14 +197,13 @@ class ContinuousActorProb(Actor): :ref:`build_the_network`. """ - # TODO: force kwargs, adjust downstream code def __init__( self, + *, preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, - device: str | int | torch.device = "cpu", unbounded: bool = False, conditioned_sigma: bool = False, preprocess_net_output_dim: int | None = None, @@ -216,17 +213,15 @@ def __init__( warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 self.preprocess = preprocess_net - self.device = device self.output_dim = int(np.prod(action_shape)) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) + self.mu = MLP(input_dim=input_dim, output_dim=self.output_dim, hidden_sizes=hidden_sizes) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP( - input_dim, - self.output_dim, - hidden_sizes, - device=self.device, + input_dim=input_dim, + output_dim=self.output_dim, + hidden_sizes=hidden_sizes, ) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) @@ -270,12 +265,12 @@ class RecurrentActorProb(nn.Module): def __init__( self, + *, layer_num: int, state_shape: Sequence[int], action_shape: Sequence[int], hidden_layer_size: int = 128, max_action: float = 1.0, - device: str | int | torch.device = "cpu", unbounded: bool = False, conditioned_sigma: bool = False, ) -> None: @@ -283,7 +278,6 @@ def __init__( if unbounded and not np.isclose(max_action, 1.0): warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 - self.device = device self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), hidden_size=hidden_layer_size, @@ -309,9 +303,10 @@ def forward( """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" if info is None: info = {} + device = torch_device(self) obs = torch.as_tensor( obs, - device=self.device, + device=device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -360,14 +355,12 @@ def __init__( self, layer_num: int, state_shape: Sequence[int], - action_shape: Sequence[int] = [0], - device: str | int | torch.device = "cpu", + action_shape: Sequence[int] = (0,), hidden_layer_size: int = 128, ) -> None: super().__init__() self.state_shape = state_shape self.action_shape = action_shape - self.device = device self.nn = nn.LSTM( input_size=int(np.prod(state_shape)), hidden_size=hidden_layer_size, @@ -385,9 +378,10 @@ def forward( """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" if info is None: info = {} + device = torch_device(self) obs = torch.as_tensor( obs, - device=self.device, + device=device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -400,7 +394,7 @@ def forward( if act is not None: act = torch.as_tensor( act, - device=self.device, + device=device, dtype=torch.float32, ) obs = torch.cat([obs, act], dim=1) @@ -428,15 +422,14 @@ class Perturbation(nn.Module): def __init__( self, + *, preprocess_net: nn.Module, max_action: float, - device: str | int | torch.device = "cpu", phi: float = 0.05, ): # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim super().__init__() self.preprocess_net = preprocess_net - self.device = device self.max_action = max_action self.phi = phi @@ -473,12 +466,12 @@ class VAE(nn.Module): def __init__( self, + *, encoder: nn.Module, decoder: nn.Module, hidden_dim: int, latent_dim: int, max_action: float, - device: str | torch.device = "cpu", ): super().__init__() self.encoder = encoder @@ -490,7 +483,6 @@ def __init__( self.max_action = max_action self.latent_dim = latent_dim - self.device = device def forward( self, @@ -521,8 +513,9 @@ def decode( if latent_z is None: # state.shape[0] may be batch_size # latent vector clipped to [-0.5, 0.5] + device = torch_device(self) latent_z = ( - torch.randn(state.shape[:-1] + (self.latent_dim,)).to(self.device).clamp(-0.5, 0.5) + torch.randn(state.shape[:-1] + (self.latent_dim,)).to(device).clamp(-0.5, 0.5) ) # decode z with state! diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index a1baf4cf7..fdbca00d8 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -8,6 +8,7 @@ from tianshou.data import Batch, to_torch from tianshou.utils.net.common import MLP, Actor, Net, TActionShape, get_output_dim +from tianshou.utils.torch_utils import torch_device def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions.Categorical: @@ -35,25 +36,21 @@ class DiscreteActor(Actor): def __init__( self, + *, preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), softmax_output: bool = True, preprocess_net_output_dim: int | None = None, - device: str | int | torch.device = "cpu", ) -> None: super().__init__() - # TODO: reduce duplication with continuous.py. Probably introducing - # base classes is a good idea. - self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( - input_dim, - self.output_dim, - hidden_sizes, - device=self.device, + input_dim=input_dim, + output_dim=self.output_dim, + hidden_sizes=hidden_sizes, ) self.softmax_output = softmax_output @@ -105,18 +102,17 @@ class DiscreteCritic(nn.Module): def __init__( self, + *, preprocess_net: nn.Module | Net, hidden_sizes: Sequence[int] = (), last_size: int = 1, preprocess_net_output_dim: int | None = None, - device: str | int | torch.device = "cpu", ) -> None: super().__init__() - self.device = device self.preprocess = preprocess_net self.output_dim = last_size input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) + self.last = MLP(input_dim=input_dim, output_dim=last_size, hidden_sizes=hidden_sizes) # TODO: make a proper interface! def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: @@ -187,19 +183,22 @@ class ImplicitQuantileNetwork(DiscreteCritic): def __init__( self, + *, preprocess_net: nn.Module, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, preprocess_net_output_dim: int | None = None, - device: str | int | torch.device = "cpu", ) -> None: last_size = int(np.prod(action_shape)) - super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device) - self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) - self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( - device, + super().__init__( + preprocess_net=preprocess_net, + hidden_sizes=hidden_sizes, + last_size=last_size, + preprocess_net_output_dim=preprocess_net_output_dim, ) + self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim) def forward( # type: ignore self, @@ -278,20 +277,19 @@ class FullQuantileFunction(ImplicitQuantileNetwork): def __init__( self, + *, preprocess_net: nn.Module, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, preprocess_net_output_dim: int | None = None, - device: str | int | torch.device = "cpu", ) -> None: super().__init__( - preprocess_net, - action_shape, - hidden_sizes, - num_cosines, - preprocess_net_output_dim, - device, + preprocess_net=preprocess_net, + action_shape=action_shape, + hidden_sizes=hidden_sizes, + num_cosines=num_cosines, + preprocess_net_output_dim=preprocess_net_output_dim, ) def _compute_quantiles(self, obs: torch.Tensor, taus: torch.Tensor) -> torch.Tensor: @@ -391,34 +389,30 @@ class IntrinsicCuriosityModule(nn.Module): :param feature_dim: input dimension of the feature net. :param action_dim: dimension of the action space. :param hidden_sizes: hidden layer sizes for forward and inverse models. - :param device: device for the module. """ def __init__( self, + *, feature_net: nn.Module, feature_dim: int, action_dim: int, hidden_sizes: Sequence[int] = (), - device: str | torch.device = "cpu", ) -> None: super().__init__() self.feature_net = feature_net self.forward_model = MLP( - feature_dim + action_dim, + input_dim=feature_dim + action_dim, output_dim=feature_dim, hidden_sizes=hidden_sizes, - device=device, ) self.inverse_model = MLP( - feature_dim * 2, + input_dim=feature_dim * 2, output_dim=action_dim, hidden_sizes=hidden_sizes, - device=device, ) self.feature_dim = feature_dim self.action_dim = action_dim - self.device = device def forward( self, @@ -428,10 +422,11 @@ def forward( **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Mapping: s1, act, s2 -> mse_loss, act_hat.""" - s1 = to_torch(s1, dtype=torch.float32, device=self.device) - s2 = to_torch(s2, dtype=torch.float32, device=self.device) + device = torch_device(self) + s1 = to_torch(s1, dtype=torch.float32, device=device) + s2 = to_torch(s2, dtype=torch.float32, device=device) phi1, phi2 = self.feature_net(s1), self.feature_net(s2) - act = to_torch(act, dtype=torch.long, device=self.device) + act = to_torch(act, dtype=torch.long, device=device) phi2_hat = self.forward_model( torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1), ) From df3c11618e84677c620859a44537d56e1e8034a7 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 20 Mar 2025 14:02:56 +0100 Subject: [PATCH 082/230] v2: Remove updating flag of Algorithm --- CHANGELOG.md | 1 + tianshou/policy/base.py | 21 +++++++++------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be4eee4cb..5ef061ef7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,6 +76,7 @@ for the optimizers created (via method `with_lr_scheduler_factory` and accompanying factory abstraction `LRSchedulerFactory`). The parameter `lr_scheduler` has thus been removed from all algorithm constructors. + * The flag `updating` has been removed (no internal usage, general usefulness questionable). * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3a56ec05a..a00f6d594 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -483,7 +483,6 @@ def __init__( super().__init__() self.policy: TPolicy = policy self.lr_schedulers: list[LRScheduler] = [] - self.updating = False def _create_optimizer( self, module: torch.nn.Module, factory: OptimizerFactory @@ -540,14 +539,15 @@ def _update( buffer: ReplayBuffer | None, update_with_batch_fn: Callable[[RolloutBatchProtocol], TTrainingStats], ) -> TTrainingStats: - """Update the policy network and replay buffer. + """Performs an update step. - It includes 3 function steps: process_fn, learn, and post_process_fn. In - addition, this function will change the value of ``self.updating``: it will be - False before this function and will be True when executing :meth:`update`. - Please refer to :ref:`policy_state` for more detailed explanation. The return - value of learn is augmented with the training time within update, while smoothed - loss values are computed in the trainer. + An update involves three algorithm-specific sub-steps: + * pre-processing of the batch, + * performing the actual network update with the batch, and + * post-processing of the batch. + + The return value is that of the network update call, augmented with the + training time within update. :param sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. None also @@ -555,8 +555,7 @@ def _update( first. :param buffer: the corresponding replay buffer. - :return: A dataclass object containing the data needed to be logged (e.g., loss) from - ``policy.learn()``. + :return: A dataclass object containing the data needed to be logged (e.g., loss) """ if not self.policy.is_within_training_step: raise RuntimeError( @@ -569,14 +568,12 @@ def _update( return TrainingStats() # type: ignore[return-value] start_time = time.time() batch, indices = buffer.sample(sample_size) - self.updating = True batch = self.preprocess_batch(batch, buffer, indices) with torch_train_mode(self): training_stat = update_with_batch_fn(batch) self.postprocess_batch(batch, buffer, indices) for lr_scheduler in self.lr_schedulers: lr_scheduler.step() - self.updating = False training_stat.train_time = time.time() - start_time return training_stat From e3ba271b66b797ab4c367a2e4d5b9313bf418494 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 20 Mar 2025 17:24:10 +0100 Subject: [PATCH 083/230] v2: Algorithms now internally use a wrapper (Algorithm.Optimizer) around the torch optimizers * This facilitates backpropagation steps with gradient clipping * The implementation ensures that the optimizers' state is handled alongside the model parameters when calling `state_dict` or `load_state_dict` on the `Algorithm` instance. Special handling of optimizer state restoration in examples/tests was thus removed. --- CHANGELOG.md | 9 ++- test/continuous/test_ppo.py | 8 +-- test/discrete/test_c51.py | 8 +-- test/discrete/test_rainbow.py | 8 +-- test/modelbased/test_dqn_icm.py | 2 +- test/offline/test_discrete_bcq.py | 8 +-- test/offline/test_gail.py | 8 +-- tianshou/policy/base.py | 82 +++++++++++++++++++++-- tianshou/policy/imitation/base.py | 7 +- tianshou/policy/imitation/bcq.py | 17 ++--- tianshou/policy/imitation/cql.py | 30 +++------ tianshou/policy/imitation/discrete_bcq.py | 4 +- tianshou/policy/imitation/discrete_cql.py | 4 +- tianshou/policy/imitation/discrete_crr.py | 4 +- tianshou/policy/imitation/gail.py | 4 +- tianshou/policy/imitation/td3_bc.py | 4 +- tianshou/policy/modelbased/icm.py | 7 +- tianshou/policy/modelfree/a2c.py | 21 +++--- tianshou/policy/modelfree/bdqn.py | 4 +- tianshou/policy/modelfree/c51.py | 4 +- tianshou/policy/modelfree/ddpg.py | 11 ++- tianshou/policy/modelfree/discrete_sac.py | 13 +--- tianshou/policy/modelfree/dqn.py | 4 +- tianshou/policy/modelfree/fqf.py | 12 ++-- tianshou/policy/modelfree/iqn.py | 4 +- tianshou/policy/modelfree/npg.py | 4 +- tianshou/policy/modelfree/pg.py | 4 +- tianshou/policy/modelfree/ppo.py | 12 +--- tianshou/policy/modelfree/qrdqn.py | 4 +- tianshou/policy/modelfree/redq.py | 8 +-- tianshou/policy/modelfree/sac.py | 4 +- tianshou/policy/modelfree/td3.py | 4 +- tianshou/policy/modelfree/trpo.py | 4 +- 33 files changed, 156 insertions(+), 175 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ef061ef7..c896a23eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,7 +91,14 @@ for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). * Multi-agent RL methods are now differentiated by the type of the sub-algorithms being employed (`MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm`), which renders all interfaces clean. - Helper class `MARLDispatcher` has been factored out to manage the dispatching of data to the respective agents. + Helper class `MARLDispatcher` has been factored out to manage the dispatching of data to the respective agents. + * Algorithms now internally use a wrapper (`Algorithm.Optimizer`) around the optimizers; creation is handled + by method `_create_optimizer`. + * This facilitates backpropagation steps with gradient clipping. + * The optimizers of an Algorithm instance are now centrally tracked, such that we can ensure that the + optimizers' states are handled alongside the model parameters when calling `state_dict` or `load_state_dict` + on the `Algorithm` instance. + Special handling of the restoration of optimizers' state dicts was thus removed from examples and tests. * Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOnPolicyAlgorithm` diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 3c6edf6c3..6d94f13a2 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -149,10 +149,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": algorithm.state_dict(), - "optim": optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) return ckpt_path @@ -163,8 +160,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - algorithm.load_state_dict(checkpoint["model"]) - optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 240650c41..56ca3ceb4 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -158,10 +158,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": algorithm.state_dict(), - "optim": optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) buffer_path = os.path.join(log_path, "train_buffer.pkl") @@ -175,8 +172,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - algorithm.load_state_dict(checkpoint["model"]) - algorithm.optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 2dd342bf4..4877ef392 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -177,10 +177,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": algorithm.state_dict(), - "optim": algorithm.optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) buffer_path = os.path.join(log_path, "train_buffer.pkl") @@ -194,8 +191,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - algorithm.load_state_dict(checkpoint["model"]) - algorithm.optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index cedf3bd16..d1e8c5839 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -168,7 +168,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: icm_algorithm) -> None: + def save_best_fn(policy: ICMOffPolicyWrapper) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 2a94f0ed8..27f012712 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -135,10 +135,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": algorithm.state_dict(), - "optim": optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) return ckpt_path @@ -149,8 +146,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - algorithm.load_state_dict(checkpoint["model"]) - optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index dd975284f..8d0b0dc59 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -184,10 +184,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # Example: saving by epoch num # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( - { - "model": algorithm.state_dict(), - "optim": algorithm.optim.state_dict(), - }, + algorithm.state_dict(), ckpt_path, ) return ckpt_path @@ -198,8 +195,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - algorithm.load_state_dict(checkpoint["model"]) - algorithm.optim.load_state_dict(checkpoint["optim"]) + algorithm.load_state_dict(checkpoint) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index a00f6d594..edebafac7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,7 +1,7 @@ import logging import time from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast @@ -474,6 +474,8 @@ class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainerParams, TTrainingStats policy.load_state_dict(torch.load("policy.pth")) """ + _STATE_DICT_KEY_OPTIMIZERS = "_optimizers" + def __init__( self, *, @@ -483,14 +485,86 @@ def __init__( super().__init__() self.policy: TPolicy = policy self.lr_schedulers: list[LRScheduler] = [] + self._optimizers: list["Algorithm.Optimizer"] = [] + """ + list of optimizers associated with the algorithm (created via `_create_optimizer`), + whose states will be returned when calling `state_dict` and which will be restored + when calling `load_state_dict` accordingly + """ + + class Optimizer: + """Wrapper for a torch optimizer that optionally performs gradient clipping.""" + + def __init__( + self, + optim: torch.optim.Optimizer, + module: torch.nn.Module, + max_grad_norm: float | None = None, + ) -> None: + """ + :param optim: the optimizer + :param module: the module whose parameters are being affected by `optim` + :param max_grad_norm: the maximum gradient norm for gradient clipping; if None, do not apply gradient clipping + """ + super().__init__() + self._optim = optim + self._module = module + self._max_grad_norm = max_grad_norm + + def step( + self, loss: torch.Tensor, retain_graph: bool | None = None, create_graph: bool = False + ) -> None: + """Performs an optimizer step, optionally applying gradient clipping (if configured at construction). + + :param loss: the loss to backpropagate + :param retain_graph: passed on to `backward` + :param create_graph: passed on to `backward` + """ + self._optim.zero_grad() + loss.backward(retain_graph=retain_graph, create_graph=create_graph) + if self._max_grad_norm is not None: + nn.utils.clip_grad_norm_(self._module.parameters(), max_norm=self._max_grad_norm) + self._optim.step() + + def state_dict(self) -> dict: + """Returns the `state_dict` of the wrapped optimizer.""" + return self._optim.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Loads the given `state_dict` into the wrapped optimizer.""" + self._optim.load_state_dict(state_dict) def _create_optimizer( - self, module: torch.nn.Module, factory: OptimizerFactory - ) -> torch.optim.Optimizer: + self, + module: torch.nn.Module, + factory: OptimizerFactory, + max_grad_norm: float | None = None, + ) -> Optimizer: optimizer, lr_scheduler = factory.create_instances(module) if lr_scheduler is not None: self.lr_schedulers.append(lr_scheduler) - return optimizer + optim = self.Optimizer(optimizer, module, max_grad_norm=max_grad_norm) + self._optimizers.append(optim) + return optim + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): # type: ignore + d = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) + + # add optimizer states + assert self._STATE_DICT_KEY_OPTIMIZERS not in d + d[self._STATE_DICT_KEY_OPTIMIZERS] = [o.state_dict() for o in self._optimizers] + + return d + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ) -> None: + # restore optimizer states + optimizers_state_dict = state_dict.pop(self._STATE_DICT_KEY_OPTIMIZERS) + for optim, optim_state in zip(self._optimizers, optimizers_state_dict, strict=True): + optim.load_state_dict(optim_state) + + super().load_state_dict(state_dict, strict=strict, assign=assign) def preprocess_batch( self, diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index ddf9231ed..a76806174 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -13,6 +13,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) +from tianshou.policy import Algorithm from tianshou.policy.base import ( OfflineAlgorithm, OffPolicyAlgorithm, @@ -91,9 +92,8 @@ def _imitation_update( self, batch: RolloutBatchProtocol, policy: ImitationPolicy, - optim: torch.optim.Optimizer, + optim: Algorithm.Optimizer, ) -> ImitationTrainingStats: - optim.zero_grad() if policy.action_type == "continuous": # regression act = policy(batch).act act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) @@ -104,8 +104,7 @@ def _imitation_update( loss = F.nll_loss(act, act_target) else: raise ValueError(policy.action_type) - loss.backward() - optim.step() + optim.step(loss) return ImitationTrainingStats(loss=loss.item()) diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 58a841ba8..d1b2300fb 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -172,9 +172,7 @@ def _update_with_batch( KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() vae_loss = recon_loss + KL_loss / 2 - self.vae_optim.zero_grad() - vae_loss.backward() - self.vae_optim.step() + self.vae_optim.step(vae_loss) # critic training: with torch.no_grad(): @@ -210,13 +208,8 @@ def _update_with_batch( critic1_loss = F.mse_loss(current_Q1, target_Q) critic2_loss = F.mse_loss(current_Q2, target_Q) - - self.critic_optim.zero_grad() - self.critic2_optim.zero_grad() - critic1_loss.backward() - critic2_loss.backward() - self.critic_optim.step() - self.critic2_optim.step() + self.critic_optim.step(critic1_loss) + self.critic2_optim.step(critic2_loss) sampled_act = self.policy.vae.decode(obs) perturbed_act = self.policy.actor_perturbation(obs, sampled_act) @@ -224,9 +217,7 @@ def _update_with_batch( # max actor_loss = -self.policy.critic(obs, perturbed_act).mean() - self.actor_perturbation_optim.zero_grad() - actor_loss.backward() - self.actor_perturbation_optim.step() + self.actor_perturbation_optim.step(actor_loss) # update target networks self._update_lagged_network_weights() diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 376eaa032..eef500c13 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -6,7 +6,6 @@ import torch import torch.nn.functional as F from overrides import override -from torch.nn.utils import clip_grad_norm_ from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.buffer.base import TBuffer @@ -101,9 +100,13 @@ def __init__( self.policy_optim = self._create_optimizer(self.policy, policy_optim) self.critic = critic - self.critic_optim = self._create_optimizer(self.critic, critic_optim) + self.critic_optim = self._create_optimizer( + self.critic, critic_optim, max_grad_norm=clip_grad + ) self.critic2 = critic2 or deepcopy(critic) - self.critic2_optim = self._create_optimizer(self.critic2, critic2_optim or critic_optim) + self.critic2_optim = self._create_optimizer( + self.critic2, critic2_optim or critic_optim, max_grad_norm=clip_grad + ) self.critic_old = self._add_lagged_network(self.critic) self.critic2_old = self._add_lagged_network(self.critic2) @@ -117,6 +120,7 @@ def __init__( self.cql_weight = cql_weight self.cql_log_alpha = torch.tensor([0.0], requires_grad=True) + # TODO: Use an OptimizerFactory? self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr) self.cql_log_alpha = self.cql_log_alpha.to(device) @@ -127,7 +131,6 @@ def __init__( self.alpha_min = alpha_min self.alpha_max = alpha_max - self.clip_grad = clip_grad self.calibrated = calibrated @@ -204,9 +207,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TCQLTrainingStats: # compute actor loss and update actor actor_loss, log_pi = self._calc_policy_loss(obs) - self.policy_optim.zero_grad() - actor_loss.backward() - self.policy_optim.step() + self.policy_optim.step(actor_loss) entropy = -log_pi.detach() alpha_loss = self.alpha.update(entropy) @@ -316,18 +317,9 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TCQLTrainingStats: critic1_loss = critic1_loss + cql1_scaled_loss critic2_loss = critic2_loss + cql2_scaled_loss - # update critic - self.critic_optim.zero_grad() - critic1_loss.backward(retain_graph=True) - # clip grad, prevent the vanishing gradient problem - # It doesn't seem necessary - clip_grad_norm_(self.critic.parameters(), self.clip_grad) - self.critic_optim.step() - - self.critic2_optim.zero_grad() - critic2_loss.backward() - clip_grad_norm_(self.critic2.parameters(), self.clip_grad) - self.critic2_optim.step() + # update critics + self.critic_optim.step(critic1_loss, retain_graph=True) + self.critic2_optim.step(critic2_loss) self._update_lagged_network_weights() diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index b0ebf94dc..6f38f067f 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -204,9 +204,7 @@ def _update_with_batch( reg_loss = imitation_logits.pow(2).mean() loss = q_loss + i_loss + self._weight_reg * reg_loss - self.optim.zero_grad() - loss.backward() - self.optim.step() + self.optim.step(loss) return DiscreteBCQTrainingStats( # type: ignore[return-value] loss=loss.item(), diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 6d6de8f00..695ddced1 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -71,7 +71,6 @@ def _update_with_batch( batch: RolloutBatchProtocol, ) -> TDiscreteCQLTrainingStats: self._periodically_update_lagged_network_weights() - self.optim.zero_grad() weight = batch.pop("weight", 1.0) all_dist = self.policy(batch).logits act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) @@ -94,8 +93,7 @@ def _update_with_batch( negative_sampling = q.logsumexp(1).mean() min_q_loss = negative_sampling - dataset_expec loss = qr_loss + min_q_loss * self.min_q_weight - loss.backward() - self.optim.step() + self.optim.step(loss) return DiscreteCQLTrainingStats( # type: ignore[return-value] loss=loss.item(), diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 7eae67d67..a0626e4c8 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -113,7 +113,6 @@ def _update_with_batch( # type: ignore ) -> TDiscreteCRRTrainingStats: if self._target and self._iter % self._freq == 0: self._update_lagged_network_weights() - self.optim.zero_grad() q_t = self.critic(batch.obs) act = to_torch(batch.act, dtype=torch.long, device=q_t.device) qa_t = q_t.gather(1, act.unsqueeze(1)) @@ -142,8 +141,7 @@ def _update_with_batch( # type: ignore # CQL loss/regularizer min_q_loss = (q_t.logsumexp(1) - qa_t).mean() loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss - loss.backward() - self.optim.step() + self.optim.step(loss) self._iter += 1 return DiscreteCRRTrainingStats( # type: ignore[return-value] diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index b6341a3a7..e6a27a5fd 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -150,9 +150,7 @@ def _update_with_batch( # type: ignore loss_pi = -F.logsigmoid(-logits_pi).mean() loss_exp = -F.logsigmoid(logits_exp).mean() loss_disc = loss_pi + loss_exp - self.disc_optim.zero_grad() - loss_disc.backward() - self.disc_optim.step() + self.disc_optim.step(loss_disc) losses.append(loss_disc.item()) acc_pis.append((logits_pi < 0).float().mean().item()) acc_exps.append((logits_exp > 0).float().mean().item()) diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 211a160c9..9f9f0c446 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -96,10 +96,8 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TTD3BCTrainingStats q_value = self.critic(batch.obs, act) lmbda = self.alpha / q_value.abs().mean().detach() actor_loss = -lmbda * q_value.mean() + F.mse_loss(act, to_torch_as(batch.act, act)) - self.policy_optim.zero_grad() - actor_loss.backward() self._last = actor_loss.item() - self.policy_optim.step() + self.policy_optim.step(actor_loss) self._update_lagged_network_weights() self._cnt += 1 diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 5da69d05f..41ae3d4ff 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -5,6 +5,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import RolloutBatchProtocol +from tianshou.policy import Algorithm from tianshou.policy.base import ( OffPolicyAlgorithm, OffPolicyWrapperAlgorithm, @@ -41,7 +42,7 @@ def __init__( self, *, model: IntrinsicCuriosityModule, - optim: torch.optim.Optimizer, + optim: Algorithm.Optimizer, lr_scale: float, reward_scale: float, forward_loss_weight: float, @@ -76,7 +77,6 @@ def _icm_update( batch: RolloutBatchProtocol, original_stats: TrainingStats, ) -> ICMTrainingStats: - self.optim.zero_grad() act_hat = batch.policy.act_hat act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) inverse_loss = F.cross_entropy(act_hat, act).mean() @@ -84,8 +84,7 @@ def _icm_update( loss = ( (1 - self.forward_loss_weight) * inverse_loss + self.forward_loss_weight * forward_loss ) * self.lr_scale - loss.backward() - self.optim.step() + self.optim.step(loss) return ICMTrainingStats( original_stats, diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index aaa5e9dd9..ddf2d3f02 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -5,7 +5,6 @@ import numpy as np import torch import torch.nn.functional as F -from torch import nn from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol @@ -44,6 +43,7 @@ def __init__( critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_include_actor: bool, + max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, discount_factor: float = 0.99, @@ -54,6 +54,8 @@ def __init__( :param optim: the optimizer factory. :param optim_include_actor: whether the optimizer shall include the actor network's parameters. Pass False for algorithms that shall update only the critic via the optimizer. + :param max_grad_norm: the maximum gradient norm for gradient clipping; if None, gradient clipping + is not applied :param gae_lambda: in [0, 1], param for generalized advantage estimation (GAE). :param max_batchsize: the maximum size of the batch when computing GAE. :param discount_factor: in [0, 1]. @@ -66,11 +68,12 @@ def __init__( assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}" self.gae_lambda = gae_lambda self.max_batchsize = max_batchsize - self._actor_critic = ActorCritic(self.policy.actor, self.critic) if optim_include_actor: - self.optim = self._create_optimizer(self._actor_critic, optim) + self.optim = self._create_optimizer( + ActorCritic(self.policy.actor, self.critic), optim, max_grad_norm=max_grad_norm + ) else: - self.optim = self._create_optimizer(self.critic, optim) + self.optim = self._create_optimizer(self.critic, optim, max_grad_norm=max_grad_norm) self.gamma = discount_factor self.rew_norm = reward_normalization self.ret_rms = RunningMeanStd() @@ -152,6 +155,7 @@ def __init__( critic=critic, optim=optim, optim_include_actor=True, + max_grad_norm=max_grad_norm, gae_lambda=gae_lambda, max_batchsize=max_batchsize, discount_factor=discount_factor, @@ -192,14 +196,7 @@ def _update_with_batch( # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = actor_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss - self.optim.zero_grad() - loss.backward() - if self.max_grad_norm: # clip large gradient - nn.utils.clip_grad_norm_( - self._actor_critic.parameters(), - max_norm=self.max_grad_norm, - ) - self.optim.step() + self.optim.step(loss) actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 313d201c0..eecdaf88a 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -185,7 +185,6 @@ def _update_with_batch( batch: RolloutBatchProtocol, ) -> TBDQNTrainingStats: self._periodically_update_lagged_network_weights() - self.optim.zero_grad() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) q = self.policy(batch).logits @@ -197,7 +196,6 @@ def _update_with_batch( td_error = returns - act_q loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() batch.weight = td_error.sum(-1).sum(-1) # prio-buffer - loss.backward() - self.optim.step() + self.optim.step(loss) return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 59f5b8301..87810d9b2 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -122,7 +122,6 @@ def _update_with_batch( batch: RolloutBatchProtocol, ) -> TC51TrainingStats: self._periodically_update_lagged_network_weights() - self.optim.zero_grad() with torch.no_grad(): target_dist = self._target_dist(batch) weight = batch.pop("weight", 1.0) @@ -133,7 +132,6 @@ def _update_with_batch( loss = (cross_entropy * weight).mean() # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 batch.weight = cross_entropy.detach() # prio-buffer - loss.backward() - self.optim.step() + self.optim.step(loss) return C51TrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index eb1bb7294..7b74152ac 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -18,6 +18,7 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.policy import Algorithm from tianshou.policy.base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OffPolicyAlgorithm, @@ -221,7 +222,7 @@ def __init__( def _minimize_critic_squared_loss( batch: RolloutBatchProtocol, critic: torch.nn.Module, - optimizer: torch.optim.Optimizer, + optimizer: Algorithm.Optimizer, ) -> tuple[torch.Tensor, torch.Tensor]: """Takes an optimizer step to minimize the squared loss of the critic given a batch of data. @@ -235,9 +236,7 @@ def _minimize_critic_squared_loss( target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() - optimizer.zero_grad() - critic_loss.backward() - optimizer.step() + optimizer.step(critic_loss) return td, critic_loss def preprocess_batch( @@ -343,9 +342,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDDPGTrainingStats: batch.weight = td # prio-buffer # actor actor_loss = -self.critic(batch.obs, self.policy(batch).act).mean() - self.policy_optim.zero_grad() - actor_loss.backward() - self.policy_optim.step() + self.policy_optim.step(actor_loss) self._update_lagged_network_weights() return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 06c1fd860..d3a0d39d8 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -141,19 +141,14 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDiscreteSACTrainin current_q1 = self.critic(batch.obs).gather(1, act).flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() - - self.critic_optim.zero_grad() - critic1_loss.backward() - self.critic_optim.step() + self.critic_optim.step(critic1_loss) # critic 2 current_q2 = self.critic2(batch.obs).gather(1, act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() + self.critic2_optim.step(critic2_loss) - self.critic2_optim.zero_grad() - critic2_loss.backward() - self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor @@ -164,9 +159,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDiscreteSACTrainin current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) actor_loss = -(self.alpha.value * entropy + (dist.probs * q).sum(dim=-1)).mean() - self.policy_optim.zero_grad() - actor_loss.backward() - self.policy_optim.step() + self.policy_optim.step(actor_loss) alpha_loss = self.alpha.update(entropy.detach()) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 54ce79c97..be61674ad 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -308,7 +308,6 @@ def _update_with_batch( batch: RolloutBatchProtocol, ) -> TDQNTrainingStats: self._periodically_update_lagged_network_weights() - self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self.policy(batch).logits q = q[np.arange(len(q)), batch.act] @@ -323,7 +322,6 @@ def _update_with_batch( loss = (td_error.pow(2) * weight).mean() batch.weight = td_error # prio-buffer - loss.backward() - self.optim.step() + self.optim.step(loss) return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 9c42e25ec..0ef47f93e 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -9,7 +9,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import QRDQN +from tianshou.policy import QRDQN, Algorithm from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats from tianshou.policy.optim import OptimizerFactory @@ -139,7 +139,7 @@ def __init__( self.fraction_optim = self._create_optimizer(self.policy.fraction_model, fraction_optim) @override - def _create_policy_optimizer(self, optim: OptimizerFactory) -> torch.optim.Optimizer: + def _create_policy_optimizer(self, optim: OptimizerFactory) -> Algorithm.Optimizer: # Override to leave out the fraction model (use main model only), as we want # to use a separate optimizer for the fraction model return self._create_optimizer(self.policy.model, optim) @@ -214,12 +214,8 @@ def _update_with_batch( # calculate entropy loss entropy_loss = out.fractions.entropies.mean() fraction_entropy_loss = fraction_loss - self.ent_coef * entropy_loss - self.fraction_optim.zero_grad() - fraction_entropy_loss.backward(retain_graph=True) - self.fraction_optim.step() - self.optim.zero_grad() - quantile_loss.backward() - self.optim.step() + self.fraction_optim.step(fraction_entropy_loss, retain_graph=True) + self.optim.step(quantile_loss) return FQFTrainingStats( # type: ignore[return-value] loss=quantile_loss.item() + fraction_entropy_loss.item(), diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 3f7e4f7f1..652a2d422 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -129,7 +129,6 @@ def _update_with_batch( batch: RolloutBatchProtocol, ) -> TIQNTrainingStats: self._periodically_update_lagged_network_weights() - self.optim.zero_grad() weight = batch.pop("weight", 1.0) action_batch = self.policy(batch) curr_dist, taus = action_batch.logits, action_batch.taus @@ -150,7 +149,6 @@ def _update_with_batch( # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer - loss.backward() - self.optim.step() + self.optim.step(loss) return IQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 05d1dc532..9077f407f 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -137,9 +137,7 @@ def _update_with_batch( # type: ignore for _ in range(self.optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) - self.optim.zero_grad() - vf_loss.backward() - self.optim.step() + self.optim.step(vf_loss) actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 2c2a5f7e7..6ba320046 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -297,15 +297,13 @@ def _update_with_batch( # type: ignore split_batch_size = batch_size or -1 for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): - self.optim.zero_grad() result = self.policy(minibatch) dist = result.dist act = to_torch_as(minibatch.act, result.act) ret = to_torch(minibatch.returns, torch.float, result.act.device) log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1) loss = -(log_prob * ret).mean() - loss.backward() - self.optim.step() + self.optim.step(loss) losses.append(loss.item()) loss_summary_stat = SequenceSummaryStats.from_sequence(losses) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 53a699c9c..5f637ce42 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -4,7 +4,6 @@ import numpy as np import torch -from torch import nn from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol @@ -12,7 +11,6 @@ from tianshou.policy.base import TrainingStats from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @@ -120,7 +118,6 @@ def __init__( self.value_clip = value_clip self.norm_adv = advantage_normalization self.recompute_adv = recompute_advantage - self._actor_critic: ActorCritic def preprocess_batch( self, @@ -185,14 +182,7 @@ def _update_with_batch( # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = clip_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss - self.optim.zero_grad() - loss.backward() - if self.max_grad_norm: # clip large gradient - nn.utils.clip_grad_norm_( - self._actor_critic.parameters(), - max_norm=self.max_grad_norm, - ) - self.optim.step() + self.optim.step(loss) clip_losses.append(clip_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 47179d033..460f6b812 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -97,7 +97,6 @@ def _update_with_batch( batch: RolloutBatchProtocol, ) -> TQRDQNTrainingStats: self._periodically_update_lagged_network_weights() - self.optim.zero_grad() weight = batch.pop("weight", 1.0) curr_dist = self.policy(batch).logits act = batch.act @@ -114,7 +113,6 @@ def _update_with_batch( # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer - loss.backward() - self.optim.step() + self.optim.step(loss) return QRDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 401026e41..d0faa5603 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -189,9 +189,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TREDQTrainingStats: target_q = batch.returns.flatten() td = current_qs - target_q critic_loss = (td.pow(2) * weight).mean() - self.critic_optim.zero_grad() - critic_loss.backward() - self.critic_optim.step() + self.critic_optim.step(critic_loss) batch.weight = torch.mean(td, dim=0) # prio-buffer self.critic_gradient_step += 1 @@ -202,9 +200,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TREDQTrainingStats: a = obs_result.act current_qa = self.critic(batch.obs, a).mean(dim=0).flatten() actor_loss = (self.alpha.value * obs_result.log_prob.flatten() - current_qa).mean() - self.policy_optim.zero_grad() - actor_loss.backward() - self.policy_optim.step() + self.policy_optim.step(actor_loss) # The entropy of a Gaussian policy can be expressed as -log_prob + a constant (which we ignore) entropy = -obs_result.log_prob.detach() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 5d95529de..208a065f1 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -278,9 +278,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TSACTrainingStats: actor_loss = ( self.alpha.value * obs_result.log_prob.flatten() - torch.min(current_q1a, current_q2a) ).mean() - self.policy_optim.zero_grad() - actor_loss.backward() - self.policy_optim.step() + self.policy_optim.step(actor_loss) # The entropy of a Gaussian policy can be expressed as -log_prob + a constant (which we ignore) entropy = -obs_result.log_prob.detach() diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index c928a9a35..902ca8ced 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -178,10 +178,8 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TTD3TrainingStats: # actor if self._cnt % self.update_actor_freq == 0: actor_loss = -self.critic(batch.obs, self.policy(batch, eps=0.0).act).mean() - self.policy_optim.zero_grad() - actor_loss.backward() self._last = actor_loss.item() - self.policy_optim.step() + self.policy_optim.step(actor_loss) self._update_lagged_network_weights() self._cnt += 1 diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index c3e51ea65..1231a83e0 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -151,9 +151,7 @@ def _update_with_batch( # type: ignore for _ in range(self.optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) - self.optim.zero_grad() - vf_loss.backward() - self.optim.step() + self.optim.step(vf_loss) actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) From c5455b9c2421e49cf975c6e9b6b15aa0089611dc Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 20 Mar 2025 22:15:30 +0100 Subject: [PATCH 084/230] v2: Improve handling of epsilon-greedy exploration for discrete Q-learning algorithms * All respective Policy implementations (e.g. `DQNPolicy`, `C51Policy`, etc.) now accept two parameters `eps_training` and `eps_inference`, which allows the training and test collection cases to be sufficiently differentiated and makes the use of callback functions (`train_fn`, `test_fn`) unnecessary if only constants are to be set. * The setter method `set_eps` has been replaced with `set_eps_training` and `set_eps_inference` accordingly. --- CHANGELOG.md | 6 +++ README.md | 28 ++++++------ examples/atari/atari_c51.py | 9 ++-- examples/atari/atari_dqn.py | 9 ++-- examples/atari/atari_fqf.py | 6 +-- examples/atari/atari_iqn.py | 9 ++-- examples/atari/atari_iqn_hl.py | 4 +- examples/atari/atari_qrdqn.py | 9 ++-- examples/atari/atari_rainbow.py | 9 ++-- examples/box2d/acrobot_dualdqn.py | 14 +++--- examples/box2d/bipedal_bdq.py | 11 ++--- examples/box2d/lunarlander_dqn.py | 10 ++--- examples/discrete/discrete_dqn.py | 7 ++- examples/discrete/discrete_dqn_hl.py | 6 +-- examples/offline/atari_cql.py | 2 - examples/vizdoom/vizdoom_c51.py | 9 ++-- test/discrete/test_bdqn.py | 9 ++-- test/discrete/test_c51.py | 13 +++--- test/discrete/test_dqn.py | 17 ++++--- test/discrete/test_drqn.py | 10 +---- test/discrete/test_fqf.py | 13 +++--- test/discrete/test_iqn.py | 13 +++--- test/discrete/test_qrdqn.py | 17 ++++--- test/discrete/test_rainbow.py | 13 +++--- test/modelbased/test_dqn_icm.py | 12 +++-- test/offline/gather_cartpole_data.py | 15 +++---- test/pettingzoo/pistonball.py | 11 +---- test/pettingzoo/tic_tac_toe.py | 11 +---- tianshou/highlevel/algorithm.py | 10 ++++- tianshou/highlevel/params/policy_params.py | 17 +++++++ tianshou/highlevel/trainer.py | 18 +++++--- tianshou/policy/modelfree/bdqn.py | 20 ++++++++- tianshou/policy/modelfree/c51.py | 19 +++++++- tianshou/policy/modelfree/dqn.py | 52 ++++++++++++++++++---- tianshou/policy/modelfree/fqf.py | 15 +++++++ tianshou/policy/modelfree/iqn.py | 23 ++++++++++ 36 files changed, 270 insertions(+), 206 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c896a23eb..5af6ae6e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,12 @@ * The default value for `test_in_train` was changed from True to False (updating all usage sites to explicitly set the parameter), because False is the more natural default, which does not make assumptions about returns/score values computed for the data from a collection step being at all meaningful for early stopping + * The management of episolon-greedy exploration for discrete Q-learning algorithms has been simplified: + * All respective Policy implementations (e.g. `DQNPolicy`, `C51Policy`, etc.) now accept two parameters + `eps_training` and `eps_inference`, which allows the training and test collection cases to be sufficiently + differentiated and makes the use of callback functions (`train_fn`, `test_fn`) unnecessary if only + constants are to be set. + * The setter method `set_eps` has been replaced with `set_eps_training` and `set_eps_inference` accordingly. * Further internal changes unlikely to affect usage: * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` * The two places that collected and evaluated test episodes (`_test_in_train` and `_reset`) in addition to diff --git a/README.md b/README.md index 61439a46d..8db202d57 100644 --- a/README.md +++ b/README.md @@ -388,19 +388,19 @@ Let's train it: ```python result = ts.trainer.OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=epoch, - step_per_epoch=step_per_epoch, - step_per_collect=step_per_collect, - episode_per_test=test_num, - batch_size=batch_size, - update_per_step=1 / step_per_collect, - train_fn=lambda epoch, env_step: policy.set_eps(eps_train), - test_fn=lambda epoch, env_step: policy.set_eps(eps_test), - stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, - logger=logger, + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=epoch, + step_per_epoch=step_per_epoch, + step_per_collect=step_per_collect, + episode_per_test=test_num, + batch_size=batch_size, + update_per_step=1 / step_per_collect, + train_fn=lambda epoch, env_step: policy.set_eps_training(eps_train), + test_fn=lambda epoch, env_step: policy.set_eps_training(eps_test), + stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, + logger=logger, ).run() print(f"Finished training in {result.timing.total_time} seconds") ``` @@ -416,7 +416,7 @@ Watch the agent with 35 FPS: ```python policy.eval() -policy.set_eps(eps_test) +policy.set_eps_training(eps_test) collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) ``` diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index fbc485299..1f6301af2 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -98,6 +98,8 @@ def main(args: argparse.Namespace = get_args()) -> None: num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: C51 = C51( policy=policy, @@ -163,16 +165,12 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -214,7 +212,6 @@ def watch() -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 96f23a574..251ad4af5 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -114,6 +114,8 @@ def main(args: argparse.Namespace = get_args()) -> None: policy = DQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: DQN | ICMOffPolicyWrapper algorithm = DQN( @@ -200,13 +202,10 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") @@ -215,7 +214,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -258,7 +256,6 @@ def watch() -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index f8901583c..578ae89ee 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -177,16 +177,16 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(args.eps_test) def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) + policy.set_eps_training(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 398031be8..888b61342 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -108,6 +108,8 @@ def main(args: argparse.Namespace = get_args()) -> None: sample_size=args.sample_size, online_sample_size=args.online_sample_size, target_sample_size=args.target_sample_size, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: IQN = IQN( policy=policy, @@ -172,17 +174,13 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -225,7 +223,6 @@ def watch() -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 4425c2649..e12267010 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -17,7 +17,6 @@ ) from tianshou.highlevel.params.policy_params import IQNParams from tianshou.highlevel.trainer import ( - EpochTestCallbackDQNSetEps, EpochTrainCallbackDQNEpsLinearDecay, ) @@ -85,13 +84,14 @@ def main( target_sample_size=target_sample_size, hidden_sizes=hidden_sizes, num_cosines=num_cosines, + eps_training=eps_train, + eps_inference=eps_test, ), ) .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True)) .with_epoch_train_callback( EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final), ) - .with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test)) .with_epoch_stop_callback(AtariEpochStopCallback(task)) .build() ) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 05cff3103..d4364511f 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -101,6 +101,8 @@ def main(args: argparse.Namespace = get_args()) -> None: policy = QRDQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: QRDQN = QRDQN( policy=policy, @@ -166,17 +168,13 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -219,7 +217,6 @@ def watch() -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 22a034394..64757c922 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -121,6 +121,8 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) optim = AdamOptimizerFactory(lr=args.lr) algorithm: C51 = RainbowDQN( @@ -200,7 +202,7 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) if not args.no_priority: @@ -212,12 +214,8 @@ def train_fn(epoch: int, env_step: int) -> None: if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/beta": beta}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -262,7 +260,6 @@ def watch() -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 07d592bed..6ca40ad72 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -79,6 +79,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: policy = DQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, @@ -95,7 +97,6 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -116,15 +117,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: if env_step <= 100000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 500000: eps = args.eps_train - (env_step - 100000) / 400000 * (0.5 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.5 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.5 * args.eps_train) # train result = algorithm.run_training( @@ -138,7 +136,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: batch_size=args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, @@ -150,7 +147,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 9364d1282..98ef08bd7 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -105,6 +105,8 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: policy = BDQNPolicy( model=net, action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: BDQN = BDQN( policy=policy, @@ -120,7 +122,6 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -139,10 +140,7 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) - policy.set_eps(eps) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(eps) # trainer result = algorithm.run_training( @@ -157,7 +155,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, - test_fn=test_fn, save_best_fn=save_best_fn, logger=logger, test_in_train=True, @@ -168,7 +165,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.set_eps(args.eps_test) + policy.set_eps_training(args.eps_test) test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index d1edb3336..7c8912e96 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -81,6 +81,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: policy = DQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, @@ -97,7 +99,6 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -118,10 +119,7 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) - policy.set_eps(eps) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(eps) # train result = algorithm.run_training( @@ -136,7 +134,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, - test_fn=test_fn, save_best_fn=save_best_fn, logger=logger, test_in_train=True, @@ -147,7 +144,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index aa958286c..3e95af9be 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -37,7 +37,9 @@ def main() -> None: net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = AdamOptimizerFactory(lr=lr) - policy = DQNPolicy(model=net, action_space=env.action_space) + policy = DQNPolicy( + model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test + ) algorithm: ts.policy.DQN = ts.policy.DQN( policy=policy, optim=optim, @@ -75,8 +77,6 @@ def stop_fn(mean_rewards: float) -> bool: episode_per_test=test_num, batch_size=batch_size, update_per_step=1 / step_per_collect, - train_fn=lambda epoch, env_step: policy.set_eps(eps_train), - test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=stop_fn, logger=logger, test_in_train=True, @@ -85,7 +85,6 @@ def stop_fn(mean_rewards: float) -> bool: print(f"Finished training in {result.timing.total_time} seconds") # watch performance - policy.set_eps(eps_test) collector = ts.data.Collector[CollectStats](algorithm, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index c44db7c08..5f2e3ec97 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -9,8 +9,6 @@ from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.trainer import ( EpochStopCallbackRewardThreshold, - EpochTestCallbackDQNSetEps, - EpochTrainCallbackDQNSetEps, ) @@ -46,11 +44,11 @@ def main() -> None: discount_factor=0.9, estimation_step=3, target_update_freq=320, + eps_training=0.3, + eps_inference=0.0, ), ) .with_model_factory_default(hidden_sizes=(64, 64)) - .with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3)) - .with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0)) .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) .build() ) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index dae6c665b..dd64f82e3 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -29,7 +29,6 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--num-quantiles", type=int, default=200) @@ -175,7 +174,6 @@ def stop_fn(mean_rewards: float) -> bool: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index dcc17e2a3..b5654395c 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -104,6 +104,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: C51 = C51( policy=policy, @@ -164,17 +166,13 @@ def train_fn(epoch: int, env_step: int) -> None: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final - policy.set_eps(eps) + policy.set_eps_training(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -216,7 +214,6 @@ def watch() -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index eca1e1df9..5fa986846 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -102,6 +102,8 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: policy = BDQNPolicy( model=net, action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: BDQN = BDQN( policy=policy, @@ -117,16 +119,12 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) - policy.set_eps(eps) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(eps) def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold @@ -143,7 +141,6 @@ def stop_fn(mean_rewards: float) -> bool: batch_size=args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, test_in_train=True, ) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 56ca3ceb4..64795b800 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -100,6 +100,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: C51 = C51( policy=policy, @@ -124,7 +126,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -142,15 +143,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.1 * args.eps_train) def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html @@ -196,7 +194,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: batch_size=args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 8d88d3e61..a68794413 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -88,7 +88,11 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) policy = DQNPolicy( - model=net, action_space=env.action_space, observation_space=env.observation_space + model=net, + action_space=env.action_space, + observation_space=env.observation_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, @@ -113,7 +117,6 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -131,15 +134,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( @@ -153,7 +153,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: batch_size=args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 29b525e75..ce7393a3b 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -81,6 +81,8 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: policy = DQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, @@ -112,12 +114,6 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch: int, env_step: int) -> None: - policy.set_eps(args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - # train result = algorithm.run_training( OffPolicyTrainerParams( @@ -129,8 +125,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: episode_per_test=args.test_num, batch_size=args.batch_size, update_per_step=args.update_per_step, - train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 1d52e078b..67288b7db 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -103,6 +103,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: model=net, fraction_model=fraction_net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: FQF = FQF( policy=policy, @@ -130,7 +132,6 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -148,15 +149,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( @@ -169,7 +167,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 1381ebd06..cdb767da1 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -102,6 +102,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: sample_size=args.sample_size, online_sample_size=args.online_sample_size, target_sample_size=args.target_sample_size, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: IQN = IQN( policy=policy, @@ -126,7 +128,6 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -144,15 +145,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( @@ -165,7 +163,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index b7fa509da..f21e7a442 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -93,7 +93,11 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: ) optim = AdamOptimizerFactory(lr=args.lr) policy = QRDQNPolicy( - model=net, action_space=env.action_space, observation_space=env.observation_space + model=net, + action_space=env.action_space, + observation_space=env.observation_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: QRDQN = QRDQN( policy=policy, @@ -119,7 +123,6 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -137,15 +140,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( @@ -158,7 +158,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 4877ef392..793a163f0 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -109,6 +109,8 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: RainbowDQN = RainbowDQN( policy=policy, @@ -134,7 +136,6 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: # collectors train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) - # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -152,12 +153,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) + policy.set_eps_training(0.1 * args.eps_train) # beta annealing, just a demo if args.prioritized_replay: if env_step <= 10000: @@ -168,9 +169,6 @@ def train_fn(epoch: int, env_step: int) -> None: beta = args.beta_final buf.set_beta(beta) - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) - def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") @@ -215,7 +213,6 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: batch_size=args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index d1e8c5839..cc2be5b97 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -109,6 +109,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: policy = DQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: DQN = DQN( policy=policy, @@ -177,15 +179,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.1 * args.eps_train) # train result = icm_algorithm.run_training( @@ -199,7 +198,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: batch_size=args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 21fa9930c..86e052968 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -100,6 +100,8 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: policy = QRDQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) algorithm: QRDQN = QRDQN( policy=policy, @@ -125,7 +127,6 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: train_collector.reset() test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) test_collector.reset() - # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") @@ -141,15 +142,12 @@ def stop_fn(mean_rewards: float) -> bool: def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: - policy.set_eps(args.eps_train) + policy.set_eps_training(args.eps_train) elif env_step <= 50000: eps = args.eps_train - (env_step - 10000) / 40000 * (0.9 * args.eps_train) - policy.set_eps(eps) + policy.set_eps_training(eps) else: - policy.set_eps(0.1 * args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - policy.set_eps(args.eps_test) + policy.set_eps_training(0.1 * args.eps_train) # train result = algorithm.run_training( @@ -162,7 +160,6 @@ def test_fn(epoch: int, env_step: int | None) -> None: episode_per_test=args.test_num, batch_size=args.batch_size, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, @@ -174,7 +171,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) - policy.set_eps(0.2) + policy.set_eps_inference(0.2) collector = Collector[CollectStats](algorithm, test_envs, buf, exploration_noise=True) collector.reset() collector_stats = collector.collect(n_step=args.buffer_size) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index cf505d6a1..e29114974 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -101,6 +101,8 @@ def get_agents( policy = DQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) agent: DQN = DQN( policy=policy, @@ -153,12 +155,6 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return False - def train_fn(epoch: int, env_step: int) -> None: - [agent.set_eps(args.eps_train) for agent in marl_algorithm.policy.policies.values()] - - def test_fn(epoch: int, env_step: int | None) -> None: - [agent.set_eps(args.eps_test) for agent in marl_algorithm.policy.policies.values()] - def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] @@ -172,8 +168,6 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: step_per_collect=args.step_per_collect, episode_per_test=args.test_num, batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, update_per_step=args.update_per_step, @@ -192,7 +186,6 @@ def watch(args: argparse.Namespace = get_args(), policy: Algorithm | None = None "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 7ae46c68e..b5308fd1c 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -124,6 +124,8 @@ def get_agents( algorithm = DQNPolicy( model=net, action_space=env.action_space, + eps_training=args.eps_train, + eps_inference=args.eps_test, ) agent_learn = DQN( policy=algorithm, @@ -201,12 +203,6 @@ def save_best_fn(policy: Algorithm) -> None: def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.win_rate - def train_fn(epoch: int, env_step: int) -> None: - marl_algorithm.get_algorithm(player_agent_id).policy.set_eps(args.eps_train) - - def test_fn(epoch: int, env_step: int | None) -> None: - marl_algorithm.get_algorithm(player_agent_id).policy.set_eps(args.eps_test) - def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, args.agent_id - 1] @@ -220,8 +216,6 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: step_per_collect=args.step_per_collect, episode_per_test=args.test_num, batch_size=args.batch_size, - train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, update_per_step=args.update_per_step, @@ -241,7 +235,6 @@ def watch( ) -> None: env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) - policy.algorithms[agents[args.agent_id - 1]].policy.set_eps(args.eps_test) collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render, reset_before_collect=True) result.pprint_asdict() diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 8d785b7a8..2a35049ef 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -460,7 +460,7 @@ def _create_policy( return self._create_policy_from_args( constructor=DQNPolicy, params_dict=params, - policy_params=[], + policy_params=["eps_training", "eps_inference"], model=model, action_space=action_space, observation_space=observation_space, @@ -482,7 +482,13 @@ def _create_policy( return self._create_policy_from_args( IQNPolicy, params, - ["sample_size", "online_sample_size", "target_sample_size"], + [ + "sample_size", + "online_sample_size", + "target_sample_size", + "eps_training", + "eps_inference", + ], model=model, action_space=action_space, observation_space=observation_space, diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 0769537c7..1c030665a 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -477,6 +477,23 @@ class QLearningOffPolicyParams(Params, ParamsMixinSingleModel): """the target network update frequency (0 if no target network is to be used)""" reward_normalization: bool = False """whether to normalize the returns to Normal(0, 1)""" + eps_training: float = 0.0 + """ + the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + eps_inference: float = 0.0 + """ + the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 93dbe5c0b..b462f6fd4 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -9,6 +9,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger from tianshou.policy import DQN, Algorithm +from tianshou.policy.modelfree.dqn import DQNPolicy TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) log = logging.getLogger(__name__) @@ -86,12 +87,13 @@ class EpochTrainCallbackDQNSetEps(EpochTrainCallback): stage in each epoch. """ - def __init__(self, eps_test: float): - self.eps_test = eps_test + def __init__(self, eps: float): + self.eps = eps def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) - algorithm.policy.set_eps(self.eps_test) + policy: DQNPolicy = algorithm.policy + policy.set_eps_training(self.eps) class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback): @@ -106,6 +108,7 @@ def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) + policy: DQNPolicy = algorithm.policy logger = context.logger if env_step <= self.decay_steps: eps = self.eps_train - env_step / self.decay_steps * ( @@ -113,7 +116,7 @@ def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: ) else: eps = self.eps_train_final - algorithm.policy.set_eps(eps) + policy.set_eps_training(eps) logger.write("train/env_step", env_step, {"train/eps": eps}) @@ -122,12 +125,13 @@ class EpochTestCallbackDQNSetEps(EpochTestCallback): stage in each epoch. """ - def __init__(self, eps_test: float): - self.eps_test = eps_test + def __init__(self, eps: float): + self.eps = eps def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) - algorithm.policy.set_eps(self.eps_test) + policy: DQNPolicy = algorithm.policy + policy.set_eps_inference(self.eps) class EpochStopCallbackRewardThreshold(EpochStopCallback): diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index eecdaf88a..b65a21d14 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -42,16 +42,31 @@ def __init__( model: BranchingNet, action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, + eps_training: float = 0.0, + eps_inference: float = 0.0, ): """ :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. :param action_space: the environment's action space :param observation_space: the environment's observation space. + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). """ super().__init__( model=model, action_space=action_space, observation_space=observation_space, + eps_training=eps_training, + eps_inference=eps_inference, ) def forward( @@ -76,10 +91,11 @@ def add_exploration_noise( act: TArrOrActBatch, batch: ObsBatchProtocol, ) -> TArrOrActBatch: + eps = self.eps_training if self.is_within_training_step else self.eps_inference # TODO: This looks problematic; the non-array case is silently ignored - if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): + if isinstance(act, np.ndarray) and not np.isclose(eps, 0.0): bsz = len(act) - rand_mask = np.random.rand(bsz) < self.eps + rand_mask = np.random.rand(bsz) < eps rand_act = np.random.randint( low=0, high=self.model.action_per_branch, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 87810d9b2..badddf57a 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -33,6 +33,8 @@ def __init__( num_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0, + eps_training: float = 0.0, + eps_inference: float = 0.0, ): """ :param model: a model following the rules (s_B -> action_values_BA) @@ -42,10 +44,25 @@ def __init__( Default to -10.0. :param v_max: the value of the largest atom in the support set. Default to 10.0. + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). """ assert isinstance(action_space, gym.spaces.Discrete) super().__init__( - model=model, action_space=action_space, observation_space=observation_space + model=model, + action_space=action_space, + observation_space=observation_space, + eps_training=eps_training, + eps_inference=eps_inference, ) assert num_atoms > 1, f"num_atoms should be greater than 1 but got: {num_atoms}" assert v_min < v_max, f"v_max should be larger than v_min, but got {v_min=} and {v_max=}" diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index be61674ad..a8f67b934 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -46,11 +46,24 @@ def __init__( model: TModel, action_space: gym.spaces.Space, observation_space: gym.Space | None = None, + eps_training: float = 0.0, + eps_inference: float = 0.0, ) -> None: """ :param model: a model mapping (obs, state, info) to action_values_BA. :param action_space: the environment's action space :param observation_space: the environment's observation space. + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). """ super().__init__( action_space=action_space, @@ -60,11 +73,33 @@ def __init__( ) self.model = model self.max_action_num: int | None = None - self.eps = 0.0 + self.eps_training = eps_training + self.eps_inference = eps_inference - def set_eps(self, eps: float) -> None: - """Set the eps for epsilon-greedy exploration.""" - self.eps = eps + def set_eps_training(self, eps: float) -> None: + """ + Sets the epsilon value for epsilon-greedy exploration during training. + + :param eps: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + self.eps_training = eps + + def set_eps_inference(self, eps: float) -> None: + """ + Sets the epsilon value for epsilon-greedy exploration during inference. + + :param eps: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ + self.eps_inference = eps def forward( self, @@ -126,14 +161,15 @@ def add_exploration_noise( act: TArrOrActBatch, batch: ObsBatchProtocol, ) -> TArrOrActBatch: + eps = self.eps_training if self.is_within_training_step else self.eps_inference # TODO: This looks problematic; the non-array case is silently ignored - if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): - bsz = len(act) - rand_mask = np.random.rand(bsz) < self.eps + if isinstance(act, np.ndarray) and not np.isclose(eps, 0.0): + batch_size = len(act) + rand_mask = np.random.rand(batch_size) < eps assert ( self.max_action_num is not None ), "Can't call this method before max_action_num was set in first forward" - q = np.random.rand(bsz, self.max_action_num) # [0, 1] + q = np.random.rand(batch_size, self.max_action_num) # [0, 1] if hasattr(batch.obs, "mask"): q += batch.obs.mask rand_act = q.argmax(axis=1) diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 0ef47f93e..92bc33e6a 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -34,6 +34,8 @@ def __init__( fraction_model: FractionProposalNetwork, action_space: gym.spaces.Space, observation_space: gym.Space | None = None, + eps_training: float = 0.0, + eps_inference: float = 0.0, ): """ :param model: a model following the rules (s_B -> action_values_BA) @@ -41,12 +43,25 @@ def __init__( proposing fractions/quantiles given state. :param action_space: the environment's action space :param observation_space: the environment's observation space. + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). """ assert isinstance(action_space, gym.spaces.Discrete) super().__init__( model=model, action_space=action_space, observation_space=observation_space, + eps_training=eps_training, + eps_inference=eps_inference, ) self.fraction_model = fraction_model diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 652a2d422..a112491a9 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -36,7 +36,28 @@ def __init__( online_sample_size: int = 8, target_sample_size: int = 8, observation_space: gym.Space | None = None, + eps_training: float = 0.0, + eps_inference: float = 0.0, ) -> None: + """ + :param model: + :param action_space: the environment's action space + :param sample_size: + :param online_sample_size: + :param target_sample_size: + :param observation_space: the environment's observation space + :param eps_training: the epsilon value for epsilon-greedy exploration during training. + When collecting data for training, this is the probability of choosing a random action + instead of the action chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). + """ assert isinstance(action_space, gym.spaces.Discrete) assert sample_size > 1, f"sample_size should be greater than 1 but got: {sample_size}" assert ( @@ -49,6 +70,8 @@ def __init__( model=model, action_space=action_space, observation_space=observation_space, + eps_training=eps_training, + eps_inference=eps_inference, ) self.sample_size = sample_size self.online_sample_size = online_sample_size From 9f5e055de17b7f8aea74d17e8825555cb7939329 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 22 Apr 2025 23:36:29 +0200 Subject: [PATCH 085/230] Fix return type --- tianshou/policy/modelfree/dqn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index a8f67b934..3bcacb7a6 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -16,6 +16,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) +from tianshou.policy import Algorithm from tianshou.policy.base import ( LaggedNetworkFullUpdateAlgorithmMixin, OffPolicyAlgorithm, @@ -231,7 +232,7 @@ def __init__( self._add_lagged_network(self.policy.model) if self.use_target_network else None ) - def _create_policy_optimizer(self, optim: OptimizerFactory) -> torch.optim.Optimizer: + def _create_policy_optimizer(self, optim: OptimizerFactory) -> Algorithm.Optimizer: return self._create_optimizer(self.policy, optim) @property From de1e8b7a7d742773982521b90af57c1f7469d977 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 3 May 2025 01:39:25 +0200 Subject: [PATCH 086/230] v2: Fix mypy/typing issues --- tianshou/highlevel/params/lr_scheduler.py | 9 ++++++--- tianshou/policy/modelfree/a2c.py | 4 ++-- tianshou/policy/modelfree/rainbow.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 6d50b88c6..47ea04897 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -19,10 +19,13 @@ def __init__(self, training_config: TrainingConfig): self.training_config = training_config def create_lr_scheduler_factory(self) -> LRSchedulerFactory: - if self.training_config.step_per_epoch is None: + if ( + self.training_config.step_per_epoch is None + or self.training_config.step_per_collect is None + ): raise ValueError( - f"{self.__class__.__name__} requires step_per_epoch to be set " - f"in order for the total number of update steps to be computable" + f"{self.__class__.__name__} requires step_per_epoch and step_per_collect to be set " + f"in order for the scheduling to be well-defined." ) return LRSchedulerFactoryLinear( num_epochs=self.training_config.num_epochs, diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index ddf2d3f02..1ff2151fc 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -10,7 +10,7 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy.base import ( OnPolicyAlgorithm, - TrainingStats, + TrainingStats, TTrainingStats, ) from tianshou.policy.modelfree.pg import ActorPolicy, TPGTrainingStats from tianshou.policy.optim import OptimizerFactory @@ -32,7 +32,7 @@ class A2CTrainingStats(TrainingStats): class ActorCriticOnPolicyAlgorithm( - OnPolicyAlgorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats], ABC + OnPolicyAlgorithm[ActorPolicy, TTrainingStats], Generic[TTrainingStats], ABC ): """Abstract base class for actor-critic algorithms that use generalized advantage estimation (GAE).""" diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 3e95bcf97..83b6605ab 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -48,6 +48,6 @@ def _update_with_batch( batch: RolloutBatchProtocol, ) -> TRainbowTrainingStats: self._sample_noise(self.policy.model) - if self.use_target_network and self._sample_noise(self.model_old): + if self.use_target_network and self._sample_noise(self.model_old): # type: ignore self.model_old.train() # so that NoisyLinear takes effect return super()._update_with_batch(batch) From 7a72da948feb1efc553b5d08f5550a413d0c0c79 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 2 May 2025 23:07:45 +0200 Subject: [PATCH 087/230] v2: Clean up handling of modules that define attribute `output_dim` introducing the explicit base class `ModuleWithVectorOutput` Interfaces where one could specify either a module with `output_dim` or additionally provide the output dimension as an argument were changed to use `ModuleWithVectorOutput`. The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance (via adaptation if necessary). --- CHANGELOG.md | 6 + examples/box2d/bipedal_bdq.py | 4 +- test/offline/test_gail.py | 3 - test/pettingzoo/pistonball_continuous.py | 18 +- tianshou/data/collector.py | 1 - tianshou/env/atari/atari_network.py | 27 ++- tianshou/highlevel/module/actor.py | 6 +- tianshou/highlevel/module/intermediate.py | 7 + tianshou/highlevel/module/special.py | 3 +- tianshou/policy/imitation/gail.py | 12 +- tianshou/policy/modelfree/a2c.py | 5 +- tianshou/utils/net/common.py | 236 ++++++++++++---------- tianshou/utils/net/continuous.py | 51 ++--- tianshou/utils/net/discrete.py | 76 +++---- 14 files changed, 231 insertions(+), 224 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c8101cb9..b28d843e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -165,6 +165,12 @@ `IntrinsicCuriosityModule`, `Net`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, `RecurrentActorProb`, `RecurrentCritic`, `VAE` * (Peripheral change:) Require the use of keyword arguments for the constructors of all of these classes + * Clean up handling of modules that define attribute `output_dim`, introducing the explicit base class + `ModuleWithVectorOutput` + * Interfaces where one could specify either a module with `output_dim` or additionally provide the output + dimension as an argument were changed to use `ModuleWithVectorOutput`. + * The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance + (via adaptation if necessary). ## Unreleased diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 98ef08bd7..e87d0f9d7 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -55,7 +55,7 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def test_bdq(args: argparse.Namespace = get_args()) -> None: +def run_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) @@ -173,4 +173,4 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay if __name__ == "__main__": - test_bdq(get_args()) + run_bdq(get_args()) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 8d0b0dc59..25c8f424b 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -83,10 +83,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 477549067..e6832d2b6 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -21,10 +21,11 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ModuleWithVectorOutput from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic -class DQNet(nn.Module): +class DQNet(ModuleWithVectorOutput): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -38,12 +39,7 @@ def __init__( w: int, device: str | int | torch.device = "cpu", ) -> None: - super().__init__() - self.device = device - self.c = c - self.h = h - self.w = w - self.net = nn.Sequential( + net = nn.Sequential( nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=4, stride=2), @@ -53,7 +49,13 @@ def __init__( nn.Flatten(), ) with torch.no_grad(): - self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) + output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) + super().__init__(int(output_dim)) + self.device = device + self.c = c + self.h = h + self.w = w + self.net = net def forward( self, diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 5c8a14537..9eaa5487b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -32,7 +32,6 @@ from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import Algorithm from tianshou.policy.base import Policy, episode_mc_return_to_go -from tianshou.policy.base import episode_mc_return_to_go from tianshou.utils.determinism import TraceLogger from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index 338b72d19..f14d73670 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -28,13 +28,11 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0. return layer -class ScaledObsInputModule(torch.nn.Module): +class ScaledObsInputModule(NetBase): def __init__(self, module: NetBase, denom: float = 255.0) -> None: - super().__init__() + super().__init__(module.get_output_dim()) self.module = module self.denom = denom - # This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim) - self.output_dim = module.output_dim def forward( self, @@ -74,8 +72,7 @@ def __init__( raise ValueError( "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", ) - super().__init__() - self.net = nn.Sequential( + net = nn.Sequential( layer_init(nn.Conv2d(c, 32, kernel_size=8, stride=4)), nn.ReLU(inplace=True), layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), @@ -85,25 +82,27 @@ def __init__( nn.Flatten(), ) with torch.no_grad(): - base_cnn_output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])) + base_cnn_output_dim = int(np.prod(net(torch.zeros(1, c, h, w)).shape[1:])) if not features_only: action_dim = int(np.prod(action_shape)) - self.net = nn.Sequential( - self.net, + net = nn.Sequential( + net, layer_init(nn.Linear(base_cnn_output_dim, 512)), nn.ReLU(inplace=True), layer_init(nn.Linear(512, action_dim)), ) - self.output_dim = action_dim + output_dim = action_dim elif output_dim_added_layer is not None: - self.net = nn.Sequential( - self.net, + net = nn.Sequential( + net, layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)), nn.ReLU(inplace=True), ) - self.output_dim = output_dim_added_layer + output_dim = output_dim_added_layer else: - self.output_dim = base_cnn_output_dim + output_dim = base_cnn_output_dim + super().__init__(output_dim) + self.net = net def forward( self, diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index e849d6f9c..297109829 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -24,7 +24,7 @@ ) from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import Actor, ModuleType, Net +from tianshou.utils.net.common import Actor, ModuleType, ModuleWithVectorOutput, Net class ContinuousActorType(Enum): @@ -269,5 +269,7 @@ def __init__(self, actor_factory: ActorFactory): def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: actor = self.actor_factory.create_module(envs, device) - assert isinstance(actor, Actor) + assert isinstance( + actor, ModuleWithVectorOutput + ), "Actor factory must produce an actor with known vector output dimension" return IntermediateModule(actor, actor.get_output_dim()) diff --git a/tianshou/highlevel/module/intermediate.py b/tianshou/highlevel/module/intermediate.py index 62bf3843f..08b32641a 100644 --- a/tianshou/highlevel/module/intermediate.py +++ b/tianshou/highlevel/module/intermediate.py @@ -6,6 +6,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import ModuleFactory, TDevice +from tianshou.utils.net.common import ModuleWithVectorOutput @dataclass @@ -15,6 +16,12 @@ class IntermediateModule: module: torch.nn.Module output_dim: int + def get_module_with_vector_output(self) -> ModuleWithVectorOutput: + if isinstance(self.module, ModuleWithVectorOutput): + return self.module + else: + return ModuleWithVectorOutput.from_module(self.module, self.output_dim) + class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): """Factory for the generation of a module which computes an intermediate representation.""" diff --git a/tianshou/highlevel/module/special.py b/tianshou/highlevel/module/special.py index 6d119d739..b36b26d9e 100644 --- a/tianshou/highlevel/module/special.py +++ b/tianshou/highlevel/module/special.py @@ -22,9 +22,8 @@ def __init__( def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork: preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device) return ImplicitQuantileNetwork( - preprocess_net=preprocess_net.module, + preprocess_net=preprocess_net.get_module_with_vector_output(), action_shape=envs.get_action_shape(), hidden_sizes=self.hidden_sizes, num_cosines=self.num_cosines, - preprocess_net_output_dim=preprocess_net.output_dim, ).to(device) diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index e6a27a5fd..890ed192c 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -16,6 +16,7 @@ from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.policy.optim import OptimizerFactory +from tianshou.utils.net.common import ModuleWithVectorOutput from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic from tianshou.utils.torch_utils import torch_device @@ -59,7 +60,8 @@ def __init__( reward_normalization: bool = False, ) -> None: r""" - :param policy: the policy. + :param policy: the policy (which must use an actor with known output dimension, i.e. + any Tianshou `Actor` implementation or other subclass of `ModuleWithVectorOutput`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the actor and critic networks. :param expert_buffer: the replay buffer containing expert experience. @@ -106,10 +108,10 @@ def __init__( self.disc_optim = self._create_optimizer(self.disc_net, disc_optim) self.disc_update_num = disc_update_num self.expert_buffer = expert_buffer - # TODO: This violates the type requirement; nn.Module does not necessarily have output_dim! - # Use IntermediateModule or perhaps a class more general than BaseActor which defines - # only the output dimension? - self.action_dim = self.policy.actor.output_dim + actor = self.policy.actor + if not isinstance(actor, ModuleWithVectorOutput): + raise TypeError("GAIL requires the policy to use an actor with known output dimension.") + self.action_dim = actor.get_output_dim() def preprocess_batch( self, diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 1ff2151fc..2162115c4 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -10,9 +10,10 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy.base import ( OnPolicyAlgorithm, - TrainingStats, TTrainingStats, + TrainingStats, + TTrainingStats, ) -from tianshou.policy.modelfree.pg import ActorPolicy, TPGTrainingStats +from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ActorCritic diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index f77771d6c..afe403a4d 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -47,30 +47,49 @@ def miniblock( return layers -class MLP(nn.Module): - """Simple MLP backbone. - - Create a MLP of size input_dim * hidden_sizes[0] * hidden_sizes[1] * ... - * hidden_sizes[-1] * output_dim +class ModuleWithVectorOutput(nn.Module): + """ + A module that outputs a vector of a known size. - :param input_dim: dimension of the input vector. - :param output_dim: dimension of the output vector. If set to 0, there - is no final linear layer. - :param hidden_sizes: shape of MLP passed in as a list, not including - input_dim and output_dim. - :param norm_layer: use which normalization before activation, e.g., - ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. - You can also pass a list of normalization modules with the same length - of hidden_sizes, to use different normalization module in different - layers. Default to no normalization. - :param activation: which activation to use after each layer, can be both - the same activation for all layers if passed in nn.Module, or different - activation for different Modules if passed in a list. Default to - nn.ReLU. - :param linear_layer: use this module as linear layer. Default to nn.Linear. - :param flatten_input: whether to flatten input data. Default to True. + Use `from_module` to adapt a module to this interface. """ + def __init__(self, output_dim: int) -> None: + """:param output_dim: the dimension of the output vector.""" + super().__init__() + self.output_dim = output_dim + + @staticmethod + def from_module(module: nn.Module, output_dim: int) -> "ModuleWithVectorOutput": + """ + :param module: the module to adapt. + :param output_dim: dimension of the output vector produced by the module. + """ + return ModuleWithVectorOutputAdapter(module, output_dim) + + def get_output_dim(self) -> int: + """:return: the dimension of the output vector.""" + return self.output_dim + + +class ModuleWithVectorOutputAdapter(ModuleWithVectorOutput): + """Adapts a module with vector output to provide the :class:`ModuleWithVectorOutput` interface.""" + + def __init__(self, module: nn.Module, output_dim: int) -> None: + """ + :param module: the module to adapt. + :param output_dim: the dimension of the output vector produced by the module. + """ + super().__init__(output_dim) + self.module = module + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.module(*args, **kwargs) + + +class MLP(ModuleWithVectorOutput): + """Simple MLP backbone.""" + def __init__( self, *, @@ -84,7 +103,24 @@ def __init__( linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, ) -> None: - super().__init__() + """ + :param input_dim: dimension of the input vector. + :param output_dim: dimension of the output vector. If set to 0, there + is no explicit final linear layer and the output dimension is the last hidden layer's dimension. + :param hidden_sizes: shape of MLP passed in as a list, not including + input_dim and output_dim. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param linear_layer: use this module as linear layer. Default to nn.Linear. + :param flatten_input: whether to flatten input data. Default to True. + """ if norm_layer: if isinstance(norm_layer, list): assert len(norm_layer) == len(hidden_sizes) @@ -129,7 +165,7 @@ def __init__( model += miniblock(in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer) if output_dim > 0: model += [linear_layer(hidden_sizes[-1], output_dim)] - self.output_dim = output_dim or hidden_sizes[-1] + super().__init__(output_dim or hidden_sizes[-1]) self.model = nn.Sequential(*model) self.flatten_input = flatten_input @@ -145,8 +181,8 @@ def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: TRecurrentState = TypeVar("TRecurrentState", bound=Any) -class NetBase(nn.Module, Generic[TRecurrentState], ABC): - """Interface for NNs used in policies.""" +class PolicyForwardInterface(Generic[TRecurrentState], ABC): + """Defines the `forward` interface for neural networks used in policies.""" @abstractmethod def forward( @@ -158,6 +194,10 @@ def forward( pass +class NetBase(ModuleWithVectorOutput, PolicyForwardInterface[TRecurrentState], ABC): + """Base class for NNs used in policies which produce vector outputs.""" + + class Net(NetBase[Any]): """Wrapper of MLP to support more specific DRL usage. @@ -217,21 +257,14 @@ def __init__( dueling_param: tuple[dict[str, Any], dict[str, Any]] | None = None, linear_layer: TLinearLayer = nn.Linear, ) -> None: - super().__init__() - self.softmax = softmax - self.num_atoms = num_atoms - self.Q: MLP | None = None - self.V: MLP | None = None - input_dim = int(np.prod(state_shape)) action_dim = int(np.prod(action_shape)) * num_atoms if concat: input_dim += action_dim - self.use_dueling = dueling_param is not None - output_dim = action_dim if not self.use_dueling and not concat else 0 - self.model = MLP( + use_dueling = dueling_param is not None + model = MLP( input_dim=input_dim, - output_dim=output_dim, + output_dim=action_dim if not use_dueling and not concat else 0, hidden_sizes=hidden_sizes, norm_layer=norm_layer, norm_args=norm_args, @@ -239,7 +272,9 @@ def __init__( act_args=act_args, linear_layer=linear_layer, ) - if self.use_dueling: # dueling DQN + Q: MLP | None = None + V: MLP | None = None + if use_dueling: # dueling DQN assert dueling_param is not None kwargs_update = { "input_dim": self.model.output_dim, @@ -250,10 +285,18 @@ def __init__( q_kwargs["output_dim"] = 0 if concat else action_dim v_kwargs["output_dim"] = 0 if concat else num_atoms - self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) - self.output_dim = self.Q.output_dim + Q, V = MLP(**q_kwargs), MLP(**v_kwargs) + output_dim = Q.output_dim else: - self.output_dim = self.model.output_dim + output_dim = model.output_dim + + super().__init__(output_dim) + self.use_dueling = use_dueling + self.softmax = softmax + self.num_atoms = num_atoms + self.model = model + self.Q = Q + self.V = V def forward( self, @@ -299,7 +342,8 @@ def __init__( action_shape: TActionShape, hidden_layer_size: int = 128, ) -> None: - super().__init__() + output_dim = int(np.prod(action_shape)) + super().__init__(output_dim) self.nn = nn.LSTM( input_size=hidden_layer_size, hidden_size=hidden_layer_size, @@ -307,7 +351,7 @@ def __init__( batch_first=True, ) self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size) - self.fc2 = nn.Linear(hidden_layer_size, int(np.prod(action_shape))) + self.fc2 = nn.Linear(hidden_layer_size, output_dim) def forward( self, @@ -445,32 +489,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -# TODO: fix docstring -class BranchingNet(NetBase[Any]): +class BranchingNet(nn.Module, PolicyForwardInterface[Any]): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module - and action "branches" one for each dimension.It allows for a linear scaling + and action "branches" one for each dimension. It allows for a linear scaling of Q-value the output w.r.t. the number of dimensions in the action space. - For more info please refer to: arXiv:1711.08946. - :param state_shape: int or a sequence of int of the shape of state. - :param action_shape: int or a sequence of int of the shape of action. - :param action_peer_branch: int or a sequence of int of the number of actions in - each dimension. - :param common_hidden_sizes: shape of the common MLP network passed in as a list. - :param value_hidden_sizes: shape of the value MLP network passed in as a list. - :param action_hidden_sizes: shape of the action MLP network passed in as a list. - :param norm_layer: use which normalization before activation, e.g., - ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. - You can also pass a list of normalization modules with the same length - of hidden_sizes, to use different normalization module in different - layers. Default to no normalization. - :param activation: which activation to use after each layer, can be both - the same activation for all layers if passed in nn.Module, or different - activation for different Modules if passed in a list. Default to - nn.ReLU. - :param softmax: whether to apply a softmax layer over the last layer's - output. + + This network architecture efficiently handles environments with multiple independent + action dimensions by using a branching structure. Instead of representing all action + combinations (which grows exponentially), it represents each action dimension separately + (linear scaling). + For example, if there are 3 actions with 3 possible values each, then we would normally + need to consider 3^4 = 81 unique actions, whereas with this architecture, we can instead + use 3 branches with 4 actions per dimension, resulting in 3 * 4 = 12 values to be considered. + + Common use cases include multi-joint robotic control tasks, where each joint can be controlled + independently. + + For more information, please refer to: arXiv:1711.08946. """ def __init__( @@ -487,7 +524,30 @@ def __init__( activation: ModuleType | None = nn.ReLU, act_args: ArgsType | None = None, ) -> None: - super().__init__() + """ + :param state_shape: int or a sequence of int of the shape of state. + :param num_branches: number of action dimensions in the environment. + Each branch represents one independent action dimension. + For example, in a robot with 7 joints, you would set this to 7. + :param action_per_branch: Number of possible discrete values for each action dimension. + For example, if each joint can have 3 positions (left, center, right), + you would set this to 3. + :param common_hidden_sizes: shape of the common MLP network passed in as a list. + :param value_hidden_sizes: shape of the value MLP network passed in as a list. + :param action_hidden_sizes: shape of the action MLP network passed in as a list. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param softmax: whether to apply a softmax layer over the last layer's + output. + """ + super().__init__(output_dim=10) common_hidden_sizes = common_hidden_sizes or [] value_hidden_sizes = value_hidden_sizes or [] action_hidden_sizes = action_hidden_sizes or [] @@ -602,15 +662,11 @@ def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: return decorator_fn, new_state_shape -class Actor(nn.Module, ABC): +class Actor(ModuleWithVectorOutput, ABC): @abstractmethod def get_preprocess_net(self) -> nn.Module: pass - @abstractmethod - def get_output_dim(self) -> int: - pass - @abstractmethod def forward( self, @@ -632,7 +688,11 @@ class RandomActor(Actor): """ def __init__(self, action_space: spaces.Box | spaces.Discrete) -> None: - super().__init__() + if isinstance(action_space, spaces.Discrete): + output_dim = action_space.n + else: + output_dim = np.prod(action_space.shape) + super().__init__(int(output_dim)) self._action_space = action_space self._space_info = ActionSpaceInfo.from_space(action_space) @@ -675,39 +735,3 @@ def compute_action_batch(self, obs: np.ndarray | torch.Tensor | BatchProtocol) - return np.random.randint(low=0, high=self.action_space.n, size=len(obs)) else: return self.forward(obs)[0] - - -def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: - """Gets the given attribute from the given object or takes the alternative value if it is not present. - If both are present, they are required to match. - - :param obj: the object from which to obtain the attribute value - :param attr_name: the attribute name - :param alt_value: the alternative value for the case where the attribute is not present, which cannot be None - if the attribute is not present - :return: the value - """ - v = getattr(obj, attr_name) - if v is not None: - if alt_value is not None and v != alt_value: - raise ValueError( - f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})", - ) - return v - else: - if alt_value is None: - raise ValueError( - f"Attribute '{attr_name}' of {obj} is not defined and no fallback given", - ) - return alt_value - - -def get_output_dim(module: nn.Module, alt_value: int | None) -> int: - """Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present. - If both are present, they must match. - - :param module: the module - :param alt_value: the alternative value - :return: the value - """ - return getattr_with_matching_alt_value(module, "output_dim", alt_value) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index d1e279bee..e5240e8d8 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -11,10 +11,9 @@ from tianshou.utils.net.common import ( MLP, Actor, - Net, + ModuleWithVectorOutput, TActionShape, TLinearLayer, - get_output_dim, ) from tianshou.utils.torch_utils import torch_device @@ -34,8 +33,6 @@ class ContinuousActorDeterministic(Actor): :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. :param max_action: the scale for the final action. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -44,16 +41,15 @@ class ContinuousActorDeterministic(Actor): def __init__( self, *, - preprocess_net: nn.Module | Net, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, - preprocess_net_output_dim: int | None = None, ) -> None: - super().__init__() + output_dim = int(np.prod(action_shape)) + super().__init__(output_dim) self.preprocess = preprocess_net - self.output_dim = int(np.prod(action_shape)) - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + input_dim = preprocess_net.get_output_dim() self.last = MLP( input_dim=input_dim, output_dim=self.output_dim, @@ -85,7 +81,7 @@ def forward( return action_BA, hidden_BH -class AbstractContinuousCritic(nn.Module, ABC): +class AbstractContinuousCritic(ModuleWithVectorOutput, ABC): @abstractmethod def forward( self, @@ -101,12 +97,10 @@ class ContinuousCritic(AbstractContinuousCritic): It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). - :param preprocess_net: a self-defined preprocess_net, see usage. + :param preprocess_net: the pre-processing network, which returns a vector of a known dimension. Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. :param linear_layer: use this module as linear layer. :param flatten_input: whether to flatten input data for the last layer. :param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before @@ -120,18 +114,16 @@ class ContinuousCritic(AbstractContinuousCritic): def __init__( self, *, - preprocess_net: nn.Module | Net, + preprocess_net: ModuleWithVectorOutput, hidden_sizes: Sequence[int] = (), - preprocess_net_output_dim: int | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, apply_preprocess_net_to_obs_only: bool = False, ) -> None: - super().__init__() + super().__init__(output_dim=1) self.preprocess = preprocess_net - self.output_dim = 1 self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + input_dim = preprocess_net.get_output_dim() self.last = MLP( input_dim=input_dim, output_dim=1, @@ -181,8 +173,7 @@ class ContinuousActorProb(Actor): Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. - :param preprocess_net: a self-defined preprocess_net, see usage. - Typically, an instance of :class:`~tianshou.utils.net.common.Net`. + :param preprocess_net: the pre-processing network, which returns a vector of a known dimension. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. @@ -190,8 +181,6 @@ class ContinuousActorProb(Actor): :param unbounded: whether to apply tanh activation on final logits. :param conditioned_sigma: True when sigma is calculated from the input, False when sigma is an independent parameter. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -200,40 +189,36 @@ class ContinuousActorProb(Actor): def __init__( self, *, - preprocess_net: nn.Module | Net, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, unbounded: bool = False, conditioned_sigma: bool = False, - preprocess_net_output_dim: int | None = None, ) -> None: - super().__init__() + output_dim = int(np.prod(action_shape)) + super().__init__(output_dim) if unbounded and not np.isclose(max_action, 1.0): warnings.warn("Note that max_action input will be discarded when unbounded is True.") max_action = 1.0 self.preprocess = preprocess_net - self.output_dim = int(np.prod(action_shape)) - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) - self.mu = MLP(input_dim=input_dim, output_dim=self.output_dim, hidden_sizes=hidden_sizes) + input_dim = preprocess_net.get_output_dim() + self.mu = MLP(input_dim=input_dim, output_dim=output_dim, hidden_sizes=hidden_sizes) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP( input_dim=input_dim, - output_dim=self.output_dim, + output_dim=output_dim, hidden_sizes=hidden_sizes, ) else: - self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) + self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1)) self.max_action = max_action self._unbounded = unbounded def get_preprocess_net(self) -> nn.Module: return self.preprocess - def get_output_dim(self) -> int: - return self.output_dim - def forward( self, obs: np.ndarray | torch.Tensor, diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index fdbca00d8..7464c0d03 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -7,7 +7,12 @@ from torch import nn from tianshou.data import Batch, to_torch -from tianshou.utils.net.common import MLP, Actor, Net, TActionShape, get_output_dim +from tianshou.utils.net.common import ( + MLP, + Actor, + ModuleWithVectorOutput, + TActionShape, +) from tianshou.utils.torch_utils import torch_device @@ -17,36 +22,30 @@ def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions class DiscreteActor(Actor): - """Simple actor network for discrete action spaces. - - :param preprocess_net: a self-defined preprocess_net. Typically, an instance of - :class:`~tianshou.utils.net.common.Net`. - :param action_shape: a sequence of int for the shape of action. - :param hidden_sizes: a sequence of int for constructing the MLP after - preprocess_net. Default to empty sequence (where the MLP now contains - only a single linear layer). - :param softmax_output: whether to apply a softmax layer over the last - layer's output. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. - - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - """ + """Simple actor network for discrete action spaces.""" def __init__( self, *, - preprocess_net: nn.Module | Net, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), softmax_output: bool = True, - preprocess_net_output_dim: int | None = None, ) -> None: - super().__init__() + """ + :param preprocess_net: the preprocessing network, which outputs a vector of a known dimension; + typically an instance of :class:`~tianshou.utils.net.common.Net`. + :param action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param softmax_output: whether to apply a softmax layer over the last + layer's output. + """ + output_dim = int(np.prod(action_shape)) + super().__init__(output_dim) self.preprocess = preprocess_net - self.output_dim = int(np.prod(action_shape)) - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + input_dim = preprocess_net.get_output_dim() self.last = MLP( input_dim=input_dim, output_dim=self.output_dim, @@ -57,9 +56,6 @@ def __init__( def get_preprocess_net(self) -> nn.Module: return self.preprocess - def get_output_dim(self) -> int: - return self.output_dim - def forward( self, obs: np.ndarray | torch.Tensor, @@ -84,17 +80,15 @@ def forward( return output_BA, hidden_BH -class DiscreteCritic(nn.Module): +class DiscreteCritic(ModuleWithVectorOutput): """Simple critic network for discrete action spaces. - :param preprocess_net: a self-defined preprocess_net. Typically, an instance of - :class:`~tianshou.utils.net.common.Net`. + :param preprocess_net: the preprocessing network, which outputs a vector of a known dimension; + typically an instance of :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param last_size: the output dimension of Critic network. Default to 1. - :param preprocess_net_output_dim: the output dimension of - `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`.. @@ -103,15 +97,13 @@ class DiscreteCritic(nn.Module): def __init__( self, *, - preprocess_net: nn.Module | Net, + preprocess_net: ModuleWithVectorOutput, hidden_sizes: Sequence[int] = (), last_size: int = 1, - preprocess_net_output_dim: int | None = None, ) -> None: - super().__init__() + super().__init__(output_dim=last_size) self.preprocess = preprocess_net - self.output_dim = last_size - input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + input_dim = preprocess_net.get_output_dim() self.last = MLP(input_dim=input_dim, output_dim=last_size, hidden_sizes=hidden_sizes) # TODO: make a proper interface! @@ -170,8 +162,6 @@ class ImplicitQuantileNetwork(DiscreteCritic): only a single linear layer). :param num_cosines: the number of cosines to use for cosine embedding. Default to 64. - :param preprocess_net_output_dim: the output dimension of - preprocess_net. .. note:: @@ -184,20 +174,18 @@ class ImplicitQuantileNetwork(DiscreteCritic): def __init__( self, *, - preprocess_net: nn.Module, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, - preprocess_net_output_dim: int | None = None, ) -> None: last_size = int(np.prod(action_shape)) super().__init__( preprocess_net=preprocess_net, hidden_sizes=hidden_sizes, last_size=last_size, - preprocess_net_output_dim=preprocess_net_output_dim, ) - self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) + self.input_dim = preprocess_net.get_output_dim() self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim) def forward( # type: ignore @@ -266,8 +254,6 @@ class FullQuantileFunction(ImplicitQuantileNetwork): only a single linear layer). :param num_cosines: the number of cosines to use for cosine embedding. Default to 64. - :param preprocess_net_output_dim: the output dimension of - preprocess_net. .. note:: @@ -278,18 +264,16 @@ class FullQuantileFunction(ImplicitQuantileNetwork): def __init__( self, *, - preprocess_net: nn.Module, + preprocess_net: ModuleWithVectorOutput, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), num_cosines: int = 64, - preprocess_net_output_dim: int | None = None, ) -> None: super().__init__( preprocess_net=preprocess_net, action_shape=action_shape, hidden_sizes=hidden_sizes, num_cosines=num_cosines, - preprocess_net_output_dim=preprocess_net_output_dim, ) def _compute_quantiles(self, obs: torch.Tensor, taus: torch.Tensor) -> torch.Tensor: From 24d7a4a35d3e2df79bc903abfe6544a0a6b330e5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 3 May 2025 13:53:21 +0200 Subject: [PATCH 088/230] v2: Update README: Algorithm abstraction, high-level example --- README.md | 69 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 8db202d57..5b6309527 100644 --- a/README.md +++ b/README.md @@ -180,20 +180,21 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page Atari and MuJoCo benchmark results can be found in the [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders respectively. **Our MuJoCo results reach or exceed the level of performance of most existing benchmarks.** -### Policy Interface +### Algorithm Abstraction -All algorithms implement the following, highly general API: +Reinforcement learning algorithms are build on abstractions for + * on-policy algorithms (`OnPolicyAlgorithm`), + * off-policy algorithms (`OffPolicyAlgorithm`), and + * offline algorithms (`OfflineAlgorithm`), -- `__init__`: initialize the policy; -- `forward`: compute actions based on given observations; -- `process_buffer`: process initial buffer, which is useful for some offline learning algorithms -- `process_fn`: preprocess data from the replay buffer (since we have reformulated _all_ algorithms to replay buffer-based algorithms); -- `learn`: learn from a given batch of data; -- `post_process_fn`: update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight); -- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`. +all of which clearly separate the core algorithm from the training process and the respective environment interactions. -The implementation of this API suffices for a new algorithm to be applicable within Tianshou, -making experimenation with new approaches particularly straightforward. +In each case, the implementation of an algorithm necessarily involves only the implementation of methods for + * pre-processing a batch of data, augmenting it with necessary information/sufficient statistics for learning (`preprocess_batch`), + * updating model parameters based on an augmented batch of data (`_update_with_batch`). + +The implementation of these methods suffices for a new algorithm to be applicable within Tianshou, +making experimentation with new approaches particularly straightforward. ## Quick Start @@ -203,14 +204,19 @@ Tianshou provides two API levels: - the procedural interface, which provides a maximum of control, especially for very advanced users and developers of reinforcement learning algorithms. In the following, let us consider an example application using the _CartPole_ gymnasium environment. -We shall apply the deep Q network (DQN) learning algorithm using both APIs. +We shall apply the deep Q-network (DQN) learning algorithm using both APIs. ### High-Level API -To get started, we need some imports. +In the high-level API, the basis for an RL experiment is an `ExperimentBuilder` +with which we can build the experiment we then seek to run. +Since we want to use DQN, we use the specialization `DQNExperimentBuilder`. + +As imports, we need only the experiment builder itself, the environment factory +and some configuration classes: ```python -from tianshou.highlevel.config import TrainingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.env import ( EnvFactoryRegistered, VectorEnvType, @@ -218,17 +224,10 @@ from tianshou.highlevel.env import ( from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig from tianshou.highlevel.params.policy_params import DQNParams from tianshou.highlevel.trainer import ( - EpochTestCallbackDQNSetEps, - EpochTrainCallbackDQNSetEps, - EpochStopCallbackRewardThreshold + EpochStopCallbackRewardThreshold, ) ``` -In the high-level API, the basis for an RL experiment is an `ExperimentBuilder` -with which we can build the experiment we then seek to run. -Since we want to use DQN, we use the specialization `DQNExperimentBuilder`. -The other imports serve to provide configuration options for our experiment. - The high-level API provides largely declarative semantics, i.e. the code is almost exclusively concerned with configuration that controls what to do (rather than how to do it). @@ -236,14 +235,19 @@ almost exclusively concerned with configuration that controls what to do ```python experiment = ( DQNExperimentBuilder( - EnvFactoryRegistered(task="CartPole-v1", train_seed=0, test_seed=0, venv_type=VectorEnvType.DUMMY), + EnvFactoryRegistered( + task="CartPole-v1", + venv_type=VectorEnvType.DUMMY, + train_seed=0, + test_seed=10, + ), ExperimentConfig( persistence_enabled=False, watch=True, watch_render=1 / 35, watch_num_episodes=100, ), - SamplingConfig( + OffPolicyTrainingConfig( num_epochs=10, step_per_epoch=10000, batch_size=64, @@ -260,11 +264,11 @@ experiment = ( discount_factor=0.9, estimation_step=3, target_update_freq=320, + eps_training=0.3, + eps_inference=0.0, ), ) .with_model_factory_default(hidden_sizes=(64, 64)) - .with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3)) - .with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0)) .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) .build() ) @@ -281,7 +285,7 @@ The experiment builder takes three arguments: episodes (`watch_num_episodes=100`). We have disabled persistence, because we do not want to save training logs, the agent or its configuration for future use. -- the sampling configuration, which controls fundamental training parameters, +- the training configuration, which controls fundamental training parameters, such as the total number of epochs we run the experiment for (`num_epochs=10`) and the number of environment steps each epoch shall consist of (`step_per_epoch=10000`). @@ -291,14 +295,15 @@ The experiment builder takes three arguments: collected in each collection step and after each collection step, we perform a training step, applying a gradient-based update based on a sample of data (`batch_size=64`) taken from the buffer of data that has been - collected. For further details, see the documentation of `SamplingConfig`. + collected. For further details, see the documentation of configuration class. We then proceed to configure some of the parameters of the DQN algorithm itself and of the neural network model we want to use. -A DQN-specific detail is the use of callbacks to configure the algorithm's -epsilon parameter for exploration. We want to use random exploration during rollouts -(train callback), but we don't when evaluating the agent's performance in the test -environments (test callback). +A DQN-specific detail is the way in which we control the epsilon parameter for +exploration. +We want to use random exploration during rollouts for training (`eps_training`), +but we don't when evaluating the agent's performance in the test environments +(`eps_inference`). Find the script in [examples/discrete/discrete_dqn_hl.py](examples/discrete/discrete_dqn_hl.py). Here's a run (with the training time cut short): From 66e9e2432233250942c9b2ba4a912bc059779b57 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 13:40:09 +0200 Subject: [PATCH 089/230] v2: Improve description of 'estimation_step' Some additional docstring/type improvements --- tianshou/highlevel/params/policy_params.py | 40 +++++++++++++++------- tianshou/policy/base.py | 3 +- tianshou/policy/imitation/discrete_bcq.py | 16 +++++++-- tianshou/policy/imitation/discrete_cql.py | 8 ++++- tianshou/policy/modelfree/bdqn.py | 8 ++++- tianshou/policy/modelfree/c51.py | 8 ++++- tianshou/policy/modelfree/ddpg.py | 10 ++++-- tianshou/policy/modelfree/discrete_sac.py | 8 ++++- tianshou/policy/modelfree/dqn.py | 16 +++++++-- tianshou/policy/modelfree/fqf.py | 8 ++++- tianshou/policy/modelfree/iqn.py | 8 ++++- tianshou/policy/modelfree/qrdqn.py | 8 ++++- tianshou/policy/modelfree/redq.py | 8 ++++- tianshou/policy/modelfree/sac.py | 8 ++++- tianshou/policy/modelfree/trpo.py | 1 + 15 files changed, 129 insertions(+), 29 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 1c030665a..eebdfd63a 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -247,6 +247,20 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return [ParamTransformerNoiseFactory("exploration_noise")] +@dataclass(kw_only=True) +class ParamsMixinEstimationStep: + estimation_step: int = 1 + """ + the number of future steps (> 0) to consider when computing temporal difference (TD) targets. + Controls the balance between TD learning and Monte Carlo methods: + Higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). + A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very + large values approach Monte Carlo-like estimation that uses complete episode returns. + """ + + @dataclass(kw_only=True) class PGParams(Params, ParamsMixinActionScaling, ParamsMixinSingleModel): discount_factor: float = 0.99 @@ -416,7 +430,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) -class _SACParams(Params, ParamsMixinActorAndDualCritics): +class _SACParams(Params, ParamsMixinActorAndDualCritics, ParamsMixinEstimationStep): tau: float = 0.005 """controls the contribution of the entropy term in the overall optimization objective, i.e. the desired amount of randomness in the optimal policy. @@ -433,8 +447,6 @@ class _SACParams(Params, ParamsMixinActorAndDualCritics): use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard auto-adjusted alpha. """ - estimation_step: int = 1 - """the number of steps to look ahead""" def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() @@ -466,13 +478,11 @@ class DiscreteSACParams(_SACParams): @dataclass(kw_only=True) -class QLearningOffPolicyParams(Params, ParamsMixinSingleModel): +class QLearningOffPolicyParams(Params, ParamsMixinSingleModel, ParamsMixinEstimationStep): discount_factor: float = 0.99 """ discount factor (gamma) for future rewards; must be in [0, 1] """ - estimation_step: int = 1 - """the number of steps to look ahead""" target_update_freq: int = 0 """the target network update frequency (0 if no target network is to be used)""" reward_normalization: bool = False @@ -540,6 +550,7 @@ class DDPGParams( ParamsMixinActorAndCritic, ParamsMixinExplorationNoise, ParamsMixinActionScaling, + ParamsMixinEstimationStep, ): tau: float = 0.005 """ @@ -548,9 +559,15 @@ class DDPGParams( Smaller tau means slower tracking and more stable learning. """ gamma: float = 0.99 - """discount factor (gamma) for future rewards; must be in [0, 1]""" - estimation_step: int = 1 - """the number of steps to look ahead.""" + """ + the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() @@ -574,8 +591,6 @@ class REDQParams(DDPGParams): use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard auto-adjusted alpha. """ - estimation_step: int = 1 - """the number of steps to look ahead""" actor_delay: int = 20 """the number of critic updates before an actor update""" deterministic_eval: bool = True @@ -597,6 +612,7 @@ class TD3Params( ParamsMixinActorAndDualCritics, ParamsMixinExplorationNoise, ParamsMixinActionScaling, + ParamsMixinEstimationStep, ): tau: float = 0.005 """ @@ -612,8 +628,6 @@ class TD3Params( """determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]""" update_actor_freq: int = 2 """the update frequency of actor network""" - estimation_step: int = 1 - """the number of steps to look ahead.""" def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index edebafac7..6faa5facf 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -734,7 +734,8 @@ def compute_nstep_return( n_step: int = 1, rew_norm: bool = False, ) -> BatchWithReturnsProtocol: - r"""Compute n-step return for Q-learning targets. + r""" + Computes the n-step return for Q-learning targets, adds it to the batch and returns the resulting batch. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 6f38f067f..ab34ebc34 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -122,12 +122,24 @@ def __init__( :param policy: the policy :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the target network update frequency. :param eval_eps: the epsilon-greedy noise added in evaluation. :param imitation_logits_penalty: regularization weight for imitation logits. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 695ddced1..5d6f70a7b 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -48,7 +48,13 @@ def __init__( :param discount_factor: in [0, 1]. :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index b65a21d14..eb578ba52 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -125,7 +125,13 @@ def __init__( :param policy: policy :param optim: the optimizer for the policy :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index badddf57a..d08b2dd0d 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -95,7 +95,13 @@ def __init__( :param policy: a policy following the rules (s -> action_values_BA) :param optim: a torch.optim for optimizing the policy. :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 7b74152ac..e1fe1f53a 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -183,7 +183,7 @@ class ActorCriticOffPolicyAlgorithm( def __init__( self, *, - policy: Any, + policy: TPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, @@ -318,7 +318,13 @@ def __init__( :param critic_optim: The optimizer for critic network. :param tau: Param for soft update of the target network. :param gamma: Discount factor, in [0, 1]. - :param estimation_step: The number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index d3a0d39d8..65b132801 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -105,7 +105,13 @@ def __init__( :param gamma: discount factor, in [0, 1]. :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). - :param estimation_step: the number of steps to look ahead for calculating + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() """ diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 3bcacb7a6..ac76f154c 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -205,7 +205,13 @@ def __init__( :param policy: the policy :param optim: the optimizer for the policy :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the frequency with which to update the weights of the target network; 0 if a target network shall not be used. :param reward_normalization: normalize the **returns** to Normal(0, 1). @@ -303,7 +309,13 @@ def __init__( :param policy: the policy :param optim: the optimizer for the policy :param discount_factor: in [0, 1]. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the frequency with which to update the weights of the target network; 0 if a target network shall not be used. :param reward_normalization: normalize the **returns** to Normal(0, 1). diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 92bc33e6a..9ac61922f 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -135,7 +135,13 @@ def __init__( :param discount_factor: in [0, 1]. :param num_fractions: the number of fractions to use. :param ent_coef: the coefficient for entropy loss. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index a112491a9..3783ae8c4 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -131,7 +131,13 @@ def __init__( :param discount_factor: in [0, 1]. :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 460f6b812..17326deed 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -55,7 +55,13 @@ def __init__( :param discount_factor: in [0, 1]. :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. - :param estimation_step: the number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). :param reward_normalization: normalize the **returns** to Normal(0, 1). diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index d0faa5603..38925280a 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -133,7 +133,13 @@ def __init__( :param gamma: Discount factor, in [0, 1]. :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). - :param estimation_step: The number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. :param actor_delay: Number of critic updates before an actor update. """ if target_mode not in ("min", "mean"): diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 208a065f1..af5b45528 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -230,7 +230,13 @@ def __init__( :param gamma: discount factor, in [0, 1]. :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). - :param estimation_step: The number of steps to look ahead. + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. """ super().__init__( policy=policy, diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 1231a83e0..95021d158 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -45,6 +45,7 @@ def __init__( reward_normalization: bool = False, ) -> None: """ + :param policy: the policy :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the critic network. :param max_kl: max kl-divergence used to constrain each actor network update. From 1c08dd8b3863f307fa39d79d6ea58a7cc02ced2e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 13:53:54 +0200 Subject: [PATCH 090/230] v2: Consistently use `gamma` instead of `discount_factor` and improve parameter description Most algorithms/components already used the attribute `gamma` internally but used `discount_factor` as a parameter name --- CHANGELOG.md | 2 + examples/atari/atari_c51.py | 2 +- examples/atari/atari_dqn.py | 2 +- examples/atari/atari_fqf.py | 2 +- examples/atari/atari_iqn.py | 2 +- examples/atari/atari_ppo.py | 2 +- examples/atari/atari_qrdqn.py | 2 +- examples/atari/atari_rainbow.py | 2 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/bipedal_bdq.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- examples/discrete/discrete_dqn.py | 2 +- examples/inverse/irl_gail.py | 2 +- examples/mujoco/mujoco_a2c.py | 2 +- examples/mujoco/mujoco_npg.py | 2 +- examples/mujoco/mujoco_ppo.py | 2 +- examples/mujoco/mujoco_reinforce.py | 2 +- examples/mujoco/mujoco_trpo.py | 2 +- examples/offline/atari_bcq.py | 2 +- examples/offline/atari_cql.py | 2 +- examples/offline/atari_crr.py | 2 +- examples/vizdoom/vizdoom_c51.py | 2 +- examples/vizdoom/vizdoom_ppo.py | 2 +- test/continuous/test_npg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_trpo.py | 2 +- test/discrete/test_a2c_with_il.py | 2 +- test/discrete/test_bdqn.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_fqf.py | 2 +- test/discrete/test_iqn.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo2.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_rainbow.py | 2 +- test/modelbased/test_dqn_icm.py | 2 +- test/modelbased/test_ppo_icm.py | 2 +- test/offline/gather_cartpole_data.py | 2 +- test/offline/test_discrete_bcq.py | 2 +- test/offline/test_discrete_cql.py | 2 +- test/offline/test_discrete_crr.py | 2 +- test/offline/test_gail.py | 2 +- test/pettingzoo/pistonball.py | 2 +- test/pettingzoo/pistonball_continuous.py | 2 +- test/pettingzoo/tic_tac_toe.py | 2 +- tianshou/highlevel/params/policy_params.py | 44 +++++++++++----------- tianshou/policy/base.py | 18 +++++++-- tianshou/policy/imitation/bcq.py | 8 +++- tianshou/policy/imitation/cql.py | 8 +++- tianshou/policy/imitation/discrete_bcq.py | 16 +++++--- tianshou/policy/imitation/discrete_cql.py | 12 ++++-- tianshou/policy/imitation/discrete_crr.py | 12 ++++-- tianshou/policy/imitation/gail.py | 12 ++++-- tianshou/policy/imitation/td3_bc.py | 8 +++- tianshou/policy/modelbased/psrl.py | 28 ++++++++++---- tianshou/policy/modelfree/a2c.py | 24 +++++++++--- tianshou/policy/modelfree/bdqn.py | 12 ++++-- tianshou/policy/modelfree/c51.py | 12 ++++-- tianshou/policy/modelfree/ddpg.py | 16 +++++++- tianshou/policy/modelfree/discrete_sac.py | 8 +++- tianshou/policy/modelfree/dqn.py | 28 +++++++++----- tianshou/policy/modelfree/fqf.py | 12 ++++-- tianshou/policy/modelfree/iqn.py | 12 ++++-- tianshou/policy/modelfree/npg.py | 12 ++++-- tianshou/policy/modelfree/pg.py | 26 +++++++++---- tianshou/policy/modelfree/ppo.py | 12 ++++-- tianshou/policy/modelfree/qrdqn.py | 12 ++++-- tianshou/policy/modelfree/redq.py | 8 +++- tianshou/policy/modelfree/sac.py | 8 +++- tianshou/policy/modelfree/td3.py | 16 +++++++- tianshou/policy/modelfree/trpo.py | 12 ++++-- 73 files changed, 339 insertions(+), 151 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b28d843e8..4a17c00da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,8 @@ `LRSchedulerFactory`). The parameter `lr_scheduler` has thus been removed from all algorithm constructors. * The flag `updating` has been removed (no internal usage, general usefulness questionable). + * Parameter name changes: + * `discount_factor` -> `gamma` (was already used internally almost everywhere) * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 1f6301af2..29cf69dc8 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -104,7 +104,7 @@ def main(args: argparse.Namespace = get_args()) -> None: algorithm: C51 = C51( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 251ad4af5..96bcf3234 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -121,7 +121,7 @@ def main(args: argparse.Namespace = get_args()) -> None: algorithm = DQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 578ae89ee..8731cb47d 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -114,7 +114,7 @@ def main(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, fraction_optim=fraction_optim, - discount_factor=args.gamma, + gamma=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, estimation_step=args.n_step, diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 888b61342..e209be5c1 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -114,7 +114,7 @@ def main(args: argparse.Namespace = get_args()) -> None: algorithm: IQN = IQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 3c3a7b258..3b48d3b83 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -148,7 +148,7 @@ def main(args: argparse.Namespace = get_args()) -> None: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index d4364511f..fbac8604d 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -107,7 +107,7 @@ def main(args: argparse.Namespace = get_args()) -> None: algorithm: QRDQN = QRDQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, target_update_freq=args.target_update_freq, diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 64757c922..9717c6769 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -128,7 +128,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: algorithm: C51 = RainbowDQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 6ca40ad72..9fa3d0242 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -85,7 +85,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: algorithm: DQN = DQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index e87d0f9d7..d53999193 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -111,7 +111,7 @@ def run_bdq(args: argparse.Namespace = get_args()) -> None: algorithm: BDQN = BDQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, target_update_freq=args.target_update_freq, ) # collector diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 7c8912e96..f5d8c217b 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -87,7 +87,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: algorithm: DQN = DQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 3e95af9be..2d23c1242 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -43,7 +43,7 @@ def main() -> None: algorithm: ts.policy.DQN = ts.policy.DQN( policy=policy, optim=optim, - discount_factor=gamma, + gamma=gamma, estimation_step=n_step, target_update_freq=target_freq, ) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index d56d770ee..13af8a387 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -219,7 +219,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: disc_net=disc_net, disc_optim=disc_optim, disc_update_num=args.disc_update_num, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index f09469d7d..6f0fbc212 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -151,7 +151,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 702df3c78..50152e437 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -149,7 +149,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index c41821ef1..c2f775934 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -152,7 +152,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 2932eeef8..d8739a30b 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -134,7 +134,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: algorithm: Reinforce = Reinforce( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, reward_normalization=args.rew_norm, ) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 430fbd56a..15006bff8 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -152,7 +152,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index b962d25aa..d4cb49fed 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -127,7 +127,7 @@ def main(args: argparse.Namespace = get_args()) -> None: algorithm: DiscreteBCQ = DiscreteBCQ( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, eval_eps=args.eps_test, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index dd64f82e3..937e3d46e 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -114,7 +114,7 @@ def main(args: argparse.Namespace = get_args()) -> None: algorithm: DiscreteCQL = DiscreteCQL( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, target_update_freq=args.target_update_freq, diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 28a188514..9447e90dc 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -127,7 +127,7 @@ def main(args: argparse.Namespace = get_args()) -> None: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, policy_improvement_mode=args.policy_improvement_mode, ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index b5654395c..4a867c5e8 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -110,7 +110,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: algorithm: C51 = C51( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index b55202861..bd91b9f96 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -159,7 +159,7 @@ def dist(logits: torch.Tensor) -> Categorical: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 0c9222b6c..ca2df5a22 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -116,7 +116,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, critic=critic, optim=AdamOptimizerFactory(lr=args.lr), - discount_factor=args.gamma, + gamma=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, gae_lambda=args.gae_lambda, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 6d94f13a2..54ffd523a 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -113,7 +113,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 518a42ddd..72e713d42 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -115,7 +115,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, gae_lambda=args.gae_lambda, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 194a4b4f1..a03e76497 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -107,7 +107,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, gae_lambda=args.gae_lambda, vf_coef=args.vf_coef, ent_coef=args.ent_coef, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 5fa986846..01a8c84c4 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -108,7 +108,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: algorithm: BDQN = BDQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, target_update_freq=args.target_update_freq, ) # collector diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 64795b800..ae8921b09 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -106,7 +106,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: algorithm: C51 = C51( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index add75c11b..f9d27bbfa 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -98,7 +98,7 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr algorithm: DQN = DQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index ce7393a3b..a9df8f1ec 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -87,7 +87,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: algorithm: DQN = DQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 67288b7db..595d25b8e 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -110,7 +110,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, fraction_optim=fraction_optim, - discount_factor=args.gamma, + gamma=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, estimation_step=args.n_step, diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index cdb767da1..ced065c2d 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -108,7 +108,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: algorithm: IQN = IQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index cb9ac442c..bf3fbae4c 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -85,7 +85,7 @@ def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tru algorithm: Reinforce = Reinforce( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, reward_normalization=args.rew_norm, ) for m in net.modules(): diff --git a/test/discrete/test_ppo2.py b/test/discrete/test_ppo2.py index 7a1393ede..5fe7e7dac 100644 --- a/test/discrete/test_ppo2.py +++ b/test/discrete/test_ppo2.py @@ -109,7 +109,7 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index f21e7a442..9bc4bb13a 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -102,7 +102,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: algorithm: QRDQN = QRDQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, target_update_freq=args.target_update_freq, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 793a163f0..2f70952a6 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -115,7 +115,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: algorithm: RainbowDQN = RainbowDQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index cc2be5b97..583ddb64b 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -115,7 +115,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: algorithm: DQN = DQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index efcf250ee..794ff3ac4 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -129,7 +129,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 86e052968..9ed763027 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -106,7 +106,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: algorithm: QRDQN = QRDQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, target_update_freq=args.target_update_freq, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 27f012712..cfe70c36f 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -97,7 +97,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: algorithm: DiscreteBCQ = DiscreteBCQ( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, eval_eps=args.eps_test, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 5706dbe71..05cae9409 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -89,7 +89,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: algorithm: DiscreteCQL = DiscreteCQL( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, num_quantiles=args.num_quantiles, estimation_step=args.n_step, target_update_freq=args.target_update_freq, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 5c06a56b7..8ce3734e6 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -93,7 +93,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 25c8f424b..e069fbf2a 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -145,7 +145,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: disc_net=disc_net, disc_optim=disc_optim, disc_update_num=args.disc_update_num, - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index e29114974..a8702fb49 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -107,7 +107,7 @@ def get_agents( agent: DQN = DQN( policy=policy, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index e6832d2b6..f53094310 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -201,7 +201,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, critic=critic, optim=optim, - discount_factor=args.gamma, + gamma=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index b5308fd1c..0142f2d26 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -131,7 +131,7 @@ def get_agents( policy=algorithm, optim=optim, estimation_step=args.n_step, - discount_factor=args.gamma, + gamma=args.gamma, target_update_freq=args.target_update_freq, ) if args.resume_path: diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index eebdfd63a..4b805dd4b 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -262,11 +262,21 @@ class ParamsMixinEstimationStep: @dataclass(kw_only=True) -class PGParams(Params, ParamsMixinActionScaling, ParamsMixinSingleModel): - discount_factor: float = 0.99 +class ParamsMixinGamma: + gamma: float = 0.99 """ - discount factor (gamma) for future rewards; must be in [0, 1] + the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks """ + + +@dataclass(kw_only=True) +class PGParams(Params, ParamsMixinGamma, ParamsMixinActionScaling, ParamsMixinSingleModel): reward_normalization: bool = False """ if True, will normalize the returns by subtracting the running mean and dividing by the running @@ -430,15 +440,15 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) -class _SACParams(Params, ParamsMixinActorAndDualCritics, ParamsMixinEstimationStep): +class _SACParams( + Params, ParamsMixinGamma, ParamsMixinActorAndDualCritics, ParamsMixinEstimationStep +): tau: float = 0.005 """controls the contribution of the entropy term in the overall optimization objective, i.e. the desired amount of randomness in the optimal policy. Higher values mean greater target entropy and therefore more randomness in the policy. Lower values mean lower target entropy and therefore a more deterministic policy. """ - gamma: float = 0.99 - """discount factor (gamma) for future rewards; must be in [0, 1]""" alpha: float | AutoAlphaFactory = 0.2 """ controls the relative importance (coefficient) of the entropy term in the loss function. @@ -478,11 +488,9 @@ class DiscreteSACParams(_SACParams): @dataclass(kw_only=True) -class QLearningOffPolicyParams(Params, ParamsMixinSingleModel, ParamsMixinEstimationStep): - discount_factor: float = 0.99 - """ - discount factor (gamma) for future rewards; must be in [0, 1] - """ +class QLearningOffPolicyParams( + Params, ParamsMixinGamma, ParamsMixinSingleModel, ParamsMixinEstimationStep +): target_update_freq: int = 0 """the target network update frequency (0 if no target network is to be used)""" reward_normalization: bool = False @@ -547,6 +555,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) class DDPGParams( Params, + ParamsMixinGamma, ParamsMixinActorAndCritic, ParamsMixinExplorationNoise, ParamsMixinActionScaling, @@ -558,16 +567,6 @@ class DDPGParams( It determines how slowly the target networks track the main networks. Smaller tau means slower tracking and more stable learning. """ - gamma: float = 0.99 - """ - the discount factor in [0, 1] for future rewards. - This determines how much future rewards are valued compared to immediate ones. - Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" - behavior. Higher values (closer to 1) make the agent value long-term rewards more, - potentially improving performance in tasks where delayed rewards are important but - increasing training variance by incorporating more environmental stochasticity. - Typically set between 0.9 and 0.99 for most reinforcement learning tasks - """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() @@ -609,6 +608,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) class TD3Params( Params, + ParamsMixinGamma, ParamsMixinActorAndDualCritics, ParamsMixinExplorationNoise, ParamsMixinActionScaling, @@ -620,8 +620,6 @@ class TD3Params( It determines how slowly the target networks track the main networks. Smaller tau means slower tracking and more stable learning. """ - gamma: float = 0.99 - """discount factor (gamma) for future rewards; must be in [0, 1]""" policy_noise: float | FloatEnvValueFactory = 0.2 """the scale of the the noise used in updating policy network""" noise_clip: float | FloatEnvValueFactory = 0.5 diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 6faa5facf..38574b2b7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -702,7 +702,13 @@ def compute_episodic_return( If None, it will be set to an array of 0. :param v_s: the value function of all current states :math:`V(s)`. If None, it is set based upon `v_s_` rolled by 1. - :param gamma: the discount factor, should be in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1]. @@ -749,7 +755,13 @@ def compute_nstep_return( :param indices: tell batch's location in buffer :param target_q_fn: a function which computes the target Q value of "obs_next" given data buffer and wanted indices (`n_step` steps ahead). - :param gamma: the discount factor, should be in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step: the number of estimation step, should be an int greater than 0. :param rew_norm: normalize the reward to Normal(0, 1). @@ -1093,7 +1105,7 @@ def _gae_return( $V_{t+1}$ :param rew: rewards in an episode, i.e. $r_t$ :param end_flag: boolean array indicating whether the episode is done - :param gamma: discount factor + :param gamma: the discount factor in [0, 1] for future rewards. :param gae_lambda: lambda parameter for GAE, controlling the bias-variance tradeoff :return: """ diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index d1b2300fb..e57f44bbd 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -123,7 +123,13 @@ def __init__( :param critic2: the second critic network; if None, clone the critic from the policy :param critic2_optim: the optimizer for the second critic network; if None, use optimizer factory of first critic :param vae_optim: the optimizer for the VAE network. - :param gamma: discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param tau: param for soft update of the target network. :param lmbda: param for Clipped Double Q-learning. :param num_sampled_action: the number of sampled actions in calculating target Q. diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index eef500c13..74bf7a0d1 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -74,7 +74,13 @@ def __init__( :param cql_alpha_lr: The learning rate of cql_log_alpha. :param cql_weight: :param tau: Parameter for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param temperature: diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index ab34ebc34..419bae6a2 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -109,7 +109,7 @@ def __init__( *, policy: DiscreteBCQPolicy, optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 8000, eval_eps: float = 1e-3, @@ -121,7 +121,13 @@ def __init__( """ :param policy: the policy :param optim: a torch.optim for optimizing the model. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) @@ -154,10 +160,8 @@ def __init__( ) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.optim = self._create_optimizer(self.policy, optim) - assert ( - 0.0 <= discount_factor <= 1.0 - ), f"discount factor should be in [0, 1] but got: {discount_factor}" - self.gamma = discount_factor + assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" + self.gamma = gamma assert ( estimation_step > 0 ), f"estimation_step should be greater than 0 but got: {estimation_step}" diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 5d6f70a7b..179a79e42 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -35,7 +35,7 @@ def __init__( policy: QRDQNPolicy, optim: OptimizerFactory, min_q_weight: float = 10.0, - discount_factor: float = 0.99, + gamma: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, @@ -45,7 +45,13 @@ def __init__( :param policy: the policy :param optim: a torch.optim for optimizing the model. :param min_q_weight: the weight for the cql loss. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. :param estimation_step: the number of future steps (> 0) to consider when computing temporal @@ -64,7 +70,7 @@ def __init__( self, policy=policy, optim=optim, - discount_factor=discount_factor, + gamma=gamma, num_quantiles=num_quantiles, estimation_step=estimation_step, target_update_freq=target_update_freq, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index a0626e4c8..a79d123ed 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -44,7 +44,7 @@ def __init__( policy: DiscreteActorPolicy, critic: torch.nn.Module | DiscreteCritic, optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, policy_improvement_mode: Literal["exp", "binary", "all"] = "exp", ratio_upper_bound: float = 20.0, beta: float = 1.0, @@ -57,7 +57,13 @@ def __init__( :param critic: the action-value critic (i.e., Q function) network. (s -> Q(s, \*)) :param optim: the optimizer for the policy's actor and the critic networks. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param str policy_improvement_mode: type of the weight function f. Possible values: "binary"/"exp"/"all". :param ratio_upper_bound: when policy_improvement_mode is "exp", the value @@ -76,7 +82,7 @@ def __init__( ) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.discounted_return_computation = DiscountedReturnComputation( - discount_factor=discount_factor, + gamma=gamma, reward_normalization=reward_normalization, ) self.critic = critic diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 890ed192c..1c5be16c6 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -55,7 +55,7 @@ def __init__( max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, - discount_factor: float = 0.99, + gamma: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, ) -> None: @@ -84,7 +84,13 @@ def __init__( :param max_grad_norm: clipping gradients in back propagation. :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param reward_normalization: normalize estimated values to have std close to 1. """ super().__init__( @@ -101,7 +107,7 @@ def __init__( max_grad_norm=max_grad_norm, gae_lambda=gae_lambda, max_batchsize=max_batchsize, - discount_factor=discount_factor, + gamma=gamma, reward_normalization=reward_normalization, ) self.disc_net = disc_net diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 9f9f0c446..028f92ef3 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -53,7 +53,13 @@ def __init__( :param critic2_optim: the optimizer for the second critic network. If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param exploration_noise: add noise to action for exploration. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 5d11de79a..fa8beb6f7 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -32,7 +32,7 @@ def __init__( trans_count_prior: np.ndarray, rew_mean_prior: np.ndarray, rew_std_prior: np.ndarray, - discount_factor: float, + gamma: float, epsilon: float, ) -> None: """ @@ -42,7 +42,13 @@ def __init__( with shape (n_state, n_action). :param rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param epsilon: for precision control in value iteration. """ self.trans_count = trans_count_prior @@ -51,7 +57,7 @@ def __init__( self.rew_std = rew_std_prior self.rew_square_sum = np.zeros_like(rew_mean_prior) self.rew_std_prior = rew_std_prior - self.discount_factor = discount_factor + self.gamma = gamma self.rew_count = np.full(rew_mean_prior.shape, epsilon) # no weight self.eps = epsilon self.policy: np.ndarray @@ -105,7 +111,7 @@ def solve_policy(self) -> None: self.policy, self.value = self.value_iteration( self.sample_trans_prob(), self.sample_reward(), - self.discount_factor, + self.gamma, self.eps, self.value, ) @@ -114,7 +120,7 @@ def solve_policy(self) -> None: def value_iteration( trans_prob: np.ndarray, rew: np.ndarray, - discount_factor: float, + gamma: float, eps: float, value: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: @@ -124,17 +130,23 @@ def value_iteration( (n_state, n_action, n_state). :param rew: rewards, with shape (n_state, n_action). :param eps: for precision control. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param value: the initialize value of value array, with shape (n_state, ). :return: the optimal policy with shape (n_state, ). """ - Q = rew + discount_factor * trans_prob.dot(value) + Q = rew + gamma * trans_prob.dot(value) new_value = Q.max(axis=1) while not np.allclose(new_value, value, eps): value = new_value - Q = rew + discount_factor * trans_prob.dot(value) + Q = rew + gamma * trans_prob.dot(value) new_value = Q.max(axis=1) # this is to make sure if Q(s, a1) == Q(s, a2) -> choose a1/a2 randomly Q += eps * np.random.randn(*Q.shape) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 2162115c4..7d18ba1e8 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -47,7 +47,7 @@ def __init__( max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, - discount_factor: float = 0.99, + gamma: float = 0.99, reward_normalization: bool = False, ) -> None: """ @@ -59,7 +59,13 @@ def __init__( is not applied :param gae_lambda: in [0, 1], param for generalized advantage estimation (GAE). :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param reward_normalization: normalize estimated values to have std close to 1. """ super().__init__( @@ -75,7 +81,7 @@ def __init__( ) else: self.optim = self._create_optimizer(self.critic, optim, max_grad_norm=max_grad_norm) - self.gamma = discount_factor + self.gamma = gamma self.rew_norm = reward_normalization self.ret_rms = RunningMeanStd() self._eps = 1e-8 @@ -135,7 +141,7 @@ def __init__( max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, - discount_factor: float = 0.99, + gamma: float = 0.99, # TODO: This algorithm does not seem to use the reward_normalization parameter. reward_normalization: bool = False, ) -> None: @@ -148,7 +154,13 @@ def __init__( :param max_grad_norm: clipping gradients in back propagation. :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param reward_normalization: normalize estimated values to have std close to 1. """ super().__init__( @@ -159,7 +171,7 @@ def __init__( max_grad_norm=max_grad_norm, gae_lambda=gae_lambda, max_batchsize=max_batchsize, - discount_factor=discount_factor, + gamma=gamma, reward_normalization=reward_normalization, ) self.vf_coef = vf_coef diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index eb578ba52..50dc2eb63 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -115,7 +115,7 @@ def __init__( *, policy: BDQNPolicy, optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, @@ -124,7 +124,13 @@ def __init__( """ :param policy: policy :param optim: the optimizer for the policy - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) @@ -144,7 +150,7 @@ def __init__( super().__init__( policy=policy, optim=optim, - discount_factor=discount_factor, + gamma=gamma, estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index d08b2dd0d..0fc3b915e 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -86,7 +86,7 @@ def __init__( *, policy: C51Policy, optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, @@ -94,7 +94,13 @@ def __init__( """ :param policy: a policy following the rules (s -> action_values_BA) :param optim: a torch.optim for optimizing the policy. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) @@ -110,7 +116,7 @@ def __init__( super().__init__( policy=policy, optim=optim, - discount_factor=discount_factor, + gamma=gamma, estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index e1fe1f53a..2a5c497f2 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -201,7 +201,13 @@ def __init__( a continuous action space; override this method if using discrete actions. :param critic_optim: the optimizer for the critic network. :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() """ @@ -317,7 +323,13 @@ def __init__( :param critic: The critic network. (s, a -> Q(s, a)) :param critic_optim: The optimizer for critic network. :param tau: Param for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 65b132801..135bd3961 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -102,7 +102,13 @@ def __init__( :param critic2_optim: the optimizer for the second critic network. If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param estimation_step: the number of future steps (> 0) to consider when computing temporal diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index ac76f154c..2d313ad7d 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -196,7 +196,7 @@ def __init__( *, policy: TDQNPolicy, optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, @@ -204,7 +204,13 @@ def __init__( """ :param policy: the policy :param optim: the optimizer for the policy - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) @@ -222,10 +228,8 @@ def __init__( ) self.optim = self._create_policy_optimizer(optim) LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) - assert ( - 0.0 <= discount_factor <= 1.0 - ), f"discount factor should be in [0, 1] but got: {discount_factor}" - self.gamma = discount_factor + assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" + self.gamma = gamma assert ( estimation_step > 0 ), f"estimation_step should be greater than 0 but got: {estimation_step}" @@ -298,7 +302,7 @@ def __init__( *, policy: TDQNPolicy, optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, @@ -308,7 +312,13 @@ def __init__( """ :param policy: the policy :param optim: the optimizer for the policy - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) @@ -328,7 +338,7 @@ def __init__( super().__init__( policy=policy, optim=optim, - discount_factor=discount_factor, + gamma=gamma, estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 9ac61922f..3e24f2a5e 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -118,7 +118,7 @@ def __init__( policy: FQFPolicy, optim: OptimizerFactory, fraction_optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, # TODO: used as num_quantiles in QRDQNPolicy, but num_fractions in FQFPolicy. # Rename? Or at least explain what happens here. num_fractions: int = 32, @@ -132,7 +132,13 @@ def __init__( :param optim: the optimizer for the policy's main Q-function model :param fraction_optim: the optimizer for the policy's fraction model :param action_space: Env's action space. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_fractions: the number of fractions to use. :param ent_coef: the coefficient for entropy loss. :param estimation_step: the number of future steps (> 0) to consider when computing temporal @@ -150,7 +156,7 @@ def __init__( super().__init__( policy=policy, optim=optim, - discount_factor=discount_factor, + gamma=gamma, num_quantiles=num_fractions, estimation_step=estimation_step, target_update_freq=target_update_freq, diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 3783ae8c4..2290cf954 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -119,7 +119,7 @@ def __init__( *, policy: IQNPolicy, optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, @@ -128,7 +128,13 @@ def __init__( """ :param policy: the policy :param optim: the optimizer for the policy's model - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. :param estimation_step: the number of future steps (> 0) to consider when computing temporal @@ -146,7 +152,7 @@ def __init__( super().__init__( policy=policy, optim=optim, - discount_factor=discount_factor, + gamma=gamma, num_quantiles=num_quantiles, estimation_step=estimation_step, target_update_freq=target_update_freq, diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 9077f407f..d83410e2d 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -44,7 +44,7 @@ def __init__( advantage_normalization: bool = True, gae_lambda: float = 0.95, max_batchsize: int = 256, - discount_factor: float = 0.99, + gamma: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, ) -> None: @@ -58,7 +58,13 @@ def __init__( normalization. :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param reward_normalization: normalize estimated values to have std close to 1. """ super().__init__( @@ -68,7 +74,7 @@ def __init__( optim_include_actor=False, gae_lambda=gae_lambda, max_batchsize=max_batchsize, - discount_factor=discount_factor, + gamma=gamma, reward_normalization=reward_normalization, ) self.norm_adv = advantage_normalization diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index b430d6302..16b5e30ae 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -184,17 +184,23 @@ def __init__( class DiscountedReturnComputation: def __init__( self, - discount_factor: float = 0.99, + gamma: float = 0.99, reward_normalization: bool = False, ): """ - :param discount_factor: the future reward discount factor gamma in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param reward_normalization: if True, will normalize the *returns* by subtracting the running mean and dividing by the running standard deviation. Can be detrimental to performance! """ - assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" - self.gamma = discount_factor + assert 0.0 <= gamma <= 1.0, "discount factor gamma should be in [0, 1]" + self.gamma = gamma self.rew_norm = reward_normalization self.ret_rms = RunningMeanStd() self.eps = 1e-8 @@ -257,14 +263,20 @@ def __init__( self, *, policy: TActorPolicy, - discount_factor: float = 0.99, + gamma: float = 0.99, reward_normalization: bool = False, optim: OptimizerFactory, ) -> None: """ :param policy: the policy :param optim: optimizer for the policy's actor network. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param reward_normalization: if True, will normalize the *returns* by subtracting the running mean and dividing by the running standard deviation. Can be detrimental to performance! @@ -273,7 +285,7 @@ def __init__( policy=policy, ) self.discounted_return_computation = DiscountedReturnComputation( - discount_factor=discount_factor, + gamma=gamma, reward_normalization=reward_normalization, ) self.optim = self._create_optimizer(self.policy, optim) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 5f637ce42..50cff3596 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -71,7 +71,7 @@ def __init__( max_grad_norm: float | None = None, gae_lambda: float = 0.95, max_batchsize: int = 256, - discount_factor: float = 0.99, + gamma: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, ) -> None: @@ -94,7 +94,13 @@ def __init__( :param max_grad_norm: clipping gradients in back propagation. :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param reward_normalization: normalize estimated values to have std close to 1. """ assert ( @@ -110,7 +116,7 @@ def __init__( max_grad_norm=max_grad_norm, gae_lambda=gae_lambda, max_batchsize=max_batchsize, - discount_factor=discount_factor, + gamma=gamma, reward_normalization=reward_normalization, ) self.eps_clip = eps_clip diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 17326deed..a3d7ae540 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -43,7 +43,7 @@ def __init__( *, policy: TQRDQNPolicy, optim: OptimizerFactory, - discount_factor: float = 0.99, + gamma: float = 0.99, num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, @@ -52,7 +52,13 @@ def __init__( """ :param policy: the policy :param optim: the optimizer for the policy - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. :param estimation_step: the number of future steps (> 0) to consider when computing temporal @@ -71,7 +77,7 @@ def __init__( super().__init__( policy=policy, optim=optim, - discount_factor=discount_factor, + gamma=gamma, estimation_step=estimation_step, target_update_freq=target_update_freq, reward_normalization=reward_normalization, diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 38925280a..f18d10b0e 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -130,7 +130,13 @@ def __init__( :param ensemble_size: Number of sub-networks in the critic ensemble. :param subset_size: Number of networks in the subset. :param tau: Param for soft update of the target network. - :param gamma: Discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param estimation_step: the number of future steps (> 0) to consider when computing temporal diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index af5b45528..3ad6e6524 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -227,7 +227,13 @@ def __init__( :param critic2_optim: the optimizer for the second critic network. If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). :param estimation_step: the number of future steps (> 0) to consider when computing temporal diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 902ca8ced..af5522268 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -68,7 +68,13 @@ def __init__( :param critic2_optim: the optimizer for the second critic network. If None, use critic_optim. :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() """ @@ -128,7 +134,13 @@ def __init__( :param critic2_optim: the optimizer for the second critic network. If None, clone critic_optim to use for critic2.parameters(). :param tau: param for soft update of the target network. - :param gamma: discount factor, in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param policy_noise: the noise used in updating policy network. :param update_actor_freq: the update frequency of actor network. :param noise_clip: the clipping range used in updating policy network. diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 95021d158..e3667857f 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -40,7 +40,7 @@ def __init__( advantage_normalization: bool = True, gae_lambda: float = 0.95, max_batchsize: int = 256, - discount_factor: float = 0.99, + gamma: float = 0.99, # TODO: rename to return_normalization? reward_normalization: bool = False, ) -> None: @@ -58,7 +58,13 @@ def __init__( normalization. :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. :param max_batchsize: the maximum size of the batch when computing GAE. - :param discount_factor: in [0, 1]. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param reward_normalization: normalize estimated values to have std close to 1. """ super().__init__( @@ -70,7 +76,7 @@ def __init__( advantage_normalization=advantage_normalization, gae_lambda=gae_lambda, max_batchsize=max_batchsize, - discount_factor=discount_factor, + gamma=gamma, reward_normalization=reward_normalization, ) self.max_backtracks = max_backtracks From 99e3c5a3cf2e5687347d8ac7406a62e8922f349a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 14:45:52 +0200 Subject: [PATCH 091/230] v2: Improve docstrings for optimizer factories (and related/neighbouring parameters) --- tianshou/policy/imitation/bcq.py | 4 ++-- tianshou/policy/imitation/cql.py | 12 ++++++------ tianshou/policy/imitation/discrete_bcq.py | 2 +- tianshou/policy/imitation/discrete_cql.py | 2 +- tianshou/policy/imitation/discrete_crr.py | 2 +- tianshou/policy/imitation/td3_bc.py | 10 +++++----- tianshou/policy/modelbased/icm.py | 2 +- tianshou/policy/modelfree/a2c.py | 2 +- tianshou/policy/modelfree/bdqn.py | 2 +- tianshou/policy/modelfree/c51.py | 2 +- tianshou/policy/modelfree/ddpg.py | 10 +++++----- tianshou/policy/modelfree/discrete_sac.py | 10 +++++----- tianshou/policy/modelfree/dqn.py | 4 ++-- tianshou/policy/modelfree/fqf.py | 6 +++--- tianshou/policy/modelfree/iqn.py | 2 +- tianshou/policy/modelfree/pg.py | 2 +- tianshou/policy/modelfree/ppo.py | 2 +- tianshou/policy/modelfree/qrdqn.py | 2 +- tianshou/policy/modelfree/redq.py | 10 +++++----- tianshou/policy/modelfree/sac.py | 10 +++++----- tianshou/policy/modelfree/td3.py | 20 ++++++++++---------- 21 files changed, 59 insertions(+), 59 deletions(-) diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index e57f44bbd..3983deabf 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -121,8 +121,8 @@ def __init__( :param actor_perturbation_optim: the optimizer factory for the policy's actor perturbation network. :param critic_optim: the optimizer factory for the policy's critic network. :param critic2: the second critic network; if None, clone the critic from the policy - :param critic2_optim: the optimizer for the second critic network; if None, use optimizer factory of first critic - :param vae_optim: the optimizer for the VAE network. + :param critic2_optim: the optimizer factory for the second critic network; if None, use optimizer factory of first critic + :param vae_optim: the optimizer factory for the VAE network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 74bf7a0d1..1de24d1fe 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -63,14 +63,14 @@ def __init__( """ :param actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> a) - :param policy_optim: The optimizer for actor network. - :param critic: The first critic network. - :param critic_optim: The optimizer for the first critic network. - :param action_space: Env's action space. + :param policy_optim: the optimizer factory for the policy/its actor network. + :param critic: the first critic network. + :param critic_optim: the optimizer factory for the first critic network. + :param action_space: the environment's action space. :param critic2: the second critic network. (s, a -> Q(s, a)). If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). + :param critic2_optim: the optimizer factory for the second critic network. + If None, clone the first critic's optimizer factory. :param cql_alpha_lr: The learning rate of cql_log_alpha. :param cql_weight: :param tau: Parameter for soft update of the target network. diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 419bae6a2..e77ffbbed 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -120,7 +120,7 @@ def __init__( ) -> None: """ :param policy: the policy - :param optim: a torch.optim for optimizing the model. + :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 179a79e42..cffaa8e7f 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -43,7 +43,7 @@ def __init__( ) -> None: """ :param policy: the policy - :param optim: a torch.optim for optimizing the model. + :param optim: the optimizer factory for the policy's model. :param min_q_weight: the weight for the cql loss. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index a79d123ed..a72ed53d9 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -56,7 +56,7 @@ def __init__( :param policy: the policy :param critic: the action-value critic (i.e., Q function) network. (s -> Q(s, \*)) - :param optim: the optimizer for the policy's actor and the critic networks. + :param optim: the optimizer factory for the policy's actor network and the critic networks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 028f92ef3..19ac386d1 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -45,13 +45,13 @@ def __init__( ) -> None: """ :param policy: the policy - :param policy_optim: the optimizer for policy. + :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. + :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. :param tau: param for soft update of the target network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 41ae3d4ff..99d8e6a4b 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -49,7 +49,7 @@ def __init__( ) -> None: """ :param model: the ICM model. - :param optim: the optimizer for parameter `model`. + :param optim: the optimizer factory. :param lr_scale: the scaling factor for ICM learning. :param forward_loss_weight: the weight for forward model loss. """ diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 7d18ba1e8..2c12e3e6f 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -148,7 +148,7 @@ def __init__( """ :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer factory for the actor and critic networks. + :param optim: the optimizer factory. :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. :param max_grad_norm: clipping gradients in back propagation. diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 50dc2eb63..9eafa8840 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -123,7 +123,7 @@ def __init__( ) -> None: """ :param policy: policy - :param optim: the optimizer for the policy + :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 0fc3b915e..35c68a3f5 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -93,7 +93,7 @@ def __init__( ) -> None: """ :param policy: a policy following the rules (s -> action_values_BA) - :param optim: a torch.optim for optimizing the policy. + :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 2a5c497f2..f9d4a8a02 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -193,13 +193,13 @@ def __init__( ) -> None: """ :param policy: the policy - :param policy_optim: the optimizer for actor network. + :param policy_optim: the optimizer factory for the policy's model. :param critic: the critic network. For continuous action spaces: (s, a -> Q(s, a)). For discrete action spaces: (s -> ). NOTE: The default implementation of `_target_q_compute_value` assumes a continuous action space; override this method if using discrete actions. - :param critic_optim: the optimizer for the critic network. + :param critic_optim: the optimizer factory for the critic network. :param tau: param for soft update of the target network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. @@ -319,9 +319,9 @@ def __init__( ) -> None: """ :param policy: the policy - :param policy_optim: The optimizer for actor network. - :param critic: The critic network. (s, a -> Q(s, a)) - :param critic_optim: The optimizer for critic network. + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer factory for the critic network. :param tau: Param for soft update of the target network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 135bd3961..edeccaf94 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -94,13 +94,13 @@ def __init__( ) -> None: """ :param policy: the policy - :param policy_optim: the optimizer for actor network. + :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. (s -> ). - :param critic_optim: the optimizer for the first critic network. + :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network. (s -> ). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. :param tau: param for soft update of the target network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 2d313ad7d..0805e6834 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -203,7 +203,7 @@ def __init__( ) -> None: """ :param policy: the policy - :param optim: the optimizer for the policy + :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" @@ -311,7 +311,7 @@ def __init__( ) -> None: """ :param policy: the policy - :param optim: the optimizer for the policy + :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 3e24f2a5e..b03ee3ae2 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -129,9 +129,9 @@ def __init__( ) -> None: """ :param policy: the policy - :param optim: the optimizer for the policy's main Q-function model - :param fraction_optim: the optimizer for the policy's fraction model - :param action_space: Env's action space. + :param optim: the optimizer factory for the policy's main Q-function model + :param fraction_optim: the optimizer factory for the policy's fraction model + :param action_space: the environment's action space. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 2290cf954..54f60c2fc 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -127,7 +127,7 @@ def __init__( ) -> None: """ :param policy: the policy - :param optim: the optimizer for the policy's model + :param optim: the optimizer factory for the policy's model :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 16b5e30ae..b26d387a2 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -269,7 +269,7 @@ def __init__( ) -> None: """ :param policy: the policy - :param optim: optimizer for the policy's actor network. + :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 50cff3596..dc602bd56 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -78,7 +78,7 @@ def __init__( r""" :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer factory for the actor and critic networks. + :param optim: the optimizer factory for the policy's actor network and the critic networks. :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original paper. :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index a3d7ae540..8790c4523 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -51,7 +51,7 @@ def __init__( ) -> None: """ :param policy: the policy - :param optim: the optimizer for the policy + :param optim: the optimizer factory for the policy's model. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index f18d10b0e..61f6b4b95 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -124,11 +124,11 @@ def __init__( ) -> None: """ :param policy: the policy - :param policy_optim: The optimizer for actor network. - :param critic: The critic network. (s, a -> Q(s, a)) - :param critic_optim: The optimizer for critic network. - :param ensemble_size: Number of sub-networks in the critic ensemble. - :param subset_size: Number of networks in the subset. + :param policy_optim: the optimizer factory for the policy's model. + :param critic: the critic network. (s, a -> Q(s, a)) + :param critic_optim: the optimizer factory for the critic network. + :param ensemble_size: the number of sub-networks in the critic ensemble. + :param subset_size: the number of networks in the subset. :param tau: Param for soft update of the target network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 3ad6e6524..49dc944b3 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -219,13 +219,13 @@ def __init__( ) -> None: """ :param policy: the policy - :param policy_optim: the optimizer for actor network. + :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. + :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. :param tau: param for soft update of the target network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index af5522268..566eef0be 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -57,16 +57,16 @@ def __init__( ) -> None: """ :param policy: the policy - :param policy_optim: the optimizer for actor network. + :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. For continuous action spaces: (s, a -> Q(s, a)). NOTE: The default implementation of `_target_q_compute_value` assumes a continuous action space; override this method if using discrete actions. - :param critic_optim: the optimizer for the first critic network. + :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network (analogous functionality to the first). - If None, use the same network as the first critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, use critic_optim. + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. :param tau: param for soft update of the target network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. @@ -126,13 +126,13 @@ def __init__( ) -> None: """ :param policy: the policy - :param policy_optim: the optimizer for actor network. + :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. (s, a -> Q(s, a)) - :param critic_optim: the optimizer for the first critic network. + :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network. (s, a -> Q(s, a)). - If None, use the same network as critic (via deepcopy). - :param critic2_optim: the optimizer for the second critic network. - If None, clone critic_optim to use for critic2.parameters(). + If None, copy the first critic (via deepcopy). + :param critic2_optim: the optimizer factory for the second critic network. + If None, use the first critic's factory. :param tau: param for soft update of the target network. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. From 41c4563177d1bd5b5415255768dc51b003083551 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 15:00:31 +0200 Subject: [PATCH 092/230] v2: Improve description of parameter 'tau' --- tianshou/highlevel/params/policy_params.py | 42 +++++++++++----------- tianshou/policy/imitation/bcq.py | 9 ++++- tianshou/policy/imitation/cql.py | 9 ++++- tianshou/policy/imitation/td3_bc.py | 9 ++++- tianshou/policy/modelfree/ddpg.py | 32 +++++++++++++---- tianshou/policy/modelfree/discrete_sac.py | 9 ++++- tianshou/policy/modelfree/redq.py | 9 ++++- tianshou/policy/modelfree/sac.py | 9 ++++- tianshou/policy/modelfree/td3.py | 18 ++++++++-- 9 files changed, 111 insertions(+), 35 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 4b805dd4b..41172bbd7 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -275,6 +275,21 @@ class ParamsMixinGamma: """ +@dataclass(kw_only=True) +class ParamsMixinTau: + tau: float = 0.005 + """ + the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + """ + + @dataclass(kw_only=True) class PGParams(Params, ParamsMixinGamma, ParamsMixinActionScaling, ParamsMixinSingleModel): reward_normalization: bool = False @@ -441,14 +456,12 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) class _SACParams( - Params, ParamsMixinGamma, ParamsMixinActorAndDualCritics, ParamsMixinEstimationStep + Params, + ParamsMixinGamma, + ParamsMixinActorAndDualCritics, + ParamsMixinEstimationStep, + ParamsMixinTau, ): - tau: float = 0.005 - """controls the contribution of the entropy term in the overall optimization objective, - i.e. the desired amount of randomness in the optimal policy. - Higher values mean greater target entropy and therefore more randomness in the policy. - Lower values mean lower target entropy and therefore a more deterministic policy. - """ alpha: float | AutoAlphaFactory = 0.2 """ controls the relative importance (coefficient) of the entropy term in the loss function. @@ -560,14 +573,8 @@ class DDPGParams( ParamsMixinExplorationNoise, ParamsMixinActionScaling, ParamsMixinEstimationStep, + ParamsMixinTau, ): - tau: float = 0.005 - """ - controls the soft update of the target network. - It determines how slowly the target networks track the main networks. - Smaller tau means slower tracking and more stable learning. - """ - def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) @@ -613,13 +620,8 @@ class TD3Params( ParamsMixinExplorationNoise, ParamsMixinActionScaling, ParamsMixinEstimationStep, + ParamsMixinTau, ): - tau: float = 0.005 - """ - controls the soft update of the target network. - It determines how slowly the target networks track the main networks. - Smaller tau means slower tracking and more stable learning. - """ policy_noise: float | FloatEnvValueFactory = 0.2 """the scale of the the noise used in updating policy network""" noise_clip: float | FloatEnvValueFactory = 0.5 diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 3983deabf..b5a67be48 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -130,7 +130,14 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param tau: param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param lmbda: param for Clipped Double Q-learning. :param num_sampled_action: the number of sampled actions in calculating target Q. The algorithm samples several actions using VAE, and perturbs each action to get the target Q. diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 1de24d1fe..aba5b8632 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -73,7 +73,14 @@ def __init__( If None, clone the first critic's optimizer factory. :param cql_alpha_lr: The learning rate of cql_log_alpha. :param cql_weight: - :param tau: Parameter for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 19ac386d1..048a4b409 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -52,7 +52,14 @@ def __init__( If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. - :param tau: param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index f9d4a8a02..7a1102c1f 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -108,8 +108,15 @@ def __init__( This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). :param action_space: Env's action space. - :param tau: Param for soft update of the target network. - :param observation_space: Env's observation space. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. + :param observation_space: the environment's observation space. :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. :param action_bound_method: method to bound action to range [-1, 1]. @@ -200,7 +207,14 @@ def __init__( NOTE: The default implementation of `_target_q_compute_value` assumes a continuous action space; override this method if using discrete actions. :param critic_optim: the optimizer factory for the critic network. - :param tau: param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" @@ -208,8 +222,6 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() """ assert 0.0 <= tau <= 1.0, f"tau should be in [0, 1] but got: {tau}" assert 0.0 <= gamma <= 1.0, f"gamma should be in [0, 1] but got: {gamma}" @@ -322,7 +334,14 @@ def __init__( :param policy_optim: the optimizer factory for the policy's model. :param critic: the critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer factory for the critic network. - :param tau: Param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" @@ -337,7 +356,6 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param lr_scheduler: if not None, will be called in `policy.update()`. """ super().__init__( policy=policy, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index edeccaf94..887bbbdae 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -101,7 +101,14 @@ def __init__( If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. - :param tau: param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 61f6b4b95..72a62eff1 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -129,7 +129,14 @@ def __init__( :param critic_optim: the optimizer factory for the critic network. :param ensemble_size: the number of sub-networks in the critic ensemble. :param subset_size: the number of networks in the subset. - :param tau: Param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 49dc944b3..289b35931 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -226,7 +226,14 @@ def __init__( If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. - :param tau: param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 566eef0be..274bdf110 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -67,7 +67,14 @@ def __init__( If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. - :param tau: param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" @@ -133,7 +140,14 @@ def __init__( If None, copy the first critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, use the first critic's factory. - :param tau: param for soft update of the target network. + :param tau: the soft update coefficient for target networks, controlling the rate at which + target networks track the learned networks. + When the parameters of the target network are updated with the current (source) network's + parameters, a weighted average is used: target = tau * source + (1 - tau) * target. + Smaller values (closer to 0) create more stable but slower learning as target networks + change more gradually. Higher values (closer to 1) allow faster learning but may reduce + stability. + Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" From 7e4d696509b443573916033943b1dc5fdc7f3f95 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 15:58:33 +0200 Subject: [PATCH 093/230] v2: Update references to parameters --- tianshou/trainer/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index d042ad15f..6c995e645 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -4,13 +4,13 @@ Training is structured as follows (hierarchical glossary): - **epoch**: The outermost iteration level of the training loop. Each epoch consists of a number of training steps - and one test step (see :attr:`TrainingConfig.max_epoch` for a detailed explanation): + and one test step (see :attr:`TrainerParams.max_epoch` for a detailed explanation): - **training step**: A training step performs the steps necessary in order to apply a single update of the neural network components as defined by the underlying RL algorithm (:class:`Algorithm`). This involves the following sub-steps: - for online learning algorithms: - **collection step**: collecting environment steps/transitions to be used for training. - (potentially) a test step (see below) if the early stopping criterion is satisfied based on - the data collected (see :attr:`OnlineTrainingConfig.test_in_train`). + the data collected (see :attr:`OnlineTrainerParams.test_in_train`). - **update step**: applying the actual gradient updates using the RL algorithm. The update is based on either ... - data from only the preceding collection step (on-policy learning), @@ -19,7 +19,7 @@ For offline learning algorithms, a training step is thus equivalent to an update step. - **test step**: Collects test episodes from dedicated test environments which are used to evaluate the performance of the policy. Optionally, the performance result can be used to determine whether training shall stop early - (see :attr:`TrainingConfig.stop_fn`). + (see :attr:`TrainerParams.stop_fn`). """ import logging import time @@ -473,7 +473,7 @@ class _TrainingStepResult(ABC): def get_steps_in_epoch_advancement(self) -> int: """ :return: the number of steps that were done within the epoch, where the concrete semantics - of what a step is depend on the type of algorith. See docstring of `TrainingConfig.step_per_epoch`. + of what a step is depend on the type of algorithm. See docstring of `TrainerParams.step_per_epoch`. """ @abstractmethod From 186178f3c22a68b2d44144294ef15ae52ff75ff7 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 15:58:52 +0200 Subject: [PATCH 094/230] v2: Add/improve docstrings of algorithm base classes --- tianshou/policy/base.py | 110 ++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 49 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 38574b2b7..993d591a6 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -151,6 +151,8 @@ def __setattr__(self, name: str, value: Any) -> None: class Policy(nn.Module, ABC): + """Represents a policy, which provides the fundamental mapping from observations to actions.""" + def __init__( self, action_space: gym.Space, @@ -387,6 +389,11 @@ def add_exploration_noise( class LaggedNetworkAlgorithmMixin(ABC): + """ + Base class for an algorithm mixin which adds support for lagged networks (target networks) whose weights + are updated periodically. + """ + def __init__(self) -> None: self._lagged_networks = LaggedNetworkCollection() @@ -408,12 +415,27 @@ def _update_lagged_network_weights(self) -> None: class LaggedNetworkFullUpdateAlgorithmMixin(LaggedNetworkAlgorithmMixin): + """ + Algorithm mixin which adds support for lagged networks (target networks) where weights + are updated by fully copying the weights of the source network to the target network. + """ + def _update_lagged_network_weights(self) -> None: self._lagged_networks.full_parameter_update() class LaggedNetworkPolyakUpdateAlgorithmMixin(LaggedNetworkAlgorithmMixin): + """ + Algorithm mixin which adds support for lagged networks (target networks) where weights + are updated via Polyak averaging (soft update using a convex combination of the parameters + of the source and target networks with weight `tau` and `1-tau` respectively). + """ + def __init__(self, tau: float) -> None: + """ + :param tau: the fraction with which to use the source network's parameters, the inverse `1-tau` being + the fraction with which to retain the target network's parameters. + """ super().__init__() self.tau = tau @@ -427,51 +449,11 @@ def _update_lagged_network_weights(self) -> None: class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainerParams, TTrainingStats], ABC): """ - TODO fix docstring - The base class for any RL policy. - - Tianshou aims to modularize RL algorithms. It comes into several classes of - policies in Tianshou. All policy classes must inherit from - :class:`~tianshou.policy.BasePolicy`. - - A policy class typically has the following parts: - - * :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including \ - coping the target network and so on; - * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \ - observation; - * :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the \ - replay buffer (this function can interact with replay buffer); - * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of \ - data. - * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the replay buffer \ - from the learning process (e.g., prioritized replay buffer needs to update \ - the weight); - * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training, \ - i.e., `process_fn -> learn -> post_process_fn`. - - Most of the policy needs a neural network to predict the action and an - optimizer to optimize the policy. The rules of self-defined networks are: - - 1. Input: observation "obs" (may be a ``numpy.ndarray``, a ``torch.Tensor``, a \ - dict or any others), hidden state "state" (for RNN usage), and other information \ - "info" provided by the environment. - 2. Output: some "logits", the next hidden state "state", and the intermediate \ - result during policy forwarding procedure "policy". The "logits" could be a tuple \ - instead of a ``torch.Tensor``. It depends on how the policy process the network \ - output. For example, in PPO, the return of the network might be \ - ``(mu, sigma), state`` for Gaussian policy. The "policy" can be a Batch of \ - torch.Tensor or other things, which will be stored in the replay buffer, and can \ - be accessed in the policy update process (e.g. in "policy.learn()", the \ - "batch.policy" is what you need). - - Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can - use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``, - for instance, loading and saving the model: - :: - - torch.save(policy.state_dict(), "policy.pth") - policy.load_state_dict(torch.load("policy.pth")) + The base class for reinforcement learning algorithms in Tianshou. + + An algorithm critically defines how to update the parameters of neural networks + based on a batch data, optionally applying pre-processing and post-processing to the data. + The actual update step is highly algorithm-specific and thus is defined in subclasses. """ _STATE_DICT_KEY_OPTIMIZERS = "_optimizers" @@ -613,7 +595,7 @@ def _update( buffer: ReplayBuffer | None, update_with_batch_fn: Callable[[RolloutBatchProtocol], TTrainingStats], ) -> TTrainingStats: - """Performs an update step. + """Orchestrates an update step. An update involves three algorithm-specific sub-steps: * pre-processing of the batch, @@ -628,8 +610,10 @@ def _update( means it will extract all the data from the buffer, but it will be shuffled first. :param buffer: the corresponding replay buffer. + :param update_with_batch_fn: the function to call for the actual update step, + which is algorithm-specific and thus provided by the subclass. - :return: A dataclass object containing the data needed to be logged (e.g., loss) + :return: A dataclass object containing data to be logged (e.g., loss) """ if not self.policy.is_within_training_step: raise RuntimeError( @@ -846,6 +830,8 @@ class OnPolicyAlgorithm( Generic[TPolicy, TTrainingStats], ABC, ): + """Base class for on-policy RL algorithms.""" + def create_trainer(self, params: "OnPolicyTrainerParams") -> "OnPolicyTrainer": from tianshou.trainer.base import OnPolicyTrainer @@ -855,7 +841,15 @@ def create_trainer(self, params: "OnPolicyTrainerParams") -> "OnPolicyTrainer": def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> TTrainingStats: - pass + """Performs an update step based on the given batch of data, updating the network + parameters. + + :param batch: the batch of data + :param batch_size: the minibatch size for gradient updates + :param repeat: the number of times to repeat the update over the whole batch + :return: a dataclas object containing statistics on the learning process, including + the data needed to be logged (e.g. loss values). + """ def update( self, @@ -876,6 +870,8 @@ class OffPolicyAlgorithm( Generic[TPolicy, TTrainingStats], ABC, ): + """Base class for off-policy RL algorithms.""" + def create_trainer(self, params: "OffPolicyTrainerParams") -> "OffPolicyTrainer": from tianshou.trainer.base import OffPolicyTrainer @@ -910,6 +906,8 @@ class OfflineAlgorithm( Generic[TPolicy, TTrainingStats], ABC, ): + """Base class for offline RL algorithms.""" + def process_buffer(self, buffer: TBuffer) -> TBuffer: """Pre-process the replay buffer to prepare for offline learning, e.g. to add new keys.""" return buffer @@ -957,6 +955,13 @@ class OnPolicyWrapperAlgorithm( Generic[TPolicy, TTrainingStats, TWrappedAlgorthmTrainingStats], ABC, ): + """ + Base class for an on-policy algorithm that is a wrapper around another algorithm. + + It applies the wrapped algorithm's pre-processing and post-processing methods + and chains the update method of the wrapped algorithm with the wrapper's own update method. + """ + def __init__( self, wrapped_algorithm: OnPolicyAlgorithm[TPolicy, TWrappedAlgorthmTrainingStats], @@ -985,7 +990,7 @@ def postprocess_batch( def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int ) -> TTrainingStats: - """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update .""" + """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update.""" original_stats = self.wrapped_algorithm._update_with_batch( batch, batch_size=batch_size, repeat=repeat ) @@ -1007,6 +1012,13 @@ class OffPolicyWrapperAlgorithm( Generic[TPolicy, TTrainingStats, TWrappedAlgorthmTrainingStats], ABC, ): + """ + Base class for an off-policy algorithm that is a wrapper around another algorithm. + + It applies the wrapped algorithm's pre-processing and post-processing methods + and chains the update method of the wrapped algorithm with the wrapper's own update method. + """ + def __init__( self, wrapped_algorithm: OffPolicyAlgorithm[TPolicy, TWrappedAlgorthmTrainingStats], From 793297a0782b64d401da192b2d18ff82dd227a09 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 18:46:21 +0200 Subject: [PATCH 095/230] v2: Improve description of parameter 'gae_lambda' --- tianshou/highlevel/params/policy_params.py | 14 +++++++++--- tianshou/policy/base.py | 25 +++++++++++++++++++--- tianshou/policy/imitation/gail.py | 12 ++++++++++- tianshou/policy/modelfree/a2c.py | 24 +++++++++++++++++++-- tianshou/policy/modelfree/npg.py | 12 ++++++++++- tianshou/policy/modelfree/ppo.py | 12 ++++++++++- tianshou/policy/modelfree/trpo.py | 12 ++++++++++- 7 files changed, 99 insertions(+), 12 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 41172bbd7..057fedab2 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -317,9 +317,17 @@ def _get_param_transformers(self) -> list[ParamTransformer]: class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): gae_lambda: float = 0.95 """ - determines the blend between Monte Carlo and one-step temporal difference (TD) estimates of the advantage - function in general advantage estimation (GAE). - A value of 0 gives a fully TD-based estimate; lambda=1 gives a fully Monte Carlo estimate. + the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. """ max_batchsize: int = 256 """the maximum size of the batch when computing general advantage estimation (GAE)""" diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 993d591a6..5e93580bf 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -693,8 +693,17 @@ def compute_episodic_return( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param gae_lambda: the parameter for Generalized Advantage Estimation, - should be in [0, 1]. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :return: two numpy arrays (returns, advantage) with each shape (bsz, ). """ @@ -1118,7 +1127,17 @@ def _gae_return( :param rew: rewards in an episode, i.e. $r_t$ :param end_flag: boolean array indicating whether the episode is done :param gamma: the discount factor in [0, 1] for future rewards. - :param gae_lambda: lambda parameter for GAE, controlling the bias-variance tradeoff + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :return: """ returns = np.zeros(rew.shape) diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 1c5be16c6..75a75f16b 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -82,7 +82,17 @@ def __init__( :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :param max_batchsize: the maximum size of the batch when computing GAE. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 2c12e3e6f..560c5f6e1 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -57,7 +57,17 @@ def __init__( Pass False for algorithms that shall update only the critic via the optimizer. :param max_grad_norm: the maximum gradient norm for gradient clipping; if None, gradient clipping is not applied - :param gae_lambda: in [0, 1], param for generalized advantage estimation (GAE). + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :param max_batchsize: the maximum size of the batch when computing GAE. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. @@ -152,7 +162,17 @@ def __init__( :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :param max_batchsize: the maximum size of the batch when computing GAE. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index d83410e2d..335698953 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -56,7 +56,17 @@ def __init__( :param actor_step_size: step size for actor update in natural gradient direction. :param advantage_normalization: whether to do per mini-batch advantage normalization. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :param max_batchsize: the maximum size of the batch when computing GAE. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index dc602bd56..b83f9215a 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -92,7 +92,17 @@ def __init__( :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. :param max_grad_norm: clipping gradients in back propagation. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :param max_batchsize: the maximum size of the batch when computing GAE. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index e3667857f..e3f83e2d5 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -56,7 +56,17 @@ def __init__( :param actor_step_size: step size for actor update in natural gradient direction. :param advantage_normalization: whether to do per mini-batch advantage normalization. - :param gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). + Controls the bias-variance tradeoff in advantage estimates, acting as a + weighting factor for combining different n-step advantage estimators. Higher values + (closer to 1) reduce bias but increase variance by giving more weight to longer + trajectories, while lower values (closer to 0) reduce variance but increase bias + by relying more on the immediate TD error and value function estimates. At λ=0, + GAE becomes equivalent to the one-step TD error (high bias, low variance); at λ=1, + it becomes equivalent to Monte Carlo advantage estimation (low bias, high variance). + Intermediate values create a weighted average of n-step returns, with exponentially + decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for + most policy gradient methods. :param max_batchsize: the maximum size of the batch when computing GAE. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. From 2796c5b2b3a13ce5c7365c8bd30d537ef4f4bc57 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 18:53:27 +0200 Subject: [PATCH 096/230] v2: Improve description of parameter 'actor_step_size' --- tianshou/highlevel/params/policy_params.py | 10 +++++++++- tianshou/policy/modelfree/npg.py | 8 +++++++- tianshou/policy/modelfree/trpo.py | 8 +++++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 057fedab2..d90f88587 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -403,7 +403,15 @@ class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation): optim_critic_iters: int = 5 """number of times to optimize critic network per update.""" actor_step_size: float = 0.5 - """step size for actor update in natural gradient direction""" + """ + the scalar multiplier for policy updates in the natural gradient direction. + Controls how far the policy parameters move in the calculated direction + during each update. Higher values allow for faster learning but may cause instability + or policy deterioration; lower values provide more stable but slower learning. Unlike + regular policy gradients, natural gradients already account for the local geometry of + the parameter space, making this step size more robust to different parameterizations. + Typically set between 0.1 and 1.0 for most reinforcement learning tasks. + """ advantage_normalization: bool = True """whether to do per mini-batch advantage normalization.""" diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 335698953..7f2d298a1 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -53,7 +53,13 @@ def __init__( :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the critic network. :param optim_critic_iters: Number of times to optimize critic network per update. - :param actor_step_size: step size for actor update in natural gradient direction. + :param actor_step_size: the scalar multiplier for policy updates in the natural gradient direction. + Controls how far the policy parameters move in the calculated direction + during each update. Higher values allow for faster learning but may cause instability + or policy deterioration; lower values provide more stable but slower learning. Unlike + regular policy gradients, natural gradients already account for the local geometry of + the parameter space, making this step size more robust to different parameterizations. + Typically set between 0.1 and 1.0 for most reinforcement learning tasks. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index e3f83e2d5..faf1ca36d 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -53,7 +53,13 @@ def __init__( constraints are not met. :param max_backtracks: Max number of backtracking times in linesearch. :param optim_critic_iters: Number of times to optimize critic network per update. - :param actor_step_size: step size for actor update in natural gradient direction. + :param actor_step_size: the scalar multiplier for policy updates in the natural gradient direction. + Controls how far the policy parameters move in the calculated direction + during each update. Higher values allow for faster learning but may cause instability + or policy deterioration; lower values provide more stable but slower learning. Unlike + regular policy gradients, natural gradients already account for the local geometry of + the parameter space, making this step size more robust to different parameterizations. + Typically set between 0.1 and 1.0 for most reinforcement learning tasks. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). From 6a3646baa2cd75c3396711828aac1fb67d978860 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 19:01:15 +0200 Subject: [PATCH 097/230] v2: Improve description of parameter 'max_batchsize' --- tianshou/highlevel/params/policy_params.py | 7 ++++++- tianshou/policy/imitation/gail.py | 7 ++++++- tianshou/policy/modelfree/a2c.py | 7 ++++++- tianshou/policy/modelfree/npg.py | 9 +++++++-- tianshou/policy/modelfree/trpo.py | 7 ++++++- 5 files changed, 31 insertions(+), 6 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index d90f88587..989b10007 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -330,7 +330,12 @@ class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): most policy gradient methods. """ max_batchsize: int = 256 - """the maximum size of the batch when computing general advantage estimation (GAE)""" + """the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data.""" def _get_param_transformers(self) -> list[ParamTransformer]: return [] diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 75a75f16b..b5b088d0b 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -93,7 +93,12 @@ def __init__( Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. - :param max_batchsize: the maximum size of the batch when computing GAE. + :param max_batchsize: the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 560c5f6e1..6ee5a6994 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -68,7 +68,12 @@ def __init__( Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. - :param max_batchsize: the maximum size of the batch when computing GAE. + :param max_batchsize: the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 7f2d298a1..c79814733 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -52,7 +52,7 @@ def __init__( :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the critic network. - :param optim_critic_iters: Number of times to optimize critic network per update. + :param optim_critic_iters: the number of times to optimize critic network per update. :param actor_step_size: the scalar multiplier for policy updates in the natural gradient direction. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability @@ -73,7 +73,12 @@ def __init__( Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. - :param max_batchsize: the maximum size of the batch when computing GAE. + :param max_batchsize: the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index faf1ca36d..44eb828d4 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -73,7 +73,12 @@ def __init__( Intermediate values create a weighted average of n-step returns, with exponentially decaying weights for longer-horizon returns. Typically set between 0.9 and 0.99 for most policy gradient methods. - :param max_batchsize: the maximum size of the batch when computing GAE. + :param max_batchsize: the maximum number of samples to process at once when computing + generalized advantage estimation (GAE) and value function predictions. + Controls memory usage by breaking large batches into smaller chunks processed sequentially. + Higher values may increase speed but require more GPU/CPU memory; lower values + reduce memory requirements but may increase computation time. Should be adjusted + based on available hardware resources and total batch size of your training data. :param gamma: the discount factor in [0, 1] for future rewards. This determines how much future rewards are valued compared to immediate ones. Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" From 3fc22e9eaa6bdecaa0f09b89585246b93c3d964b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 19:04:38 +0200 Subject: [PATCH 098/230] v2: Improve description of parameter 'optim_critic_iters' --- tianshou/highlevel/params/policy_params.py | 12 ++++++++++-- tianshou/policy/modelfree/npg.py | 9 ++++++++- tianshou/policy/modelfree/trpo.py | 9 ++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 989b10007..11c17de6d 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -331,7 +331,7 @@ class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): """ max_batchsize: int = 256 """the maximum number of samples to process at once when computing - generalized advantage estimation (GAE) and value function predictions. + generalized advantage estimation (GAE) and value function predictions. Controls memory usage by breaking large batches into smaller chunks processed sequentially. Higher values may increase speed but require more GPU/CPU memory; lower values reduce memory requirements but may increase computation time. Should be adjusted @@ -406,7 +406,15 @@ class PPOParams(A2CParams): @dataclass(kw_only=True) class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation): optim_critic_iters: int = 5 - """number of times to optimize critic network per update.""" + """ + the number of optimization steps performed on the critic network for each policy (actor) update. + Controls the learning rate balance between critic and actor. + Higher values prioritize critic accuracy by training the value function more + extensively before each policy update, which can improve stability but slow down + training. Lower values maintain a more even learning pace between policy and value + function but may lead to less reliable advantage estimates. + Typically set between 1 and 10, depending on the complexity of the value function. + """ actor_step_size: float = 0.5 """ the scalar multiplier for policy updates in the natural gradient direction. diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index c79814733..75d85e44b 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -52,7 +52,14 @@ def __init__( :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the critic network. - :param optim_critic_iters: the number of times to optimize critic network per update. + :param optim_critic_iters: the number of optimization steps performed on the critic network + for each policy (actor) update. + Controls the learning rate balance between critic and actor. + Higher values prioritize critic accuracy by training the value function more + extensively before each policy update, which can improve stability but slow down + training. Lower values maintain a more even learning pace between policy and value + function but may lead to less reliable advantage estimates. + Typically set between 1 and 10, depending on the complexity of the value function. :param actor_step_size: the scalar multiplier for policy updates in the natural gradient direction. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 44eb828d4..f2e13abfc 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -52,7 +52,14 @@ def __init__( :param backtrack_coeff: Coefficient to be multiplied by step size when constraints are not met. :param max_backtracks: Max number of backtracking times in linesearch. - :param optim_critic_iters: Number of times to optimize critic network per update. + :param optim_critic_iters: the number of optimization steps performed on the critic network + for each policy (actor) update. + Controls the learning rate balance between critic and actor. + Higher values prioritize critic accuracy by training the value function more + extensively before each policy update, which can improve stability but slow down + training. Lower values maintain a more even learning pace between policy and value + function but may lead to less reliable advantage estimates. + Typically set between 1 and 10, depending on the complexity of the value function. :param actor_step_size: the scalar multiplier for policy updates in the natural gradient direction. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability From d1b2e32e144677fdc3e4582785aa7d1bfc0bd9d9 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 19:55:01 +0200 Subject: [PATCH 099/230] v2: Improve description of parameter 'dist_fn' --- tianshou/policy/modelfree/pg.py | 42 ++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index b26d387a2..5bd46869f 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -46,7 +46,7 @@ [tuple[torch.Tensor, torch.Tensor]], torch.distributions.Distribution, ] -TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical] +TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Distribution] TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete @@ -74,14 +74,20 @@ def __init__( ) -> None: """ :param actor: the actor network following the rules: - If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "discrete"`: (`s_B` -> `action_values_BA`). If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). - :param dist_fn: distribution class for computing the action. - Maps model_output -> distribution. Typically, a Gaussian distribution - taking `model_output=mean,std` as input for continuous action spaces, - or a categorical distribution taking `model_output=logits` - for discrete action spaces. Note that as user, you are responsible - for ensuring that the distribution is compatible with the action space. + :param dist_fn: the function/type which creates a distribution from the actor output, + i.e. it maps the tensor(s) generated by the actor to a torch distribution. + For continuous action spaces, the output is typically a pair of tensors + (mean, std) and the distribution is a Gaussian distribution. + For discrete action spaces, the output is typically a tensor of unnormalized + log probabilities ("logits" in PyTorch terminology) or a tensor of probabilities + which can serve as the parameters of a Categorical distribution. + Note that if the actor uses softmax activation in its final layer, it will produce + probabilities, whereas if it uses no activation, it can be considered as producing + "logits". + As a user, you are responsible for ensuring that the distribution + is compatible with the output of the actor model and the action space. :param deterministic_eval: if True, will use deterministic action (the dist's mode) instead of stochastic one during evaluation. Does not affect training. :param action_space: env's action space. @@ -157,9 +163,16 @@ def __init__( ) -> None: """ :param actor: the actor network following the rules: (`s_B` -> `dist_input_BD`). - :param dist_fn: distribution class for computing the action. - Maps model_output -> distribution, typically a categorical distribution - taking `model_output=logits`. + :param dist_fn: the function/type which creates a distribution from the actor output, + i.e. it maps the tensor(s) generated by the actor to a torch distribution. + For discrete action spaces, the output is typically a tensor of unnormalized + log probabilities ("logits" in PyTorch terminology) or a tensor of probabilities + which serve as the parameters of a Categorical distribution. + Note that if the actor uses softmax activation in its final layer, it will produce + probabilities, whereas if it uses no activation, it can be considered as producing + "logits". + As a user, you are responsible for ensuring that the distribution + is compatible with the output of the actor model and the action space. :param deterministic_eval: if True, will use deterministic action (the dist's mode) instead of stochastic one during evaluation. Does not affect training. :param action_space: the environment's (discrete) action space. @@ -252,12 +265,7 @@ def add_discounted_returns( class Reinforce(OnPolicyAlgorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats]): - """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. - """ + """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm.""" def __init__( self, From 2e7a3b98040111717ddf56aa9e9b38b5070abf60 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 20:03:09 +0200 Subject: [PATCH 100/230] v2: Standardise descriptions for 'action_space' and 'observation_space' --- tianshou/policy/imitation/base.py | 4 ++-- tianshou/policy/imitation/bcq.py | 2 +- tianshou/policy/modelbased/psrl.py | 4 ++-- tianshou/policy/modelfree/ddpg.py | 4 ++-- tianshou/policy/modelfree/discrete_sac.py | 4 ++-- tianshou/policy/modelfree/pg.py | 4 ++-- tianshou/policy/modelfree/redq.py | 4 ++-- tianshou/policy/modelfree/sac.py | 4 ++-- tianshou/utils/torch_utils.py | 2 +- 9 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index a76806174..2862305ea 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -49,8 +49,8 @@ def __init__( ): """ :param actor: a model following the rules (s -> a) - :param action_space: Env's action_space. - :param observation_space: Env's observation space. + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. :param action_bound_method: method to bound action to range [-1, 1]. diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index b5a67be48..dfee050a4 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -50,7 +50,7 @@ def __init__( :param vae: the VAE network, generating actions similar to those in batch. :param forward_sampled_times: the number of sampled actions in forward function. The policy samples many actions and takes the action with the max value. - :param observation_space: Env's observation space. + :param observation_space: the environment's observation space :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. :param action_bound_method: method to bound action to range [-1, 1]. diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index fa8beb6f7..ada87c8c6 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -182,9 +182,9 @@ def __init__( with shape (n_state, n_action). :param rew_std_prior: standard deviations of the normal priors of rewards, with shape (n_state, n_action). - :param action_space: Env's action_space. + :param action_space: the environment's action_space. :param epsilon: for precision control in value iteration. - :param observation_space: Env's observation space. + :param observation_space: the environment's observation space """ super().__init__( action_space=action_space, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 7a1102c1f..a038116fe 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -57,8 +57,8 @@ def __init__( :param exploration_noise: noise model for adding noise to continuous actions for exploration. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). - :param action_space: Env's action space. - :param observation_space: Env's observation space. + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. :param action_bound_method: method to bound action to range [-1, 1]. diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 887bbbdae..36cc51412 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -44,8 +44,8 @@ def __init__( use the most probable action instead of sampling an action from the categorical distribution. This setting does not affect data collection for training, where actions are always sampled. - :param action_space: the action space of the environment - :param observation_space: the observation space of the environment + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space """ assert isinstance(action_space, gym.spaces.Discrete) super().__init__( diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 5bd46869f..9d6a264ff 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -90,8 +90,8 @@ def __init__( is compatible with the output of the actor model and the action space. :param deterministic_eval: if True, will use deterministic action (the dist's mode) instead of stochastic one during evaluation. Does not affect training. - :param action_space: env's action space. - :param observation_space: Env's observation space. + :param action_space: the environment's action space. + :param observation_space: the environment's observation space. :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. :param action_bound_method: method to bound action to range [-1, 1]. diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 72a62eff1..4399eeb8e 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -49,12 +49,12 @@ def __init__( """ :param actor: The actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> model_output) - :param action_space: Env's action space. + :param action_space: the environment's action_space. :param deterministic_eval: whether, in evaluation/inference mode, to use always use the most probable action instead of sampling an action from the categorical distribution. This setting does not affect data collection for training, where actions are always sampled. - :param observation_space: Env's observation space. + :param observation_space: the environment's observation space :param action_scaling: if True, scale the action from [-1, 1] to the range of action_space. Only used if the action_space is continuous. :param action_bound_method: method to bound action to range [-1, 1]. diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 289b35931..e681af3ca 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -79,8 +79,8 @@ def __init__( or empty string for no bounding. Only used if the action_space is continuous. This parameter is ignored in SAC, which used tanh squashing after sampling unbounded from the gaussian policy (as in (arXiv 1801.01290): Equation 21.). - :param action_space: the action space of the environment - :param observation_space: the observation space of the environment + :param action_space: the environment's action_space. + :param observation_space: the environment's observation space """ super().__init__( exploration_noise=exploration_noise, diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 1c001a544..37df7fe46 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -61,7 +61,7 @@ def create_uniform_action_dist( ) -> dist.Uniform | dist.Categorical: """Create a Distribution such that sampling from it is equivalent to sampling a batch with `action_space.sample()`. - :param action_space: The action space of the environment. + :param action_space: the environment's action_space. :param batch_size: The number of environments or batch size for sampling. :return: A PyTorch distribution for sampling actions. """ From 785524ca6fc3310d1686ef82993d95a94f299b9c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 20:24:51 +0200 Subject: [PATCH 101/230] v2: Improve descriptions of parameters 'action_scaling' and 'action_bound_method' --- tianshou/highlevel/params/policy_params.py | 29 +++++++++++++++-- tianshou/policy/base.py | 29 ++++++++++++++--- tianshou/policy/imitation/base.py | 29 ++++++++++++++--- tianshou/policy/imitation/bcq.py | 29 ++++++++++++++--- tianshou/policy/modelfree/ddpg.py | 28 ++++++++++++++-- tianshou/policy/modelfree/pg.py | 37 +++++++++++++++++----- tianshou/policy/modelfree/redq.py | 29 ++++++++++++++--- tianshou/policy/modelfree/sac.py | 30 +++++++++++++++--- 8 files changed, 206 insertions(+), 34 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 11c17de6d..ce4ec902b 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -223,10 +223,35 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) class ParamsMixinActionScaling(GetParamTransformersProtocol): action_scaling: bool | Literal["default"] = "default" - """whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces""" + """ + flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + """ action_bound_method: Literal["clip", "tanh"] | None = "clip" """ - method to bound action to range [-1, 1]. Only used if the action_space is continuous. + the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. """ def _get_param_transformers(self) -> list[ParamTransformer]: diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5e93580bf..feca3d7f5 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -164,10 +164,31 @@ def __init__( """ :param action_space: the environment's action_space. :param observation_space: the environment's observation space - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. """ allowed_action_bound_methods = ("clip", "tanh") if ( diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 2862305ea..09ee55c26 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -51,10 +51,31 @@ def __init__( :param actor: a model following the rules (s -> a) :param action_space: the environment's action_space. :param observation_space: the environment's observation space - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. """ super().__init__( action_space=action_space, diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index dfee050a4..dcb8051f0 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -51,10 +51,31 @@ def __init__( :param forward_sampled_times: the number of sampled actions in forward function. The policy samples many actions and takes the action with the max value. :param observation_space: the environment's observation space - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. """ super().__init__( action_space=action_space, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index a038116fe..d0fba91aa 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -59,9 +59,31 @@ def __init__( "default" is equivalent to GaussianNoise(sigma=0.1). :param action_space: the environment's action_space. :param observation_space: the environment's observation space - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. """ super().__init__( action_space=action_space, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 9d6a264ff..87a36797d 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -92,10 +92,31 @@ def __init__( instead of stochastic one during evaluation. Does not affect training. :param action_space: the environment's action space. :param observation_space: the environment's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. """ super().__init__( action_space=action_space, @@ -105,10 +126,10 @@ def __init__( ) if action_scaling and not np.isclose(actor.max_action, 1.0): warnings.warn( - "action_scaling and action_bound_method are only intended" - "to deal with unbounded model action space, but find actor model" - f"bound action space with max_action={actor.max_action}." - "Consider using unbounded=True option of the actor model," + "action_scaling and action_bound_method are only intended " + "to deal with unbounded model action space, but find actor model " + f"bound action space with max_action={actor.max_action}. " + "Consider using unbounded=True option of the actor model, " "or set action_scaling to False and action_bound_method to None.", ) self.actor = actor diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 4399eeb8e..a821b0cef 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -55,10 +55,31 @@ def __init__( categorical distribution. This setting does not affect data collection for training, where actions are always sampled. :param observation_space: the environment's observation space - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. - :param action_bound_method: method to bound action to range [-1, 1]. - Only used if the action_space is continuous. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. """ super().__init__( exploration_noise=exploration_noise, diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index e681af3ca..e55738ed0 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -72,11 +72,31 @@ def __init__( :param deterministic_eval: whether to use deterministic action (mode of Gaussian policy) in evaluation mode instead of stochastic action sampled by the policy. Does not affect training. - :param action_scaling: whether to map actions from range [-1, 1] - to range[action_spaces.low, action_spaces.high]. - :param action_bound_method: method to bound action to range [-1, 1], - can be either "clip" (for simply clipping the action) - or empty string for no bounding. Only used if the action_space is continuous. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. + :param action_bound_method: the method used for bounding actions in continuous action spaces + to the range [-1, 1] before scaling them to the environment's action space (provided + that `action_scaling` is enabled). + This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None + for discrete spaces. + When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this + range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly + constrains outputs to [-1, 1] while preserving gradients. + The choice of bounding method affects both training dynamics and exploration behavior. + Clipping provides hard boundaries but may create plateau regions in the gradient + landscape, while tanh provides smoother transitions but can compress sensitivity + near the boundaries. + Should be set to None if the actor model inherently produces bounded outputs. + Typically used together with `action_scaling=True`. This parameter is ignored in SAC, which used tanh squashing after sampling unbounded from the gaussian policy (as in (arXiv 1801.01290): Equation 21.). :param action_space: the environment's action_space. From 796022d1f3a654702834c46486becc6149393ab3 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 20:33:14 +0200 Subject: [PATCH 102/230] v2: Improve description of parameter 'deterministic_eval' --- tianshou/highlevel/params/policy_params.py | 51 +++++++++++++--------- tianshou/policy/modelfree/discrete_sac.py | 17 ++++++-- tianshou/policy/modelfree/pg.py | 30 +++++++++++-- tianshou/policy/modelfree/redq.py | 17 ++++++-- tianshou/policy/modelfree/sac.py | 16 +++++-- 5 files changed, 95 insertions(+), 36 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index ce4ec902b..777e41b67 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -316,17 +316,38 @@ class ParamsMixinTau: @dataclass(kw_only=True) -class PGParams(Params, ParamsMixinGamma, ParamsMixinActionScaling, ParamsMixinSingleModel): +class ParamsMixinDeterministicEval: + deterministic_eval: bool = False + """ + flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. + """ + + +@dataclass(kw_only=True) +class PGParams( + Params, + ParamsMixinGamma, + ParamsMixinActionScaling, + ParamsMixinSingleModel, + ParamsMixinDeterministicEval, +): reward_normalization: bool = False """ if True, will normalize the returns by subtracting the running mean and dividing by the running standard deviation. """ - deterministic_eval: bool = False - """ - whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. - Does not affect training. - """ def __setstate__(self, state: dict[str, Any]) -> None: setstate(PGParams, self, state, removed_properties=["dist_fn"]) @@ -515,6 +536,7 @@ class _SACParams( ParamsMixinActorAndDualCritics, ParamsMixinEstimationStep, ParamsMixinTau, + ParamsMixinDeterministicEval, ): alpha: float | AutoAlphaFactory = 0.2 """ @@ -534,11 +556,6 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling): - deterministic_eval: bool = True - """ - whether to use deterministic action (mode of Gaussian policy) in evaluation mode instead of stochastic - action sampled from the distribution. Does not affect training.""" - def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) @@ -548,10 +565,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) class DiscreteSACParams(_SACParams): - deterministic_eval: bool = True - """ - whether to use deterministic action (most probably action) in evaluation mode instead of stochastic - action sampled from the distribution. Does not affect training.""" + pass @dataclass(kw_only=True) @@ -638,7 +652,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) -class REDQParams(DDPGParams): +class REDQParams(DDPGParams, ParamsMixinDeterministicEval): ensemble_size: int = 10 """the number of sub-networks in the critic ensemble""" subset_size: int = 2 @@ -653,11 +667,6 @@ class REDQParams(DDPGParams): """ actor_delay: int = 20 """the number of critic updates before an actor update""" - deterministic_eval: bool = True - """ - whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. - Does not affect training. - """ target_mode: Literal["mean", "min"] = "min" def _get_param_transformers(self) -> list[ParamTransformer]: diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 36cc51412..e60801588 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -40,10 +40,19 @@ def __init__( """ :param actor: the actor network following the rules (s -> dist_input_BD), where the distribution input is for a `Categorical` distribution. - :param deterministic_eval: whether, in evaluation/inference mode, to use always - use the most probable action instead of sampling an action from the - categorical distribution. This setting does not affect data collection - for training, where actions are always sampled. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. :param action_space: the environment's action_space. :param observation_space: the environment's observation space """ diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 87a36797d..9b95205ad 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -88,8 +88,19 @@ def __init__( "logits". As a user, you are responsible for ensuring that the distribution is compatible with the output of the actor model and the action space. - :param deterministic_eval: if True, will use deterministic action (the dist's mode) - instead of stochastic one during evaluation. Does not affect training. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. :param action_space: the environment's action space. :param observation_space: the environment's observation space. :param action_scaling: flag indicating whether, for continuous action spaces, actions @@ -194,8 +205,19 @@ def __init__( "logits". As a user, you are responsible for ensuring that the distribution is compatible with the output of the actor model and the action space. - :param deterministic_eval: if True, will use deterministic action (the dist's mode) - instead of stochastic one during evaluation. Does not affect training. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. :param action_space: the environment's (discrete) action space. :param observation_space: the environment's observation space. """ diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index a821b0cef..b739053bd 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -50,10 +50,19 @@ def __init__( :param actor: The actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> model_output) :param action_space: the environment's action_space. - :param deterministic_eval: whether, in evaluation/inference mode, to use always - use the most probable action instead of sampling an action from the - categorical distribution. This setting does not affect data collection - for training, where actions are always sampled. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. :param observation_space: the environment's observation space :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index e55738ed0..49ead8732 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -69,9 +69,19 @@ def __init__( :param exploration_noise: add noise to action for exploration. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). - :param deterministic_eval: whether to use deterministic action - (mode of Gaussian policy) in evaluation mode instead of stochastic - action sampled by the policy. Does not affect training. + :param deterministic_eval: flag indicating whether the policy should use deterministic + actions (using the mode of the action distribution) instead of stochastic ones + (using random sampling) during evaluation. + When enabled, the policy will always select the most probable action according to + the learned distribution during evaluation phases, while still using stochastic + sampling during training. This creates a clear distinction between exploration + (training) and exploitation (evaluation) behaviors. + Deterministic actions are generally preferred for final deployment and reproducible + evaluation as they provide consistent behavior, reduce variance in performance + metrics, and are more interpretable for human observers. + Note that this parameter only affects behavior when the policy is not within a + training step. When collecting rollouts for training, actions remain stochastic + regardless of this setting to maintain proper exploration behaviour. :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. From a59c3b517cdd57694864af9e50d18f1a32134885 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 20:56:06 +0200 Subject: [PATCH 103/230] v2: Rename PGParams, PGExperimentBuilder -> Reinforce* (high-level API) --- examples/mujoco/mujoco_reinforce_hl.py | 10 +++++----- test/highlevel/test_experiment_builder.py | 6 +++--- tianshou/highlevel/algorithm.py | 4 ++-- tianshou/highlevel/experiment.py | 8 ++++---- tianshou/highlevel/params/policy_params.py | 10 +++------- 5 files changed, 17 insertions(+), 21 deletions(-) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index accf7cbed..27af61efd 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -12,10 +12,10 @@ from tianshou.highlevel.config import OnPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, - PGExperimentBuilder, + ReinforceExperimentBuilder, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_params import PGParams +from tianshou.highlevel.params.policy_params import ReinforceParams def main( @@ -57,9 +57,9 @@ def main( ) experiment = ( - PGExperimentBuilder(env_factory, experiment_config, training_config) - .with_pg_params( - PGParams( + ReinforceExperimentBuilder(env_factory, experiment_config, training_config) + .with_reinforce_params( + ReinforceParams( discount_factor=gamma, action_bound_method=action_bound_method, reward_normalization=rew_norm, diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 29eaa8074..86c298abe 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -16,9 +16,9 @@ IQNExperimentBuilder, OffPolicyExperimentBuilder, OnPolicyExperimentBuilder, - PGExperimentBuilder, PPOExperimentBuilder, REDQExperimentBuilder, + ReinforceExperimentBuilder, SACExperimentBuilder, TD3ExperimentBuilder, TRPOExperimentBuilder, @@ -57,7 +57,7 @@ def create_training_config( # NPGExperimentBuilder, # TODO test fails non-deterministically REDQExperimentBuilder, TRPOExperimentBuilder, - PGExperimentBuilder, + ReinforceExperimentBuilder, ], ) def test_experiment_builder_continuous_default_params(builder_cls: type[ExperimentBuilder]) -> None: @@ -83,7 +83,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime @pytest.mark.parametrize( "builder_cls", [ - PGExperimentBuilder, + ReinforceExperimentBuilder, PPOExperimentBuilder, A2CExperimentBuilder, DQNExperimentBuilder, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 2a35049ef..5c5b53dd0 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -35,9 +35,9 @@ ParamsMixinActorAndDualCritics, ParamsMixinSingleModel, ParamTransformerData, - PGParams, PPOParams, REDQParams, + ReinforceParams, SACParams, TD3Params, TRPOParams, @@ -283,7 +283,7 @@ def create_trainer( class ReinforceAlgorithmFactory(OnPolicyAlgorithmFactory): def __init__( self, - params: PGParams, + params: ReinforceParams, training_config: OnPolicyTrainingConfig, actor_factory: ActorFactory, optim_factory: OptimizerFactoryFactory, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 629159005..3da0344bd 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -93,9 +93,9 @@ DQNParams, IQNParams, NPGParams, - PGParams, PPOParams, REDQParams, + ReinforceParams, SACParams, TD3Params, TRPOParams, @@ -1046,7 +1046,7 @@ def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory: return self.critic_ensemble_factory -class PGExperimentBuilder( +class ReinforceExperimentBuilder( OnPolicyExperimentBuilder, _BuilderMixinActorFactory_ContinuousGaussian, ): @@ -1058,10 +1058,10 @@ def __init__( ): super().__init__(env_factory, experiment_config, training_config) _BuilderMixinActorFactory_ContinuousGaussian.__init__(self) - self._params: PGParams = PGParams() + self._params: ReinforceParams = ReinforceParams() self._env_config = None - def with_pg_params(self, params: PGParams) -> Self: + def with_reinforce_params(self, params: ReinforceParams) -> Self: self._params = params return self diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 777e41b67..f1fab5ee1 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -3,7 +3,6 @@ from dataclasses import asdict, dataclass from typing import Any, Literal, Protocol -from sensai.util.pickle import setstate from sensai.util.string import ToStringMixin from tianshou.exploration import BaseNoise @@ -336,7 +335,7 @@ class ParamsMixinDeterministicEval: @dataclass(kw_only=True) -class PGParams( +class ReinforceParams( Params, ParamsMixinGamma, ParamsMixinActionScaling, @@ -349,9 +348,6 @@ class PGParams( standard deviation. """ - def __setstate__(self, state: dict[str, Any]) -> None: - setstate(PGParams, self, state, removed_properties=["dist_fn"]) - def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) @@ -388,7 +384,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) -class A2CParams(PGParams, ParamsMixinGeneralAdvantageEstimation): +class A2CParams(ReinforceParams, ParamsMixinGeneralAdvantageEstimation): vf_coef: float = 0.5 """weight (coefficient) of the value loss in the loss function""" ent_coef: float = 0.01 @@ -450,7 +446,7 @@ class PPOParams(A2CParams): @dataclass(kw_only=True) -class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation): +class NPGParams(ReinforceParams, ParamsMixinGeneralAdvantageEstimation): optim_critic_iters: int = 5 """ the number of optimization steps performed on the critic network for each policy (actor) update. From 039a04aef4b56a0b4e2f47983ba9b5481bdc165c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sun, 4 May 2025 21:50:48 +0200 Subject: [PATCH 104/230] v2: Remove obsolete comments/docstrings --- tianshou/highlevel/optim.py | 2 -- tianshou/policy/modelfree/td3.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 3a63cd5f0..66e6d154c 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -47,8 +47,6 @@ def create_optimizer_factory(self, lr: float) -> OptimizerFactory: class OptimizerFactoryFactoryAdam(OptimizerFactoryFactory): - # Note: currently used as default optimizer - # values should be kept in sync with `ExperimentBuilder.with_optim_factory_default` def __init__( self, betas: tuple[float, float] = (0.9, 0.999), diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 274bdf110..3ccfefb0c 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -82,8 +82,6 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() """ super().__init__( policy=policy, From 91760fa88c3fd7f7a0a5f7c00c8dcf27b769c91c Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 4 May 2025 21:41:55 +0200 Subject: [PATCH 105/230] v2: policy_wrapper: duplicated call to instantiation, enabling more explicit type checking --- tianshou/highlevel/params/policy_wrapper.py | 27 ++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 2b49fcc08..41469cf7c 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -70,18 +70,23 @@ def create_wrapped_algorithm( ) optim_factory = self.optim_factory or optim_factory_default icm_optim = optim_factory.create_optimizer_factory(lr=self.lr) - cls: type[ICMOffPolicyWrapper] | type[ICMOnPolicyWrapper] if isinstance(algorithm, OffPolicyAlgorithm): - cls = ICMOffPolicyWrapper + return ICMOffPolicyWrapper( + wrapped_algorithm=algorithm, + model=icm_net, + optim=icm_optim, + lr_scale=self.lr_scale, + reward_scale=self.reward_scale, + forward_loss_weight=self.forward_loss_weight, + ).to(device) elif isinstance(algorithm, OnPolicyAlgorithm): - cls = ICMOnPolicyWrapper + return ICMOnPolicyWrapper( + wrapped_algorithm=algorithm, + model=icm_net, + optim=icm_optim, + lr_scale=self.lr_scale, + reward_scale=self.reward_scale, + forward_loss_weight=self.forward_loss_weight, + ).to(device) else: raise ValueError(f"{algorithm} is not supported by ICM") - return cls( - wrapped_algorithm=algorithm, - model=icm_net, - optim=icm_optim, - lr_scale=self.lr_scale, - reward_scale=self.reward_scale, - forward_loss_weight=self.forward_loss_weight, - ).to(device) From fcdb5ad7b7db81881cff691a016f0d80839e482e Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 4 May 2025 21:43:14 +0200 Subject: [PATCH 106/230] v2: Actor: make get_preprocess_net always return ModuleWithVectorOutput --- tianshou/utils/net/common.py | 8 ++++---- tianshou/utils/net/continuous.py | 4 ++-- tianshou/utils/net/discrete.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index afe403a4d..d0922ae47 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -277,7 +277,7 @@ def __init__( if use_dueling: # dueling DQN assert dueling_param is not None kwargs_update = { - "input_dim": self.model.output_dim, + "input_dim": model.output_dim, } # Important: don't change the original dict (e.g., don't use .update()) q_kwargs = {**dueling_param[0], **kwargs_update} @@ -664,7 +664,7 @@ def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: class Actor(ModuleWithVectorOutput, ABC): @abstractmethod - def get_preprocess_net(self) -> nn.Module: + def get_preprocess_net(self) -> ModuleWithVectorOutput: pass @abstractmethod @@ -704,8 +704,8 @@ def action_space(self) -> spaces.Box | spaces.Discrete: def space_info(self) -> ActionSpaceInfo: return self._space_info - def get_preprocess_net(self) -> nn.Module: - return nn.Identity() + def get_preprocess_net(self) -> ModuleWithVectorOutput: + return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) def get_output_dim(self) -> int: return self.space_info.action_dim diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index e5240e8d8..87fcba69e 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -57,7 +57,7 @@ def __init__( ) self.max_action = max_action - def get_preprocess_net(self) -> nn.Module: + def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess def get_output_dim(self) -> int: @@ -216,7 +216,7 @@ def __init__( self.max_action = max_action self._unbounded = unbounded - def get_preprocess_net(self) -> nn.Module: + def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess def forward( diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 7464c0d03..8da022c7e 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -53,7 +53,7 @@ def __init__( ) self.softmax_output = softmax_output - def get_preprocess_net(self) -> nn.Module: + def get_preprocess_net(self) -> ModuleWithVectorOutput: return self.preprocess def forward( From 4025b8210dbf1503e4443b684a8e5996c97f9b4a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 4 May 2025 21:44:13 +0200 Subject: [PATCH 107/230] v2: Imitation: removed unnecessary (and incorrect) generic TImitationTrainingStats The ImitationLearningAlgorithmMixin fixed the type of the training stats --- tianshou/policy/imitation/base.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index a76806174..ec80a8504 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, cast +from typing import Any, Literal, cast import gymnasium as gym import numpy as np @@ -34,9 +34,6 @@ class ImitationTrainingStats(TrainingStats): loss: float = 0.0 -TImitationTrainingStats = TypeVar("TImitationTrainingStats", bound=ImitationTrainingStats) - - class ImitationPolicy(Policy): def __init__( self, @@ -110,9 +107,8 @@ def _imitation_update( class OffPolicyImitationLearning( - OffPolicyAlgorithm[ImitationPolicy, TImitationTrainingStats], + OffPolicyAlgorithm[ImitationPolicy, ImitationTrainingStats], ImitationLearningAlgorithmMixin, - Generic[TImitationTrainingStats], ): """Implementation of off-policy vanilla imitation learning.""" @@ -134,14 +130,13 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TImitationTrainingStats: + ) -> ImitationTrainingStats: return self._imitation_update(batch, self.policy, self.optim) class OfflineImitationLearning( - OfflineAlgorithm[ImitationPolicy, TImitationTrainingStats], + OfflineAlgorithm[ImitationPolicy, ImitationTrainingStats], ImitationLearningAlgorithmMixin, - Generic[TImitationTrainingStats], ): """Implementation of offline vanilla imitation learning.""" @@ -163,5 +158,5 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TImitationTrainingStats: + ) -> ImitationTrainingStats: return self._imitation_update(batch, self.policy, self.optim) From f7f2dcd7f17cd060beccbc0c0b22f9c93cd1fc40 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 4 May 2025 21:45:08 +0200 Subject: [PATCH 108/230] v2: A bunch of small fixes in typing, param names and attribute references --- examples/atari/atari_c51.py | 8 ++++---- examples/atari/atari_dqn.py | 11 +++++------ examples/atari/atari_fqf.py | 8 ++++---- examples/atari/atari_iqn.py | 8 ++++---- examples/atari/atari_ppo.py | 22 +++++++++++++++------- examples/atari/atari_qrdqn.py | 8 ++++---- examples/atari/atari_rainbow.py | 8 ++++---- examples/atari/atari_sac.py | 17 +++++++++-------- examples/vizdoom/vizdoom_ppo.py | 7 ++++--- test/modelbased/test_dqn_icm.py | 6 +++--- test/pettingzoo/pistonball.py | 16 ++++++++++------ test/pettingzoo/pistonball_continuous.py | 19 +++++++++++-------- test/pettingzoo/tic_tac_toe.py | 15 ++++++++------- tianshou/__init__.py | 4 ++-- tianshou/highlevel/params/alpha.py | 4 ++-- tianshou/highlevel/params/policy_params.py | 2 +- tianshou/policy/base.py | 2 ++ tianshou/policy/modelfree/discrete_sac.py | 3 ++- tianshou/policy/modelfree/rainbow.py | 1 + tianshou/policy/modelfree/td3.py | 2 +- 20 files changed, 96 insertions(+), 75 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 29cf69dc8..c9106570c 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -77,8 +77,8 @@ def main(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -153,8 +153,8 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 96bcf3234..482f92ef9 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -94,9 +94,8 @@ def main(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -128,7 +127,7 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.icm_lr_scale > 0: c, h, w = args.state_shape feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) - action_dim = np.prod(args.action_shape) + action_dim = int(np.prod(args.action_shape)) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net.net, @@ -190,8 +189,8 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 8731cb47d..1a466c651 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -80,8 +80,8 @@ def main(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) @@ -165,8 +165,8 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index e209be5c1..d4e0fe005 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -80,8 +80,8 @@ def main(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -162,8 +162,8 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 3b48d3b83..6c1295e0e 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -3,12 +3,19 @@ import os import pprint import sys +from collections.abc import Sequence +from typing import cast import numpy as np import torch from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import DQNet, layer_init, scale_obs +from tianshou.env.atari.atari_network import ( + DQNet, + ScaledObsInputModule, + layer_init, + scale_obs, +) from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPO @@ -105,8 +112,8 @@ def main(args: argparse.Namespace = get_args()) -> None: scale=0, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = cast(tuple[int, ...], env.observation_space.shape) + args.action_shape = cast(Sequence[int] | int, env.action_space.shape or env.action_space.n) # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -115,6 +122,7 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model c, h, w = args.state_shape + net: ScaledObsInputModule | DQNet net = DQNet( c=c, h=h, @@ -163,7 +171,7 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.icm_lr_scale > 0: c, h, w = args.state_shape feature_net = DQNet(c=c, h=h, w=w, action_shape=args.action_shape, features_only=True) - action_dim = np.prod(args.action_shape) + action_dim = int(np.prod(args.action_shape)) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net=feature_net.net, @@ -172,7 +180,7 @@ def main(args: argparse.Namespace = get_args()) -> None: hidden_sizes=[args.hidden_size], ) icm_optim = AdamOptimizerFactory(lr=args.lr) - algorithm = ICMOnPolicyWrapper( # type: ignore[no-redef] + algorithm = ICMOnPolicyWrapper( # type: ignore[assignment] wrapped_algorithm=algorithm, model=icm_net, optim=icm_optim, @@ -222,8 +230,8 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index fbac8604d..7b0c61e7c 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -75,8 +75,8 @@ def main(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) @@ -156,8 +156,8 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 9717c6769..4c885f0cf 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -91,8 +91,8 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = env.observation_space.shape or env.observation_space.n # type: ignore + args.action_shape = env.action_space.shape or env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -190,8 +190,8 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index eebe3ca3b..7c44c7ad6 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -101,8 +101,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + c, h, w = env.observation_space.shape # type: ignore + args.action_shape = env.action_space.n # type: ignore # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) @@ -114,9 +114,10 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: # define model net = DQNet( - *args.state_shape, - args.action_shape, - device=args.device, + c, + h, + w, + action_shape=args.action_shape, features_only=True, output_dim_added_layer=args.hidden_size, ) @@ -158,7 +159,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: icm_net = IntrinsicCuriosityModule( feature_net=feature_net.net, feature_dim=feature_dim, - action_dim=action_dim, + action_dim=int(action_dim), hidden_sizes=[args.hidden_size], ) icm_optim = AdamOptimizerFactory(lr=args.actor_lr) @@ -214,8 +215,8 @@ def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold + if env.spec.reward_threshold: # type: ignore + return mean_rewards >= env.spec.reward_threshold # type: ignore if "Pong" in args.task: return mean_rewards >= 20 return False diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index bd91b9f96..7a9a48974 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -115,7 +115,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.test_num, ) args.state_shape = env.observation_space.shape - args.action_shape = env.action_space.shape or env.action_space.n + args.action_shape = env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -130,7 +130,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: w=w, action_shape=args.action_shape, features_only=True, - output_dim=args.hidden_size, + output_dim_added_layer=args.hidden_size, ) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) critic = DiscreteCritic(preprocess_net=net) @@ -155,6 +155,7 @@ def dist(logits: torch.Tensor) -> Categorical: action_scaling=False, action_space=env.action_space, ) + algorithm: PPO | ICMOnPolicyWrapper algorithm = PPO( policy=policy, critic=critic, @@ -179,7 +180,7 @@ def dist(logits: torch.Tensor) -> Categorical: w=w, action_shape=args.action_shape, features_only=True, - output_dim=args.hidden_size, + output_dim_added_layer=args.hidden_size, ) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 583ddb64b..9813d5004 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -13,7 +13,7 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DQN, ICMOffPolicyWrapper +from tianshou.policy import DQN, Algorithm, ICMOffPolicyWrapper from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams @@ -166,11 +166,11 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: train_collector.collect(n_step=args.batch_size * args.training_num) # log - log_path = os.path.join(args.logdir, args.task, "dqn_icm") + log_path = str(os.path.join(args.logdir, args.task, "dqn_icm")) writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy: ICMOffPolicyWrapper) -> None: + def save_best_fn(policy: Algorithm) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index a8702fb49..c958d3fbd 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -12,6 +12,7 @@ from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import DQN, Algorithm, MultiAgentOffPolicyAlgorithm +from tianshou.policy.base import OffPolicyAlgorithm from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams @@ -76,7 +77,7 @@ def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: def get_agents( args: argparse.Namespace = get_args(), - agents: list[Algorithm] | None = None, + agents: list[OffPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[Algorithm, list[torch.optim.Optimizer] | None, list]: env = get_env() @@ -87,8 +88,11 @@ def get_agents( ) args.state_shape = observation_space.shape or int(observation_space.n) args.action_shape = env.action_space.shape or int(env.action_space.n) - if agents is None: - agents = [] + + if agents is not None: + algorithms = agents + else: + algorithms = [] optims = [] for _ in range(args.n_pistons): # model @@ -111,16 +115,16 @@ def get_agents( estimation_step=args.n_step, target_update_freq=args.target_update_freq, ) - agents.append(agent) + algorithms.append(agent) optims.append(optim) - ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env) + ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=algorithms, env=env) return ma_algorithm, optims, env.agents def train_agent( args: argparse.Namespace = get_args(), - agents: list[Algorithm] | None = None, + agents: list[OffPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[InfoStats, Algorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index f53094310..6a4d60a67 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -16,6 +16,7 @@ from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import PPO, Algorithm +from tianshou.policy.base import OnPolicyAlgorithm from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm from tianshou.policy.optim import AdamOptimizerFactory @@ -49,7 +50,7 @@ def __init__( nn.Flatten(), ) with torch.no_grad(): - output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) + output_dim = np.prod(net(torch.zeros(1, c, h, w)).shape[1:]) super().__init__(int(output_dim)) self.device = device self.c = c @@ -143,7 +144,7 @@ def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: def get_agents( args: argparse.Namespace = get_args(), - agents: list[Algorithm] | None = None, + agents: list[OnPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[Algorithm, list[torch.optim.Optimizer] | None, list]: env = get_env() @@ -156,8 +157,10 @@ def get_agents( args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] - if agents is None: - agents = [] + if agents is not None: + algorithms = agents + else: + algorithms = [] optims = [] for _ in range(args.n_pistons): # model @@ -197,7 +200,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: action_scaling=True, action_bound_method="clip", ) - agent: PPO = PPO( + algorithm: PPO = PPO( policy=policy, critic=critic, optim=optim, @@ -215,11 +218,11 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: gae_lambda=args.gae_lambda, ) - agents.append(agent) + algorithms.append(algorithm) optims.append(optim) ma_algorithm = MultiAgentOnPolicyAlgorithm( - algorithms=agents, + algorithms=algorithms, env=env, ) return ma_algorithm, optims, env.agents @@ -227,7 +230,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: def train_agent( args: argparse.Namespace = get_args(), - agents: list[Algorithm] | None = None, + agents: list[OnPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[InfoStats, Algorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 0142f2d26..ce7aca006 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -19,6 +19,7 @@ MARLRandomDiscreteMaskedOffPolicyAlgorithm, MultiAgentOffPolicyAlgorithm, ) +from tianshou.policy.base import OffPolicyAlgorithm from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.trainer import OffPolicyTrainerParams @@ -100,8 +101,8 @@ def get_args() -> argparse.Namespace: def get_agents( args: argparse.Namespace = get_args(), - agent_learn: Algorithm | None = None, - agent_opponent: Algorithm | None = None, + agent_learn: OffPolicyAlgorithm | None = None, + agent_opponent: OffPolicyAlgorithm | None = None, optim: OptimizerFactory | None = None, ) -> tuple[MultiAgentOffPolicyAlgorithm, torch.optim.Optimizer | None, list]: env = get_env() @@ -156,10 +157,10 @@ def get_agents( def train_agent( args: argparse.Namespace = get_args(), - agent_learn: Algorithm | None = None, - agent_opponent: Algorithm | None = None, + agent_learn: OffPolicyAlgorithm | None = None, + agent_opponent: OffPolicyAlgorithm | None = None, optim: OptimizerFactory | None = None, -) -> tuple[InfoStats, Algorithm]: +) -> tuple[InfoStats, OffPolicyAlgorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -230,8 +231,8 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: def watch( args: argparse.Namespace = get_args(), - agent_learn: Algorithm | None = None, - agent_opponent: Algorithm | None = None, + agent_learn: OffPolicyAlgorithm | None = None, + agent_opponent: OffPolicyAlgorithm | None = None, ) -> None: env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index cfa162a43..73f74aa6f 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -3,10 +3,10 @@ __version__ = "1.2.0-dev" -def _register_log_config_callback(): +def _register_log_config_callback() -> None: from sensai.util import logging - def configure(): + def configure() -> None: logging.getLogger("numba").setLevel(logging.INFO) logging.set_configure_callback(configure) diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 55787b7cb..39413d965 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -24,9 +24,9 @@ def __init__( self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0, - log_alpha=0.0, + log_alpha: float = 0.0, optim: OptimizerFactoryFactory | None = None, - ): + ) -> None: """ :param lr: the learning rate for the optimizer of the alpha parameter :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 11c17de6d..e4e209e01 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -101,7 +101,7 @@ class ParamTransformerOptimFactory(ParamTransformer): def __init__( self, - key_optim_factory_factory, + key_optim_factory_factory: str, key_lr: str, key_lr_scheduler_factory_factory: str, key_optim_output: str, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5e93580bf..63dd05783 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -541,6 +541,8 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): # ty def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False ) -> None: + # don't override type in annotation since it's is declared as Mapping in nn.Module + state_dict = cast(dict[str, Any], state_dict) # restore optimizer states optimizers_state_dict = state_dict.pop(self._STATE_DICT_KEY_OPTIMIZERS) for optim, optim_state in zip(self._optimizers, optimizers_state_dict, strict=True): diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 887bbbdae..f40fd9f2b 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -7,6 +7,7 @@ from torch.distributions import Categorical from tianshou.data import Batch, to_torch +from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( DistBatchProtocol, ObsBatchProtocol, @@ -58,7 +59,7 @@ def __init__( def forward( self, batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, + state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 83b6605ab..c792225ae 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -49,5 +49,6 @@ def _update_with_batch( ) -> TRainbowTrainingStats: self._sample_noise(self.policy.model) if self.use_target_network and self._sample_noise(self.model_old): # type: ignore + assert self.model_old is not None self.model_old.train() # so that NoisyLinear takes effect return super()._update_with_batch(batch) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 274bdf110..3ece6b1c9 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -170,7 +170,7 @@ def __init__( gamma=gamma, estimation_step=estimation_step, ) - self.actor_old = self._add_lagged_network(self.policy.actor) + self.actor_old = self._add_lagged_network(self.policy.actor) # type: ignore[has-type] self.policy_noise = policy_noise self.update_actor_freq = update_actor_freq self.noise_clip = noise_clip From 84fa2bcf2326dfabadadb95a8c6b8a9d4624c6c3 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 13:54:07 +0200 Subject: [PATCH 109/230] v2: Improve description of parameter eps_clip --- tianshou/highlevel/params/policy_params.py | 8 +++++--- tianshou/policy/imitation/gail.py | 10 ++++++++-- tianshou/policy/modelfree/ppo.py | 10 ++++++++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 574172977..29ffc20e1 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -403,11 +403,13 @@ class PPOParams(A2CParams): eps_clip: float = 0.2 """ determines the range of allowed change in the policy during a policy update: - The ratio between the probabilities indicated by the new and old policy is + The ratio of action probabilities indicated by the new and old policy is constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. Small values thus force the new policy to stay close to the old policy. - Typical values range between 0.1 and 0.3. - The optimal epsilon depends on the environment; more stochastic environments may need larger epsilons. + Typical values range between 0.1 and 0.3, the value of 0.2 is recommended + in the original PPO paper. + The optimal value depends on the environment; more stochastic environments may + need larger values. """ dual_clip: float | None = None """ diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index b5b088d0b..1e7005160 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -69,8 +69,14 @@ def __init__( state dim plus action dim and output dim equals 1. :param disc_optim: the optimizer factory for the discriminator network. :param disc_update_num: the number of discriminator grad steps per model grad step. - :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original - paper. + :param eps_clip: determines the range of allowed change in the policy during a policy update: + The ratio of action probabilities indicated by the new and old policy is + constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. + Small values thus force the new policy to stay close to the old policy. + Typical values range between 0.1 and 0.3, the value of 0.2 is recommended + in the original PPO paper. + The optimal value depends on the environment; more stochastic environments may + need larger values. :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, where c > 1 is a constant indicating the lower bound. Set to None to disable dual-clip PPO. diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index b83f9215a..04557b146 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -79,8 +79,14 @@ def __init__( :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the policy's actor network and the critic networks. - :param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original - paper. + :param eps_clip: determines the range of allowed change in the policy during a policy update: + The ratio of action probabilities indicated by the new and old policy is + constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. + Small values thus force the new policy to stay close to the old policy. + Typical values range between 0.1 and 0.3, the value of 0.2 is recommended + in the original PPO paper. + The optimal value depends on the environment; more stochastic environments may + need larger values. :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, where c > 1 is a constant indicating the lower bound. Set to None to disable dual-clip PPO. From 4006a0250623ab00b31131b59e572a22b4ba8eef Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 16:01:47 +0200 Subject: [PATCH 110/230] Ignore .serena --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e63e24b00..b81b171de 100644 --- a/.gitignore +++ b/.gitignore @@ -158,4 +158,7 @@ docs/conf.py # temporary scripts (for ad-hoc testing), temp folder /temp -/temp*.py \ No newline at end of file +/temp*.py + +# Serena +/.serena From 5f50547097e9c4950dcbdc5ab7a0b6445c966e4f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 16:02:56 +0200 Subject: [PATCH 111/230] v2: Improve description of parameter 'dual_clip' --- tianshou/highlevel/params/policy_params.py | 21 +++++++++++++-------- tianshou/policy/imitation/gail.py | 16 +++++++++++++--- tianshou/policy/modelfree/ppo.py | 16 +++++++++++++--- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 29ffc20e1..e9add5441 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -413,14 +413,19 @@ class PPOParams(A2CParams): """ dual_clip: float | None = None """ - determines the lower bound clipping for the probability ratio - (corresponds to parameter c in arXiv:1912.09729, Equation 5). - If set to None, dual clipping is not used and the bounds described in parameter eps_clip apply. - If set to a float value c, the lower bound is changed from 1 - eps_clip to c, - where c < 1 - eps_clip. - Setting c > 0 reduces policy oscillation and further stabilizes training. - Typical values are between 0 and 0.5. Smaller values provide more stability. - Setting c = 0 yields PPO with only the upper bound. + a clipping parameter (denoted as c in the literature) that prevents + excessive pessimism in policy updates for negative-advantage actions. + Excessive pessimism occurs when the policy update too strongly reduces the probability + of selecting actions that led to negative advantages, potentially eliminating useful + actions based on limited negative experiences. + When enabled (c > 1), the objective for negative advantages becomes: + max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) + is the original single-clipping objective determined by `eps_clip`. + This creates a floor on negative policy gradients, maintaining some probability + of exploring actions despite initial negative outcomes. + Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer + to 1.0 provide less protection against pessimistic updates. + Set to None to disable dual clipping. """ value_clip: bool = False """ diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 1e7005160..1455e086e 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -77,9 +77,19 @@ def __init__( in the original PPO paper. The optimal value depends on the environment; more stochastic environments may need larger values. - :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, - where c > 1 is a constant indicating the lower bound. Set to None - to disable dual-clip PPO. + :param dual_clip: a clipping parameter (denoted as c in the literature) that prevents + excessive pessimism in policy updates for negative-advantage actions. + Excessive pessimism occurs when the policy update too strongly reduces the probability + of selecting actions that led to negative advantages, potentially eliminating useful + actions based on limited negative experiences. + When enabled (c > 1), the objective for negative advantages becomes: + max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) + is the original single-clipping objective determined by `eps_clip`. + This creates a floor on negative policy gradients, maintaining some probability + of exploring actions despite initial negative outcomes. + Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer + to 1.0 provide less protection against pessimistic updates. + Set to None to disable dual clipping. :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. :param advantage_normalization: whether to do per mini-batch advantage normalization. diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 04557b146..205f75f09 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -87,9 +87,19 @@ def __init__( in the original PPO paper. The optimal value depends on the environment; more stochastic environments may need larger values. - :param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, - where c > 1 is a constant indicating the lower bound. Set to None - to disable dual-clip PPO. + :param dual_clip: a clipping parameter (denoted as c in the literature) that prevents + excessive pessimism in policy updates for negative-advantage actions. + Excessive pessimism occurs when the policy update too strongly reduces the probability + of selecting actions that led to negative advantages, potentially eliminating useful + actions based on limited negative experiences. + When enabled (c > 1), the objective for negative advantages becomes: + max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) + is the original single-clipping objective determined by `eps_clip`. + This creates a floor on negative policy gradients, maintaining some probability + of exploring actions despite initial negative outcomes. + Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer + to 1.0 provide less protection against pessimistic updates. + Set to None to disable dual clipping. :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. :param advantage_normalization: whether to do per mini-batch advantage normalization. From 86c0a3acdc77c99b45124902b84f4e53abe38a95 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 16:17:47 +0200 Subject: [PATCH 112/230] v2: Improve description of parameter 'value_clip' --- tianshou/highlevel/params/policy_params.py | 14 ++++++++++---- tianshou/policy/imitation/gail.py | 11 ++++++++++- tianshou/policy/modelfree/ppo.py | 11 ++++++++++- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index e9add5441..f371e84fb 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -429,10 +429,16 @@ class PPOParams(A2CParams): """ value_clip: bool = False """ - whether to apply clipping of the predicted value function during policy learning. - Value clipping discourages large changes in value predictions between updates. - Inaccurate value predictions can lead to bad policy updates, which can cause training instability. - Clipping values prevents sporadic large errors from skewing policy updates too much. + flag indicating whether to enable clipping for value function updates. + When enabled, restricts how much the value function estimate can change from its + previous prediction, using the same clipping range as the policy updates (eps_clip). + This stabilizes training by preventing large fluctuations in value estimates, + particularly useful in environments with high reward variance. + The clipped value loss uses a pessimistic approach, taking the maximum of the + original and clipped value errors: + max((returns - value)², (returns - v_clipped)²) + Setting to True often improves training stability but may slow convergence. + Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. """ advantage_normalization: bool = True """whether to apply per mini-batch advantage normalization.""" diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 1455e086e..92ffb7f72 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -90,7 +90,16 @@ def __init__( Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer to 1.0 provide less protection against pessimistic updates. Set to None to disable dual clipping. - :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. + :param value_clip: flag indicating whether to enable clipping for value function updates. + When enabled, restricts how much the value function estimate can change from its + previous prediction, using the same clipping range as the policy updates (eps_clip). + This stabilizes training by preventing large fluctuations in value estimates, + particularly useful in environments with high reward variance. + The clipped value loss uses a pessimistic approach, taking the maximum of the + original and clipped value errors: + max((returns - value)², (returns - v_clipped)²) + Setting to True often improves training stability but may slow convergence. + Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param recompute_advantage: whether to recompute advantage every update diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 205f75f09..d3ad840d0 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -100,7 +100,16 @@ def __init__( Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer to 1.0 provide less protection against pessimistic updates. Set to None to disable dual clipping. - :param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. + :param value_clip: flag indicating whether to enable clipping for value function updates. + When enabled, restricts how much the value function estimate can change from its + previous prediction, using the same clipping range as the policy updates (eps_clip). + This stabilizes training by preventing large fluctuations in value estimates, + particularly useful in environments with high reward variance. + The clipped value loss uses a pessimistic approach, taking the maximum of the + original and clipped value errors: + max((returns - value)², (returns - v_clipped)²) + Setting to True often improves training stability but may slow convergence. + Implementation follows the approach mentioned in arXiv:1811.02553v3 Sec. 4.1. :param advantage_normalization: whether to do per mini-batch advantage normalization. :param recompute_advantage: whether to recompute advantage every update From 0f226ee03e00090740632ce54b8814bfc41beb4b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 17:02:54 +0200 Subject: [PATCH 113/230] v2: Improve description of parameter 'alpha' --- tianshou/highlevel/params/policy_params.py | 44 ++++++++++++---------- tianshou/policy/modelfree/redq.py | 13 ++++++- tianshou/policy/modelfree/sac.py | 13 ++++++- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index f371e84fb..b79a8424c 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -538,6 +538,26 @@ def _get_param_transformers(self) -> list[ParamTransformer]: ] +class ParamsMixinAlpha(GetParamTransformersProtocol): + alpha: float | AutoAlphaFactory = 0.2 + """ + the entropy regularization coefficient, which balances exploration and exploitation. + This coefficient controls how much the agent values randomness in its policy versus + pursuing higher rewards. + Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent + for maintaining diverse action choices, even if this means selecting some lower-value actions. + Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become + more focused on the highest-value actions. + A value of 0 would completely remove entropy regularization, potentially leading to + premature convergence to suboptimal deterministic policies. + Can be provided as a fixed float (0.2 is a reasonable default) or via a factory + to support automatic tuning during training. + """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ParamTransformerAutoAlpha("alpha")] + + @dataclass(kw_only=True) class _SACParams( Params, @@ -546,20 +566,12 @@ class _SACParams( ParamsMixinEstimationStep, ParamsMixinTau, ParamsMixinDeterministicEval, + ParamsMixinAlpha, ): - alpha: float | AutoAlphaFactory = 0.2 - """ - controls the relative importance (coefficient) of the entropy term in the loss function. - This can be a constant or a factory for the creation of a representation that allows the - parameter to be automatically tuned; - use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard - auto-adjusted alpha. - """ - def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) - transformers.append(ParamTransformerAutoAlpha("alpha")) + transformers.extend(ParamsMixinAlpha._get_param_transformers(self)) return transformers @@ -661,26 +673,18 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) -class REDQParams(DDPGParams, ParamsMixinDeterministicEval): +class REDQParams(DDPGParams, ParamsMixinDeterministicEval, ParamsMixinAlpha): ensemble_size: int = 10 """the number of sub-networks in the critic ensemble""" subset_size: int = 2 """the number of networks in the subset""" - alpha: float | AutoAlphaFactory = 0.2 - """ - controls the relative importance (coefficient) of the entropy term in the loss function. - This can be a constant or a factory for the creation of a representation that allows the - parameter to be automatically tuned; - use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard - auto-adjusted alpha. - """ actor_delay: int = 20 """the number of critic updates before an actor update""" target_mode: Literal["mean", "min"] = "min" def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() - transformers.append(ParamTransformerAutoAlpha("alpha")) + transformers.extend(ParamsMixinAlpha._get_param_transformers(self)) return transformers diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index b739053bd..23701120e 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -174,8 +174,17 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param alpha: the entropy regularization coefficient alpha or an object - which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). + :param alpha: the entropy regularization coefficient, which balances exploration and exploitation. + This coefficient controls how much the agent values randomness in its policy versus + pursuing higher rewards. + Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent + for maintaining diverse action choices, even if this means selecting some lower-value actions. + Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become + more focused on the highest-value actions. + A value of 0 would completely remove entropy regularization, potentially leading to + premature convergence to suboptimal deterministic policies. + Can be provided as a fixed float (0.2 is a reasonable default) or as an instance of, + in particular, class `AutoAlpha` for automatic tuning during training. :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 49ead8732..d666f8e9d 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -271,8 +271,17 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param alpha: the entropy regularization coefficient alpha or an object - which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). + :param alpha: the entropy regularization coefficient, which balances exploration and exploitation. + This coefficient controls how much the agent values randomness in its policy versus + pursuing higher rewards. + Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent + for maintaining diverse action choices, even if this means selecting some lower-value actions. + Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become + more focused on the highest-value actions. + A value of 0 would completely remove entropy regularization, potentially leading to + premature convergence to suboptimal deterministic policies. + Can be provided as a fixed float (0.2 is a reasonable default) or as an instance of, + in particular, class `AutoAlpha` for automatic tuning during training. :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) From beb9ef8f5b4b9ca78397b53209a101b0147ba757 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 5 May 2025 17:40:21 +0200 Subject: [PATCH 114/230] v2: handle case of empty module parameters in torch_device --- tianshou/utils/torch_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 37df7fe46..723d19fad 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -78,5 +78,11 @@ def create_uniform_action_dist( def torch_device(module: torch.nn.Module) -> torch.device: - """Gets the device of a torch module.""" - return next(module.parameters()).device + """Gets the device of a torch module by retrieving the device of the parameters. + + If parameters are empty, it returns the CPU device as a fallback. + """ + try: + return next(module.parameters()).device + except StopIteration: + return torch.device("cpu") From 0293e0680c961ab25b35678e98cf822add82139a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 5 May 2025 17:43:28 +0200 Subject: [PATCH 115/230] v2: remove the generic-specification pattern for TrainingStats Instead, the particular TrainingStats subclass is just directly annotated. Within Tianshou there was no interface that required knowing the specification of TrainingStats of an Algorithm implementation --- test/continuous/test_npg.py | 3 +- test/discrete/test_discrete_sac.py | 3 +- test/offline/test_bcq.py | 4 +- test/offline/test_cql.py | 3 +- tianshou/data/stats.py | 4 ++ tianshou/policy/base.py | 69 +++++++++++------------ tianshou/policy/imitation/base.py | 4 +- tianshou/policy/imitation/bcq.py | 9 ++- tianshou/policy/imitation/cql.py | 11 ++-- tianshou/policy/imitation/discrete_bcq.py | 16 ++---- tianshou/policy/imitation/discrete_cql.py | 18 ++---- tianshou/policy/imitation/discrete_crr.py | 16 ++---- tianshou/policy/imitation/gail.py | 18 +++--- tianshou/policy/imitation/td3_bc.py | 17 +----- tianshou/policy/modelbased/icm.py | 17 ++---- tianshou/policy/modelbased/psrl.py | 11 ++-- tianshou/policy/modelfree/a2c.py | 21 ++++--- tianshou/policy/modelfree/bdqn.py | 19 ++----- tianshou/policy/modelfree/c51.py | 19 ++----- tianshou/policy/modelfree/ddpg.py | 15 ++--- tianshou/policy/modelfree/discrete_sac.py | 6 +- tianshou/policy/modelfree/dqn.py | 23 +++----- tianshou/policy/modelfree/fqf.py | 16 +++--- tianshou/policy/modelfree/iqn.py | 20 ++----- tianshou/policy/modelfree/npg.py | 27 ++++----- tianshou/policy/modelfree/pg.py | 20 ++++--- tianshou/policy/modelfree/ppo.py | 51 +++-------------- tianshou/policy/modelfree/qrdqn.py | 19 ++----- tianshou/policy/modelfree/rainbow.py | 19 +++---- tianshou/policy/modelfree/redq.py | 6 +- tianshou/policy/modelfree/sac.py | 2 +- tianshou/policy/modelfree/td3.py | 16 ++---- tianshou/policy/modelfree/trpo.py | 17 +++--- tianshou/policy/multiagent/mapolicy.py | 4 +- tianshou/policy/random.py | 9 +-- tianshou/utils/net/common.py | 20 +++---- 36 files changed, 209 insertions(+), 363 deletions(-) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index ca2df5a22..aaa931e77 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -12,7 +12,6 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import NPG from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams @@ -112,7 +111,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: action_space=env.action_space, deterministic_eval=True, ) - algorithm: NPG[NPGTrainingStats] = NPG( + algorithm: NPG = NPG( policy=policy, critic=critic, optim=AdamOptimizerFactory(lr=args.lr), diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 2744d5e02..ccbeffde1 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -12,7 +12,6 @@ from tianshou.policy.base import Algorithm from tianshou.policy.modelfree.discrete_sac import ( DiscreteSACPolicy, - DiscreteSACTrainingStats, ) from tianshou.policy.modelfree.sac import AutoAlpha from tianshou.policy.optim import AdamOptimizerFactory @@ -103,7 +102,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: actor=actor, action_space=env.action_space, ) - algorithm: DiscreteSAC[DiscreteSACTrainingStats] = DiscreteSAC( + algorithm = DiscreteSAC( policy=policy, policy_optim=actor_optim, critic=critic1, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 0516dc039..e8ea22e63 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BCQ, Algorithm -from tianshou.policy.imitation.bcq import BCQPolicy, BCQTrainingStats +from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OfflineTrainerParams from tianshou.utils import TensorboardLogger @@ -145,7 +145,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: vae=vae, action_space=env.action_space, ) - algorithm: BCQ[BCQTrainingStats] = BCQ( + algorithm = BCQ( policy=policy, actor_perturbation_optim=actor_optim, critic_optim=critic_optim, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 1f31515b9..437723858 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -12,7 +12,6 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import CQL, Algorithm -from tianshou.policy.imitation.cql import CQLTrainingStats from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams @@ -141,7 +140,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: action_scaling=False, action_space=env.action_space, ) - algorithm: CQL[CQLTrainingStats] = CQL( + algorithm = CQL( policy=policy, policy_optim=actor_optim, critic=critic, diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index 964479113..51fba5c2d 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -42,6 +42,10 @@ def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "Sequenc min=float(np.min(sequence)), ) + @classmethod + def from_single_value(cls, value: float | int) -> "SequenceSummaryStats": + return cls(mean=value, std=0.0, max=value, min=value) + def compute_dim_to_summary_stats( arr: Sequence[Sequence[float]] | np.ndarray, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 308a4b382..e31984209 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -13,6 +13,9 @@ from numpy.typing import ArrayLike from overrides import override from torch import nn +from torch.nn.modules.module import ( + _IncompatibleKeys, # we have to do this since we override load_state_dict +) from torch.optim.lr_scheduler import LRScheduler from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as @@ -147,9 +150,6 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) -TTrainingStats = TypeVar("TTrainingStats", bound=TrainingStats) - - class Policy(nn.Module, ABC): """Represents a policy, which provides the fundamental mapping from observations to actions.""" @@ -468,7 +468,7 @@ def _update_lagged_network_weights(self) -> None: TTrainerParams = TypeVar("TTrainerParams", bound="TrainerParams") -class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainerParams, TTrainingStats], ABC): +class Algorithm(torch.nn.Module, Generic[TPolicy, TTrainerParams], ABC): """ The base class for reinforcement learning algorithms in Tianshou. @@ -561,7 +561,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): # ty def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False - ) -> None: + ) -> _IncompatibleKeys: # don't override type in annotation since it's is declared as Mapping in nn.Module state_dict = cast(dict[str, Any], state_dict) # restore optimizer states @@ -569,7 +569,7 @@ def load_state_dict( for optim, optim_state in zip(self._optimizers, optimizers_state_dict, strict=True): optim.load_state_dict(optim_state) - super().load_state_dict(state_dict, strict=strict, assign=assign) + return super().load_state_dict(state_dict, strict=strict, assign=assign) def preprocess_batch( self, @@ -616,8 +616,8 @@ def _update( self, sample_size: int | None, buffer: ReplayBuffer | None, - update_with_batch_fn: Callable[[RolloutBatchProtocol], TTrainingStats], - ) -> TTrainingStats: + update_with_batch_fn: Callable[[RolloutBatchProtocol], TrainingStats], + ) -> TrainingStats: """Orchestrates an update step. An update involves three algorithm-specific sub-steps: @@ -646,7 +646,7 @@ def _update( ) if buffer is None: - return TrainingStats() # type: ignore[return-value] + return TrainingStats() start_time = time.time() batch, indices = buffer.sample(sample_size) batch = self.preprocess_batch(batch, buffer, indices) @@ -858,8 +858,8 @@ def run_training(self, params: TTrainerParams) -> "InfoStats": class OnPolicyAlgorithm( - Algorithm[TPolicy, "OnPolicyTrainerParams", TTrainingStats], - Generic[TPolicy, TTrainingStats], + Algorithm[TPolicy, "OnPolicyTrainerParams"], + Generic[TPolicy], ABC, ): """Base class for on-policy RL algorithms.""" @@ -872,7 +872,7 @@ def create_trainer(self, params: "OnPolicyTrainerParams") -> "OnPolicyTrainer": @abstractmethod def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int - ) -> TTrainingStats: + ) -> TrainingStats: """Performs an update step based on the given batch of data, updating the network parameters. @@ -888,7 +888,7 @@ def update( buffer: ReplayBuffer, batch_size: int | None, repeat: int, - ) -> TTrainingStats: + ) -> TrainingStats: update_with_batch_fn = lambda batch: self._update_with_batch( batch=batch, batch_size=batch_size, repeat=repeat ) @@ -898,8 +898,8 @@ def update( class OffPolicyAlgorithm( - Algorithm[TPolicy, "OffPolicyTrainerParams", TTrainingStats], - Generic[TPolicy, TTrainingStats], + Algorithm[TPolicy, "OffPolicyTrainerParams"], + Generic[TPolicy], ABC, ): """Base class for off-policy RL algorithms.""" @@ -913,7 +913,7 @@ def create_trainer(self, params: "OffPolicyTrainerParams") -> "OffPolicyTrainer" def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TTrainingStats: + ) -> TrainingStats: """Performs an update step based on the given batch of data, updating the network parameters. @@ -926,7 +926,7 @@ def update( self, buffer: ReplayBuffer, sample_size: int | None, - ) -> TTrainingStats: + ) -> TrainingStats: update_with_batch_fn = lambda batch: self._update_with_batch(batch) return super()._update( sample_size=sample_size, buffer=buffer, update_with_batch_fn=update_with_batch_fn @@ -934,8 +934,8 @@ def update( class OfflineAlgorithm( - Algorithm[TPolicy, "OfflineTrainerParams", TTrainingStats], - Generic[TPolicy, TTrainingStats], + Algorithm[TPolicy, "OfflineTrainerParams"], + Generic[TPolicy], ABC, ): """Base class for offline RL algorithms.""" @@ -959,7 +959,7 @@ def create_trainer(self, params: "OfflineTrainerParams") -> "OfflineTrainer": def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TTrainingStats: + ) -> TrainingStats: """Performs an update step based on the given batch of data, updating the network parameters. @@ -972,19 +972,16 @@ def update( self, buffer: ReplayBuffer, sample_size: int | None, - ) -> TTrainingStats: + ) -> TrainingStats: update_with_batch_fn = lambda batch: self._update_with_batch(batch) return super()._update( sample_size=sample_size, buffer=buffer, update_with_batch_fn=update_with_batch_fn ) -TWrappedAlgorthmTrainingStats = TypeVar("TWrappedAlgorthmTrainingStats", bound=TrainingStats) - - class OnPolicyWrapperAlgorithm( - OnPolicyAlgorithm[TPolicy, TTrainingStats], - Generic[TPolicy, TTrainingStats, TWrappedAlgorthmTrainingStats], + OnPolicyAlgorithm[TPolicy], + Generic[TPolicy], ABC, ): """ @@ -996,7 +993,7 @@ class OnPolicyWrapperAlgorithm( def __init__( self, - wrapped_algorithm: OnPolicyAlgorithm[TPolicy, TWrappedAlgorthmTrainingStats], + wrapped_algorithm: OnPolicyAlgorithm[TPolicy], ): super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm @@ -1021,7 +1018,7 @@ def postprocess_batch( def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int - ) -> TTrainingStats: + ) -> TrainingStats: """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update.""" original_stats = self.wrapped_algorithm._update_with_batch( batch, batch_size=batch_size, repeat=repeat @@ -1034,14 +1031,14 @@ def _wrapper_update_with_batch( batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - original_stats: TWrappedAlgorthmTrainingStats, - ) -> TTrainingStats: + original_stats: TrainingStats, + ) -> TrainingStats: pass class OffPolicyWrapperAlgorithm( - OffPolicyAlgorithm[TPolicy, TTrainingStats], - Generic[TPolicy, TTrainingStats, TWrappedAlgorthmTrainingStats], + OffPolicyAlgorithm[TPolicy], + Generic[TPolicy], ABC, ): """ @@ -1053,7 +1050,7 @@ class OffPolicyWrapperAlgorithm( def __init__( self, - wrapped_algorithm: OffPolicyAlgorithm[TPolicy, TWrappedAlgorthmTrainingStats], + wrapped_algorithm: OffPolicyAlgorithm[TPolicy], ): super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm @@ -1079,15 +1076,15 @@ def postprocess_batch( def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TTrainingStats: + ) -> TrainingStats: """Performs the update as defined by the wrapped algorithm, followed by the wrapper's update .""" original_stats = self.wrapped_algorithm._update_with_batch(batch) return self._wrapper_update_with_batch(batch, original_stats) @abstractmethod def _wrapper_update_with_batch( - self, batch: RolloutBatchProtocol, original_stats: TWrappedAlgorthmTrainingStats - ) -> TTrainingStats: + self, batch: RolloutBatchProtocol, original_stats: TrainingStats + ) -> TrainingStats: pass diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 3ee2f4a35..a13d50809 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -128,7 +128,7 @@ def _imitation_update( class OffPolicyImitationLearning( - OffPolicyAlgorithm[ImitationPolicy, ImitationTrainingStats], + OffPolicyAlgorithm[ImitationPolicy], ImitationLearningAlgorithmMixin, ): """Implementation of off-policy vanilla imitation learning.""" @@ -156,7 +156,7 @@ def _update_with_batch( class OfflineImitationLearning( - OfflineAlgorithm[ImitationPolicy, ImitationTrainingStats], + OfflineAlgorithm[ImitationPolicy], ImitationLearningAlgorithmMixin, ): """Implementation of offline vanilla imitation learning.""" diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index dcb8051f0..91005a6a3 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -1,6 +1,6 @@ import copy from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, cast +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -117,9 +117,8 @@ def forward( class BCQ( - OfflineAlgorithm[BCQPolicy, TBCQTrainingStats], + OfflineAlgorithm[BCQPolicy], LaggedNetworkPolyakUpdateAlgorithmMixin, - Generic[TBCQTrainingStats], ): """Implementation of Batch-Constrained Deep Q-learning (BCQ) algorithm. arXiv:1812.02900.""" @@ -189,7 +188,7 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TBCQTrainingStats: + ) -> BCQTrainingStats: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) # TODO: This does not use policy.forward but computes things directly, which seems odd @@ -256,7 +255,7 @@ def _update_with_batch( # update target networks self._update_lagged_network_weights() - return BCQTrainingStats( # type: ignore + return BCQTrainingStats( actor_loss=actor_loss.item(), critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index aba5b8632..7be0e3c32 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -1,6 +1,6 @@ from copy import deepcopy from dataclasses import dataclass -from typing import TypeVar, cast +from typing import cast import numpy as np import torch @@ -28,11 +28,8 @@ class CQLTrainingStats(SACTrainingStats): cql_alpha_loss: float | None = None -TCQLTrainingStats = TypeVar("TCQLTrainingStats", bound=CQLTrainingStats) - - # TODO: Perhaps SACPolicy should get a more generic name -class CQL(OfflineAlgorithm[SACPolicy, TCQLTrainingStats], LaggedNetworkPolyakUpdateAlgorithmMixin): +class CQL(OfflineAlgorithm[SACPolicy], LaggedNetworkPolyakUpdateAlgorithmMixin): """Implementation of the conservative Q-learning (CQL) algorithm. arXiv:2006.04779.""" def __init__( @@ -212,7 +209,7 @@ def process_buffer(self, buffer: TBuffer) -> TBuffer: ) return buffer - def _update_with_batch(self, batch: RolloutBatchProtocol) -> TCQLTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> CQLTrainingStats: device = torch_device(self.policy) batch: Batch = to_torch(batch, dtype=torch.float, device=device) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next @@ -336,7 +333,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TCQLTrainingStats: self._update_lagged_network_weights() - return CQLTrainingStats( # type: ignore[return-value] + return CQLTrainingStats( actor_loss=to_optional_float(actor_loss), critic1_loss=to_optional_float(critic1_loss), critic2_loss=to_optional_float(critic2_loss), diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index e77ffbbed..57f818aea 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import Any, Generic, TypeVar, cast +from typing import Any, cast import gymnasium as gym import numpy as np @@ -19,7 +19,7 @@ OfflineAlgorithm, Policy, ) -from tianshou.policy.modelfree.dqn import DQNTrainingStats +from tianshou.policy.modelfree.pg import SimpleLossTrainingStats from tianshou.policy.optim import OptimizerFactory float_info = torch.finfo(torch.float32) @@ -27,15 +27,12 @@ @dataclass(kw_only=True) -class DiscreteBCQTrainingStats(DQNTrainingStats): +class DiscreteBCQTrainingStats(SimpleLossTrainingStats): q_loss: float i_loss: float reg_loss: float -TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteBCQTrainingStats) - - class DiscreteBCQPolicy(Policy): def __init__( self, @@ -98,9 +95,8 @@ def forward( # type: ignore class DiscreteBCQ( - OfflineAlgorithm[DiscreteBCQPolicy, TDiscreteBCQTrainingStats], + OfflineAlgorithm[DiscreteBCQPolicy], LaggedNetworkFullUpdateAlgorithmMixin, - Generic[TDiscreteBCQTrainingStats], ): """Implementation of the discrete batch-constrained deep Q-learning (BCQ) algorithm. arXiv:1910.01708.""" @@ -205,7 +201,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TDiscreteBCQTrainingStats: + ) -> DiscreteBCQTrainingStats: if self._iter % self.freq == 0: self._update_lagged_network_weights() self._iter += 1 @@ -222,7 +218,7 @@ def _update_with_batch( self.optim.step(loss) - return DiscreteBCQTrainingStats( # type: ignore[return-value] + return DiscreteBCQTrainingStats( loss=loss.item(), q_loss=q_loss.item(), i_loss=i_loss.item(), diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index cffaa8e7f..4b5167485 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import TypeVar import numpy as np import torch @@ -9,24 +8,19 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import QRDQN from tianshou.policy.base import OfflineAlgorithm -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats +from tianshou.policy.modelfree.pg import SimpleLossTrainingStats +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.optim import OptimizerFactory @dataclass(kw_only=True) -class DiscreteCQLTrainingStats(QRDQNTrainingStats): +class DiscreteCQLTrainingStats(SimpleLossTrainingStats): cql_loss: float qr_loss: float -TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteCQLTrainingStats) - - # NOTE: This uses diamond inheritance to convert from off-policy to offline -class DiscreteCQL( # type: ignore - OfflineAlgorithm[QRDQNPolicy, TDiscreteCQLTrainingStats], - QRDQN[QRDQNPolicy, TDiscreteCQLTrainingStats], -): +class DiscreteCQL(OfflineAlgorithm[QRDQNPolicy], QRDQN[QRDQNPolicy]): # type: ignore[misc] """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.""" def __init__( @@ -81,7 +75,7 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TDiscreteCQLTrainingStats: + ) -> DiscreteCQLTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) all_dist = self.policy(batch).logits @@ -107,7 +101,7 @@ def _update_with_batch( loss = qr_loss + min_q_loss * self.min_q_weight self.optim.step(loss) - return DiscreteCQLTrainingStats( # type: ignore[return-value] + return DiscreteCQLTrainingStats( loss=loss.item(), qr_loss=qr_loss.item(), cql_loss=min_q_loss.item(), diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index a72ed53d9..821e365c8 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal, TypeVar +from typing import Literal import numpy as np import torch @@ -16,24 +16,21 @@ from tianshou.policy.modelfree.pg import ( DiscountedReturnComputation, DiscreteActorPolicy, - PGTrainingStats, + SimpleLossTrainingStats, ) from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.discrete import DiscreteCritic @dataclass -class DiscreteCRRTrainingStats(PGTrainingStats): +class DiscreteCRRTrainingStats(SimpleLossTrainingStats): actor_loss: float critic_loss: float cql_loss: float -TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteCRRTrainingStats) - - class DiscreteCRR( - OfflineAlgorithm[DiscreteActorPolicy, TDiscreteCRRTrainingStats], + OfflineAlgorithm[DiscreteActorPolicy], LaggedNetworkFullUpdateAlgorithmMixin, ): r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.""" @@ -116,7 +113,7 @@ def preprocess_batch( def _update_with_batch( # type: ignore self, batch: RolloutBatchProtocol, - ) -> TDiscreteCRRTrainingStats: + ) -> DiscreteCRRTrainingStats: if self._target and self._iter % self._freq == 0: self._update_lagged_network_weights() q_t = self.critic(batch.obs) @@ -150,8 +147,7 @@ def _update_with_batch( # type: ignore self.optim.step(loss) self._iter += 1 - return DiscreteCRRTrainingStats( # type: ignore[return-value] - # TODO: Type is wrong + return DiscreteCRRTrainingStats( loss=loss.item(), actor_loss=actor_loss.item(), critic_loss=critic_loss.item(), diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 92ffb7f72..2929a99a2 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import TypeVar import numpy as np import torch @@ -12,9 +11,9 @@ to_torch, ) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol -from tianshou.policy import PPO +from tianshou.policy.modelfree.a2c import A2CTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy -from tianshou.policy.modelfree.ppo import PPOTrainingStats +from tianshou.policy.modelfree.ppo import PPO from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import ModuleWithVectorOutput from tianshou.utils.net.continuous import ContinuousCritic @@ -23,16 +22,13 @@ @dataclass(kw_only=True) -class GailTrainingStats(PPOTrainingStats): +class GailTrainingStats(A2CTrainingStats): disc_loss: SequenceSummaryStats acc_pi: SequenceSummaryStats acc_exp: SequenceSummaryStats -TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats) - - -class GAIL(PPO[TGailTrainingStats]): +class GAIL(PPO): r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.""" def __init__( @@ -180,12 +176,12 @@ def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: act = to_torch(batch.act, device=device) return self.disc_net(torch.cat([obs, act], dim=1)) - def _update_with_batch( # type: ignore + def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - ) -> TGailTrainingStats: + ) -> GailTrainingStats: # update discriminator losses = [] acc_pis = [] @@ -209,7 +205,7 @@ def _update_with_batch( # type: ignore acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) acc_exps_summary = SequenceSummaryStats.from_sequence(acc_exps) - return GailTrainingStats( # type: ignore[return-value] + return GailTrainingStats( **ppo_loss_stat.__dict__, disc_loss=disc_losses_summary, acc_pi=acc_pi_summary, diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 048a4b409..68efd1f75 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -1,6 +1,3 @@ -from dataclasses import dataclass -from typing import TypeVar - import torch import torch.nn.functional as F @@ -13,16 +10,8 @@ from tianshou.policy.optim import OptimizerFactory -@dataclass(kw_only=True) -class TD3BCTrainingStats(TD3TrainingStats): - pass - - -TTD3BCTrainingStats = TypeVar("TTD3BCTrainingStats", bound=TD3BCTrainingStats) - - # NOTE: This uses diamond inheritance to convert from off-policy to offline -class TD3BC(OfflineAlgorithm[DDPGPolicy, TTD3BCTrainingStats], TD3[TTD3BCTrainingStats]): # type: ignore +class TD3BC(OfflineAlgorithm[DDPGPolicy], TD3): # type: ignore """Implementation of TD3+BC. arXiv:2106.06860.""" def __init__( @@ -93,7 +82,7 @@ def __init__( ) self.alpha = alpha - def _update_with_batch(self, batch: RolloutBatchProtocol) -> TTD3BCTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TD3TrainingStats: # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( batch, self.critic, self.critic_optim @@ -114,7 +103,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TTD3BCTrainingStats self._update_lagged_network_weights() self._cnt += 1 - return TD3BCTrainingStats( # type: ignore[return-value] + return TD3TrainingStats( actor_loss=self._last, critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 99d8e6a4b..b77666056 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -14,7 +14,6 @@ TPolicy, TrainingStats, TrainingStatsWrapper, - TTrainingStats, ) from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -94,15 +93,13 @@ def _icm_update( ) -class ICMOffPolicyWrapper( - OffPolicyWrapperAlgorithm[TPolicy, ICMTrainingStats, TTrainingStats], _ICMMixin -): +class ICMOffPolicyWrapper(OffPolicyWrapperAlgorithm[TPolicy], _ICMMixin): """Implementation of the Intrinsic Curiosity Module (ICM) algorithm for off-policy learning. arXiv:1705.05363.""" def __init__( self, *, - wrapped_algorithm: OffPolicyAlgorithm[TPolicy, TTrainingStats], + wrapped_algorithm: OffPolicyAlgorithm[TPolicy], model: IntrinsicCuriosityModule, optim: OptimizerFactory, lr_scale: float, @@ -150,20 +147,18 @@ def postprocess_batch( def _wrapper_update_with_batch( self, batch: RolloutBatchProtocol, - original_stats: TTrainingStats, + original_stats: TrainingStats, ) -> ICMTrainingStats: return self._icm_update(batch, original_stats) -class ICMOnPolicyWrapper( - OnPolicyWrapperAlgorithm[TPolicy, ICMTrainingStats, TTrainingStats], _ICMMixin -): +class ICMOnPolicyWrapper(OnPolicyWrapperAlgorithm[TPolicy], _ICMMixin): """Implementation of the Intrinsic Curiosity Module (ICM) algorithm for on-policy learning. arXiv:1705.05363.""" def __init__( self, *, - wrapped_algorithm: OnPolicyAlgorithm[TPolicy, TTrainingStats], + wrapped_algorithm: OnPolicyAlgorithm[TPolicy], model: IntrinsicCuriosityModule, optim: OptimizerFactory, lr_scale: float, @@ -213,6 +208,6 @@ def _wrapper_update_with_batch( batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - original_stats: TTrainingStats, + original_stats: TrainingStats, ) -> ICMTrainingStats: return self._icm_update(batch, original_stats) diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index ada87c8c6..9d0124b0e 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, TypeVar, cast +from typing import Any, cast import gymnasium as gym import numpy as np @@ -21,9 +21,6 @@ class PSRLTrainingStats(TrainingStats): psrl_rew_std: float = 0.0 -TPSRLTrainingStats = TypeVar("TPSRLTrainingStats", bound=PSRLTrainingStats) - - class PSRLModel: """Implementation of Posterior Sampling Reinforcement Learning Model.""" @@ -222,7 +219,7 @@ def forward( return cast(ActBatchProtocol, Batch(act=act)) -class PSRL(OnPolicyAlgorithm[PSRLPolicy, TPSRLTrainingStats]): +class PSRL(OnPolicyAlgorithm[PSRLPolicy]): """Implementation of Posterior Sampling Reinforcement Learning (PSRL). Reference: Strens M., A Bayesian Framework for Reinforcement Learning, ICML, 2000. @@ -247,7 +244,7 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int - ) -> TPSRLTrainingStats: + ) -> PSRLTrainingStats: # NOTE: In contrast to other on-policy algorithms, this algorithm ignores # the batch_size and repeat arguments. # PSRL, being a Bayesian approach, updates its posterior distribution of @@ -272,7 +269,7 @@ def _update_with_batch( rew_count[obs_next, :] += 1 self.policy.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) - return PSRLTrainingStats( # type: ignore[return-value] + return PSRLTrainingStats( psrl_rew_mean=float(self.policy.model.rew_mean.mean()), psrl_rew_std=float(self.policy.model.rew_std.mean()), ) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 6ee5a6994..cfade460b 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,6 +1,6 @@ from abc import ABC from dataclasses import dataclass -from typing import Generic, TypeVar, cast +from typing import cast import numpy as np import torch @@ -11,7 +11,6 @@ from tianshou.policy.base import ( OnPolicyAlgorithm, TrainingStats, - TTrainingStats, ) from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import OptimizerFactory @@ -27,14 +26,10 @@ class A2CTrainingStats(TrainingStats): actor_loss: SequenceSummaryStats vf_loss: SequenceSummaryStats ent_loss: SequenceSummaryStats + gradient_steps: int -TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) - - -class ActorCriticOnPolicyAlgorithm( - OnPolicyAlgorithm[ActorPolicy, TTrainingStats], Generic[TTrainingStats], ABC -): +class ActorCriticOnPolicyAlgorithm(OnPolicyAlgorithm[ActorPolicy], ABC): """Abstract base class for actor-critic algorithms that use generalized advantage estimation (GAE).""" def __init__( @@ -142,7 +137,7 @@ def _add_returns_and_advantages( return cast(BatchWithAdvantagesProtocol, batch) -class A2C(ActorCriticOnPolicyAlgorithm[TA2CTrainingStats], Generic[TA2CTrainingStats]): +class A2C(ActorCriticOnPolicyAlgorithm): """Implementation of (synchronous) Advantage Actor-Critic (A2C). arXiv:1602.01783.""" def __init__( @@ -218,11 +213,14 @@ def _update_with_batch( batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - ) -> TA2CTrainingStats: + ) -> A2CTrainingStats: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 + gradient_steps = 0 for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): + gradient_steps = 0 + # calculate loss for actor dist = self.policy(minibatch).dist log_prob = dist.log_prob(minibatch.act) @@ -245,9 +243,10 @@ def _update_with_batch( vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) ent_loss_summary_stat = SequenceSummaryStats.from_sequence(ent_losses) - return A2CTrainingStats( # type: ignore[return-value] + return A2CTrainingStats( loss=loss_summary_stat, actor_loss=actor_loss_summary_stat, vf_loss=vf_loss_summary_stat, ent_loss=ent_loss_summary_stat, + gradient_steps=gradient_steps, ) diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 9eafa8840..299c6799e 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -1,5 +1,4 @@ -from dataclasses import dataclass -from typing import Any, TypeVar, cast +from typing import Any, cast import gymnasium as gym import numpy as np @@ -18,23 +17,15 @@ from tianshou.policy.base import TArrOrActBatch from tianshou.policy.modelfree.dqn import ( DQNPolicy, - DQNTrainingStats, QLearningOffPolicyAlgorithm, ) +from tianshou.policy.modelfree.pg import SimpleLossTrainingStats from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import BranchingNet mark_used(ActBatchProtocol) -@dataclass(kw_only=True) -class BDQNTrainingStats(DQNTrainingStats): - pass - - -TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) - - class BDQNPolicy(DQNPolicy[BranchingNet]): def __init__( self, @@ -107,7 +98,7 @@ def add_exploration_noise( return act -class BDQN(QLearningOffPolicyAlgorithm[BDQNPolicy, TBDQNTrainingStats]): +class BDQN(QLearningOffPolicyAlgorithm[BDQNPolicy]): """Implementation of the Branching Dueling Q-Network (BDQN) algorithm arXiv:1711.08946.""" def __init__( @@ -211,7 +202,7 @@ def preprocess_batch( def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TBDQNTrainingStats: + ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) @@ -226,4 +217,4 @@ def _update_with_batch( batch.weight = td_error.sum(-1).sum(-1) # prio-buffer self.optim.step(loss) - return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + return SimpleLossTrainingStats(loss=loss.item()) diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 35c68a3f5..448fa9b3c 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,6 +1,3 @@ -from dataclasses import dataclass -from typing import Generic, TypeVar - import gymnasium as gym import numpy as np import torch @@ -9,21 +6,13 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.modelfree.dqn import ( DQNPolicy, - DQNTrainingStats, QLearningOffPolicyAlgorithm, ) +from tianshou.policy.modelfree.pg import LossSequenceTrainingStats from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import Net -@dataclass(kw_only=True) -class C51TrainingStats(DQNTrainingStats): - pass - - -TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) - - class C51Policy(DQNPolicy): def __init__( self, @@ -78,7 +67,7 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc return super().compute_q_value((logits * self.support).sum(2), mask) -class C51(QLearningOffPolicyAlgorithm[C51Policy, TC51TrainingStats], Generic[TC51TrainingStats]): +class C51(QLearningOffPolicyAlgorithm[C51Policy]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887.""" def __init__( @@ -149,7 +138,7 @@ def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TC51TrainingStats: + ) -> LossSequenceTrainingStats: self._periodically_update_lagged_network_weights() with torch.no_grad(): target_dist = self._target_dist(batch) @@ -163,4 +152,4 @@ def _update_with_batch( batch.weight = cross_entropy.detach() # prio-buffer self.optim.step(loss) - return C51TrainingStats(loss=loss.item()) # type: ignore[return-value] + return LossSequenceTrainingStats(loss=loss.item()) diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index d0fba91aa..d85c50f72 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -26,7 +26,6 @@ TArrOrActBatch, TPolicy, TrainingStats, - TTrainingStats, ) from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic @@ -40,9 +39,6 @@ class DDPGTrainingStats(TrainingStats): critic_loss: float -TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) - - class ContinuousPolicyWithExplorationNoise(Policy, ABC): def __init__( self, @@ -189,9 +185,9 @@ def forward( class ActorCriticOffPolicyAlgorithm( - OffPolicyAlgorithm[TPolicy, TTrainingStats], + OffPolicyAlgorithm[TPolicy], LaggedNetworkPolyakUpdateAlgorithmMixin, - Generic[TPolicy, TTrainingStats, TActBatchProtocol], + Generic[TPolicy, TActBatchProtocol], ABC, ): """Base class for actor-critic off-policy algorithms that use a lagged critic @@ -335,8 +331,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: class DDPG( - ActorCriticOffPolicyAlgorithm[DDPGPolicy, TDDPGTrainingStats, ActBatchProtocol], - Generic[TDDPGTrainingStats], + ActorCriticOffPolicyAlgorithm[DDPGPolicy, ActBatchProtocol], ): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.""" @@ -394,7 +389,7 @@ def _target_q_compute_action(self, obs_batch: Batch) -> ActBatchProtocol: # compute the action using the lagged actor network return self.policy(obs_batch, model=self.actor_old) - def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDDPGTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> DDPGTrainingStats: # critic td, critic_loss = self._minimize_critic_squared_loss(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer @@ -403,4 +398,4 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TDDPGTrainingStats: self.policy_optim.step(actor_loss) self._update_lagged_network_weights() - return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] + return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 13bdf937a..e7e217e0c 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -81,11 +81,7 @@ def forward( return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) -class DiscreteSAC( - ActorDualCriticsOffPolicyAlgorithm[ - DiscreteSACPolicy, TDiscreteSACTrainingStats, DistBatchProtocol - ] -): +class DiscreteSAC(ActorDualCriticsOffPolicyAlgorithm[DiscreteSACPolicy, DistBatchProtocol]): """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.""" def __init__( diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 0805e6834..c878a3afb 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any, Generic, TypeVar, cast import gymnasium as gym @@ -22,21 +21,15 @@ OffPolicyAlgorithm, Policy, TArrOrActBatch, - TrainingStats, - TTrainingStats, +) +from tianshou.policy.modelfree.pg import ( + SimpleLossTrainingStats, ) from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import Net mark_used(ActBatchProtocol) - -@dataclass(kw_only=True) -class DQNTrainingStats(TrainingStats): - loss: float - - -TDQNTrainingStats = TypeVar("TDQNTrainingStats", bound=DQNTrainingStats) TModel = TypeVar("TModel", bound=torch.nn.Module | Net) @@ -182,7 +175,7 @@ def add_exploration_noise( class QLearningOffPolicyAlgorithm( - OffPolicyAlgorithm[TDQNPolicy, TTrainingStats], LaggedNetworkFullUpdateAlgorithmMixin, ABC + OffPolicyAlgorithm[TDQNPolicy], LaggedNetworkFullUpdateAlgorithmMixin, ABC ): """ Base class for Q-learning off-policy algorithms that use a Q-function to compute the @@ -286,8 +279,8 @@ def _periodically_update_lagged_network_weights(self) -> None: class DQN( - QLearningOffPolicyAlgorithm[TDQNPolicy, TDQNTrainingStats], - Generic[TDQNPolicy, TDQNTrainingStats], + QLearningOffPolicyAlgorithm[TDQNPolicy], + Generic[TDQNPolicy], ): """Implementation of Deep Q Network. arXiv:1312.5602. @@ -365,7 +358,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TDQNTrainingStats: + ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) q = self.policy(batch).logits @@ -383,4 +376,4 @@ def _update_with_batch( batch.weight = td_error # prio-buffer self.optim.step(loss) - return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + return SimpleLossTrainingStats(loss=loss.item()) diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index b03ee3ae2..c2a347561 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, TypeVar, cast +from typing import Any, cast import gymnasium as gym import numpy as np @@ -11,21 +11,19 @@ from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import QRDQN, Algorithm from tianshou.policy.modelfree.dqn import DQNPolicy -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats +from tianshou.policy.modelfree.pg import SimpleLossTrainingStats +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @dataclass(kw_only=True) -class FQFTrainingStats(QRDQNTrainingStats): +class FQFTrainingStats(SimpleLossTrainingStats): quantile_loss: float fraction_loss: float entropy_loss: float -TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) - - class FQFPolicy(QRDQNPolicy): def __init__( self, @@ -109,7 +107,7 @@ def forward( # type: ignore return cast(FQFBatchProtocol, result) -class FQF(QRDQN[FQFPolicy, TFQFTrainingStats]): +class FQF(QRDQN[FQFPolicy]): """Implementation of Fully Parameterized Quantile Function for Distributional Reinforcement Learning. arXiv:1911.02140.""" def __init__( @@ -191,7 +189,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TFQFTrainingStats: + ) -> FQFTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) out = self.policy(batch) @@ -244,7 +242,7 @@ def _update_with_batch( self.fraction_optim.step(fraction_entropy_loss, retain_graph=True) self.optim.step(quantile_loss) - return FQFTrainingStats( # type: ignore[return-value] + return FQFTrainingStats( loss=quantile_loss.item() + fraction_entropy_loss.item(), quantile_loss=quantile_loss.item(), fraction_loss=fraction_loss.item(), diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 54f60c2fc..dad67d621 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -1,5 +1,4 @@ -from dataclasses import dataclass -from typing import Any, TypeVar, cast +from typing import Any, cast import gymnasium as gym import numpy as np @@ -14,18 +13,11 @@ RolloutBatchProtocol, ) from tianshou.policy import QRDQN -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy, QRDQNTrainingStats +from tianshou.policy.modelfree.pg import SimpleLossTrainingStats +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.optim import OptimizerFactory -@dataclass(kw_only=True) -class IQNTrainingStats(QRDQNTrainingStats): - pass - - -TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) - - class IQNPolicy(QRDQNPolicy): def __init__( self, @@ -111,7 +103,7 @@ def forward( return cast(QuantileRegressionBatchProtocol, result) -class IQN(QRDQN[IQNPolicy, TIQNTrainingStats]): +class IQN(QRDQN[IQNPolicy]): """Implementation of Implicit Quantile Network. arXiv:1806.06923.""" def __init__( @@ -162,7 +154,7 @@ def __init__( def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TIQNTrainingStats: + ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) action_batch = self.policy(batch) @@ -186,4 +178,4 @@ def _update_with_batch( batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer self.optim.step(loss) - return IQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + return SimpleLossTrainingStats(loss=loss.item()) diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 75d85e44b..dbb15ec87 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Any import numpy as np import torch @@ -7,7 +7,7 @@ from torch import nn from torch.distributions import kl_divergence -from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy.base import TrainingStats from tianshou.policy.modelfree.a2c import ActorCriticOnPolicyAlgorithm @@ -24,10 +24,7 @@ class NPGTrainingStats(TrainingStats): kl: SequenceSummaryStats -TNPGTrainingStats = TypeVar("TNPGTrainingStats", bound=NPGTrainingStats) - - -class NPG(ActorCriticOnPolicyAlgorithm[TNPGTrainingStats], Generic[TNPGTrainingStats]): +class NPG(ActorCriticOnPolicyAlgorithm): """Implementation of Natural Policy Gradient. https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf @@ -128,12 +125,12 @@ def preprocess_batch( batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() return batch - def _update_with_batch( # type: ignore + def _update_with_batch( self, - batch: Batch, + batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - ) -> TNPGTrainingStats: + ) -> NPGTrainingStats: actor_losses, vf_losses, kls = [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -177,14 +174,10 @@ def _update_with_batch( # type: ignore vf_losses.append(vf_loss.item()) kls.append(kl.item()) - actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) - vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) - kl_summary_stat = SequenceSummaryStats.from_sequence(kls) - - return NPGTrainingStats( # type: ignore[return-value] - actor_loss=actor_loss_summary_stat, - vf_loss=vf_loss_summary_stat, - kl=kl_summary_stat, + return NPGTrainingStats( + actor_loss=SequenceSummaryStats.from_sequence(actor_losses), + vf_loss=SequenceSummaryStats.from_sequence(vf_losses), + kl=SequenceSummaryStats.from_sequence(kls), ) def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 9b95205ad..fe0c2557d 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar, cast +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -52,11 +52,13 @@ @dataclass(kw_only=True) -class PGTrainingStats(TrainingStats): +class LossSequenceTrainingStats(TrainingStats): loss: SequenceSummaryStats -TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats) +@dataclass(kw_only=True) +class SimpleLossTrainingStats(TrainingStats): + loss: float class ActorPolicy(Policy): @@ -307,7 +309,7 @@ def add_discounted_returns( return batch -class Reinforce(OnPolicyAlgorithm[ActorPolicy, TPGTrainingStats], Generic[TPGTrainingStats]): +class Reinforce(OnPolicyAlgorithm[ActorPolicy]): """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm.""" def __init__( @@ -353,13 +355,14 @@ def preprocess_batch( indices, ) - # TODO: why does mypy complain? - def _update_with_batch( # type: ignore + # Needs BatchWithReturnsProtocol, which violates the substitution principle. But not a problem since it's a private method and + # the remainder of the class was adjusted to provide the correct batch + def _update_with_batch( # type: ignore[override] self, batch: BatchWithReturnsProtocol, batch_size: int | None, repeat: int, - ) -> TPGTrainingStats: + ) -> LossSequenceTrainingStats: losses = [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -373,5 +376,4 @@ def _update_with_batch( # type: ignore self.optim.step(loss) losses.append(loss.item()) - loss_summary_stat = SequenceSummaryStats.from_sequence(losses) - return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value] + return LossSequenceTrainingStats(loss=SequenceSummaryStats.from_sequence(losses)) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index d3ad840d0..5f869cddd 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,6 +1,4 @@ -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Generic, Self, TypeVar, cast +from typing import cast import numpy as np import torch @@ -8,45 +6,14 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import A2C -from tianshou.policy.base import TrainingStats +from tianshou.policy.modelfree.a2c import A2CTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic -@dataclass(kw_only=True) -class PPOTrainingStats(TrainingStats): - loss: SequenceSummaryStats - clip_loss: SequenceSummaryStats - vf_loss: SequenceSummaryStats - ent_loss: SequenceSummaryStats - gradient_steps: int = 0 - - @classmethod - def from_sequences( - cls, - *, - losses: Sequence[float], - clip_losses: Sequence[float], - vf_losses: Sequence[float], - ent_losses: Sequence[float], - gradient_steps: int = 0, - ) -> Self: - return cls( - loss=SequenceSummaryStats.from_sequence(losses), - clip_loss=SequenceSummaryStats.from_sequence(clip_losses), - vf_loss=SequenceSummaryStats.from_sequence(vf_losses), - ent_loss=SequenceSummaryStats.from_sequence(ent_losses), - gradient_steps=gradient_steps, - ) - - -TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats) - - -# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. -class PPO(A2C[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] +class PPO(A2C): r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. .. seealso:: @@ -183,7 +150,7 @@ def _update_with_batch( batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - ) -> TPPOTrainingStats: + ) -> A2CTrainingStats: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] gradient_steps = 0 split_batch_size = batch_size or -1 @@ -229,10 +196,10 @@ def _update_with_batch( ent_losses.append(ent_loss.item()) losses.append(loss.item()) - return PPOTrainingStats.from_sequences( # type: ignore[return-value] - losses=losses, - clip_losses=clip_losses, - vf_losses=vf_losses, - ent_losses=ent_losses, + return A2CTrainingStats( + loss=SequenceSummaryStats.from_sequence(losses), + actor_loss=SequenceSummaryStats.from_sequence(clip_losses), + vf_loss=SequenceSummaryStats.from_sequence(vf_losses), + ent_loss=SequenceSummaryStats.from_sequence(ent_losses), gradient_steps=gradient_steps, ) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 8790c4523..6f7a67a89 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -1,5 +1,4 @@ import warnings -from dataclasses import dataclass from typing import Generic, TypeVar import numpy as np @@ -10,20 +9,12 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.modelfree.dqn import ( DQNPolicy, - DQNTrainingStats, QLearningOffPolicyAlgorithm, ) +from tianshou.policy.modelfree.pg import SimpleLossTrainingStats from tianshou.policy.optim import OptimizerFactory -@dataclass(kw_only=True) -class QRDQNTrainingStats(DQNTrainingStats): - pass - - -TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats) - - class QRDQNPolicy(DQNPolicy): def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: return super().compute_q_value(logits.mean(2), mask) @@ -33,8 +24,8 @@ def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torc class QRDQN( - QLearningOffPolicyAlgorithm[TQRDQNPolicy, TQRDQNTrainingStats], - Generic[TQRDQNPolicy, TQRDQNTrainingStats], + QLearningOffPolicyAlgorithm[TQRDQNPolicy], + Generic[TQRDQNPolicy], ): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.""" @@ -107,7 +98,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TQRDQNTrainingStats: + ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) curr_dist = self.policy(batch).logits @@ -127,4 +118,4 @@ def _update_with_batch( batch.weight = dist_diff.detach().abs().sum(-1).mean(1) # prio-buffer self.optim.step(loss) - return QRDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + return SimpleLossTrainingStats(loss=loss.item()) diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index c792225ae..88d572d47 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -1,23 +1,19 @@ from dataclasses import dataclass -from typing import TypeVar from torch import nn from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import C51 -from tianshou.policy.modelfree.c51 import C51TrainingStats +from tianshou.policy.modelfree.c51 import C51 +from tianshou.policy.modelfree.pg import LossSequenceTrainingStats from tianshou.utils.net.discrete import NoisyLinear @dataclass(kw_only=True) -class RainbowTrainingStats(C51TrainingStats): +class RainbowTrainingStats: loss: float -TRainbowTrainingStats = TypeVar("TRainbowTrainingStats", bound=RainbowTrainingStats) - - -class RainbowDQN(C51[TRainbowTrainingStats]): +class RainbowDQN(C51): """Implementation of Rainbow DQN. arXiv:1710.02298. .. seealso:: @@ -46,9 +42,10 @@ def _sample_noise(model: nn.Module) -> bool: def _update_with_batch( self, batch: RolloutBatchProtocol, - ) -> TRainbowTrainingStats: + ) -> LossSequenceTrainingStats: self._sample_noise(self.policy.model) - if self.use_target_network and self._sample_noise(self.model_old): # type: ignore + if self.use_target_network: assert self.model_old is not None - self.model_old.train() # so that NoisyLinear takes effect + if self._sample_noise(self.model_old): + self.model_old.train() # so that NoisyLinear takes effect return super()._update_with_batch(batch) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 23701120e..ba76a334a 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -132,7 +132,7 @@ def forward( # type: ignore return cast(DistLogProbBatchProtocol, result) -class REDQ(ActorCriticOffPolicyAlgorithm[REDQPolicy, TREDQTrainingStats, DistLogProbBatchProtocol]): +class REDQ(ActorCriticOffPolicyAlgorithm[REDQPolicy, DistLogProbBatchProtocol]): """Implementation of REDQ. arXiv:2101.05982.""" def __init__( @@ -240,7 +240,7 @@ def _target_q_compute_value( return target_q - def _update_with_batch(self, batch: RolloutBatchProtocol) -> TREDQTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> REDQTrainingStats: # type: ignore # critic ensemble weight = getattr(batch, "weight", 1.0) current_qs = self.critic(batch.obs, batch.act).flatten(1) @@ -268,7 +268,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TREDQTrainingStats: self._update_lagged_network_weights() - return REDQTrainingStats( # type: ignore[return-value] + return REDQTrainingStats( actor_loss=self._last_actor_loss, critic_loss=critic_loss.item(), alpha=self.alpha.value, diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index d666f8e9d..c315189cc 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -227,7 +227,7 @@ def update(self, entropy: torch.Tensor) -> float: class SAC( - ActorDualCriticsOffPolicyAlgorithm[SACPolicy, TSACTrainingStats, DistLogProbBatchProtocol], + ActorDualCriticsOffPolicyAlgorithm[SACPolicy, DistLogProbBatchProtocol], Generic[TSACTrainingStats], ): """Implementation of Soft Actor-Critic. arXiv:1812.05905.""" diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index fe86bde12..e28913713 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,7 +1,7 @@ from abc import ABC from copy import deepcopy from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Any import torch @@ -13,7 +13,6 @@ from tianshou.policy.base import ( TPolicy, TrainingStats, - TTrainingStats, ) from tianshou.policy.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, @@ -30,12 +29,8 @@ class TD3TrainingStats(TrainingStats): critic2_loss: float -TTD3TrainingStats = TypeVar("TTD3TrainingStats", bound=TD3TrainingStats) - - class ActorDualCriticsOffPolicyAlgorithm( - ActorCriticOffPolicyAlgorithm[TPolicy, TTrainingStats, TActBatchProtocol], - Generic[TPolicy, TTrainingStats, TActBatchProtocol], + ActorCriticOffPolicyAlgorithm[TPolicy, TActBatchProtocol], ABC, ): """A base class for off-policy algorithms with two critics, where the target Q-value is computed as the minimum @@ -108,8 +103,7 @@ def _target_q_compute_value( class TD3( - ActorDualCriticsOffPolicyAlgorithm[DDPGPolicy, TTD3TrainingStats, ActStateBatchProtocol], - Generic[TTD3TrainingStats], + ActorDualCriticsOffPolicyAlgorithm[DDPGPolicy, ActStateBatchProtocol], ): """Implementation of TD3, arXiv:1802.09477.""" @@ -189,7 +183,7 @@ def _target_q_compute_action(self, obs_batch: Batch) -> ActStateBatchProtocol: act_batch.act = act_ return act_batch - def _update_with_batch(self, batch: RolloutBatchProtocol) -> TTD3TrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> TD3TrainingStats: # critic 1&2 td1, critic1_loss = self._minimize_critic_squared_loss( batch, self.critic, self.critic_optim @@ -207,7 +201,7 @@ def _update_with_batch(self, batch: RolloutBatchProtocol) -> TTD3TrainingStats: self._update_lagged_network_weights() self._cnt += 1 - return TD3TrainingStats( # type: ignore[return-value] + return TD3TrainingStats( actor_loss=self._last, critic1_loss=critic1_loss.item(), critic2_loss=critic2_loss.item(), diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index f2e13abfc..d5a0b8e34 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -1,12 +1,12 @@ import warnings from dataclasses import dataclass -from typing import TypeVar import torch import torch.nn.functional as F from torch.distributions import kl_divergence -from tianshou.data import Batch, SequenceSummaryStats +from tianshou.data import SequenceSummaryStats +from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import NPG from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy @@ -20,10 +20,7 @@ class TRPOTrainingStats(NPGTrainingStats): step_size: SequenceSummaryStats -TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats) - - -class TRPO(NPG[TTRPOTrainingStats]): +class TRPO(NPG): """Implementation of Trust Region Policy Optimization. arXiv:1502.05477.""" def __init__( @@ -111,12 +108,12 @@ def __init__( self.max_kl = max_kl self.backtrack_coeff = backtrack_coeff - def _update_with_batch( # type: ignore + def _update_with_batch( self, - batch: Batch, + batch: RolloutBatchProtocol, batch_size: int | None, repeat: int, - ) -> TTRPOTrainingStats: + ) -> TRPOTrainingStats: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -198,7 +195,7 @@ def _update_with_batch( # type: ignore kl_summary_stat = SequenceSummaryStats.from_sequence(kls) step_size_stat = SequenceSummaryStats.from_sequence(step_sizes) - return TRPOTrainingStats( # type: ignore[return-value] + return TRPOTrainingStats( actor_loss=actor_loss_summary_stat, vf_loss=vf_loss_summary_stat, kl=kl_summary_stat, diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 52592fdc2..5a3df68ea 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -268,7 +268,7 @@ def dispatch_update_with_batch( return MapTrainingStats(agent_id_to_stats) -class MultiAgentOffPolicyAlgorithm(OffPolicyAlgorithm[MultiAgentPolicy, MapTrainingStats]): +class MultiAgentOffPolicyAlgorithm(OffPolicyAlgorithm[MultiAgentPolicy]): """Multi-agent reinforcement learning where each agent uses off-policy learning.""" def __init__( @@ -312,7 +312,7 @@ def update(algorithm: OffPolicyAlgorithm, data: RolloutBatchProtocol) -> Trainin return self._dispatcher.dispatch_update_with_batch(batch, update) -class MultiAgentOnPolicyAlgorithm(OnPolicyAlgorithm[MultiAgentPolicy, MapTrainingStats]): +class MultiAgentOnPolicyAlgorithm(OnPolicyAlgorithm[MultiAgentPolicy]): """Multi-agent reinforcement learning where each agent uses on-policy learning.""" def __init__( diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 675c570d3..db0ad27a1 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -1,4 +1,4 @@ -from typing import TypeVar, cast +from typing import cast import gymnasium as gym import numpy as np @@ -14,9 +14,6 @@ class MARLRandomTrainingStats(TrainingStats): pass -TMARLRandomTrainingStats = TypeVar("TMARLRandomTrainingStats", bound=MARLRandomTrainingStats) - - class MARLRandomDiscreteMaskedOffPolicyAlgorithm(OffPolicyAlgorithm): """A random agent used in multi-agent learning. @@ -58,6 +55,6 @@ def __init__(self, action_space: gym.spaces.Space) -> None: """:param action_space: the environment's action space.""" super().__init__(policy=self.Policy(action_space)) - def _update_with_batch(self, batch: RolloutBatchProtocol) -> TMARLRandomTrainingStats: # type: ignore + def _update_with_batch(self, batch: RolloutBatchProtocol) -> MARLRandomTrainingStats: # type: ignore """Since a random agent learns nothing, it returns an empty dict.""" - return MARLRandomTrainingStats() # type: ignore[return-value] + return MARLRandomTrainingStats() diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index d0922ae47..2c4a3834d 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -489,7 +489,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class BranchingNet(nn.Module, PolicyForwardInterface[Any]): +class BranchingNet(nn.Module, PolicyForwardInterface): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module @@ -536,18 +536,16 @@ def __init__( :param value_hidden_sizes: shape of the value MLP network passed in as a list. :param action_hidden_sizes: shape of the action MLP network passed in as a list. :param norm_layer: use which normalization before activation, e.g., - ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. - You can also pass a list of normalization modules with the same length - of hidden_sizes, to use different normalization module in different - layers. Default to no normalization. + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. :param activation: which activation to use after each layer, can be both - the same activation for all layers if passed in nn.Module, or different - activation for different Modules if passed in a list. Default to - nn.ReLU. - :param softmax: whether to apply a softmax layer over the last layer's - output. + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. """ - super().__init__(output_dim=10) + super().__init__() common_hidden_sizes = common_hidden_sizes or [] value_hidden_sizes = value_hidden_sizes or [] action_hidden_sizes = action_hidden_sizes or [] From 3c94128ff62830cce92152f325bcab1537603ebb Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 5 May 2025 17:46:20 +0200 Subject: [PATCH 116/230] v2: removed mujoco-py from dependencies --- poetry.lock | 474 +++++++++++++++++++++++++++++++++++++++++-------- pyproject.toml | 4 +- 2 files changed, 401 insertions(+), 77 deletions(-) diff --git a/poetry.lock b/poetry.lock index 583b90756..675a1a525 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -6,6 +6,7 @@ version = "2.0.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "absl-py-2.0.0.tar.gz", hash = "sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5"}, {file = "absl_py-2.0.0-py3-none-any.whl", hash = "sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3"}, @@ -17,6 +18,7 @@ version = "0.0.4" description = "A collection of accessible pygments styles" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "accessible-pygments-0.0.4.tar.gz", hash = "sha256:e7b57a9b15958e9601c7e9eb07a440c813283545a20973f2574a5f453d0e953e"}, {file = "accessible_pygments-0.0.4-py2.py3-none-any.whl", hash = "sha256:416c6d8c1ea1c5ad8701903a20fcedf953c6e720d64f33dc47bfb2d3f2fa4e8d"}, @@ -31,6 +33,8 @@ version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, @@ -45,6 +49,7 @@ version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "alabaster-0.7.13-py3-none-any.whl", hash = "sha256:1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3"}, {file = "alabaster-0.7.13.tar.gz", hash = "sha256:a27a4a084d5e690e16e01e03ad2b2e552c61a65469419b907243193de1a84ae2"}, @@ -56,6 +61,8 @@ version = "0.8.1" description = "The Arcade Learning Environment (ALE) - a platform for AI research." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "ale_py-0.8.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:b2aa2f69a4169742800615970efe6914fa856e33eaf7fa9133c0e06a617a80e2"}, {file = "ale_py-0.8.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f2f6b92c8fd6189654979bbf0b305dbe0ecf82176c47f244d8c1cbc36286b89"}, @@ -91,6 +98,7 @@ version = "4.0.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"}, {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"}, @@ -102,7 +110,7 @@ sniffio = ">=1.1" [package.extras] doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17) ; python_version < \"3.12\" and platform_python_implementation == \"CPython\" and platform_system != \"Windows\""] trio = ["trio (>=0.22)"] [[package]] @@ -111,6 +119,7 @@ version = "1.4.1" description = "Handy tools for working with URLs and APIs." optional = false python-versions = ">=3.6.1" +groups = ["dev"] files = [ {file = "apeye-1.4.1-py3-none-any.whl", hash = "sha256:44e58a9104ec189bf42e76b3a7fe91e2b2879d96d48e9a77e5e32ff699c9204e"}, {file = "apeye-1.4.1.tar.gz", hash = "sha256:14ea542fad689e3bfdbda2189a354a4908e90aee4bf84c15ab75d68453d76a36"}, @@ -132,6 +141,7 @@ version = "1.1.4" description = "Core (offline) functionality for the apeye library." optional = false python-versions = ">=3.6.1" +groups = ["dev"] files = [ {file = "apeye_core-1.1.4-py3-none-any.whl", hash = "sha256:084bc696448d3ac428fece41c1f2eb08fa9d9ce1d1b2f4d43187e3def4528a60"}, {file = "apeye_core-1.1.4.tar.gz", hash = "sha256:72bb89fed3baa647cb81aa28e1d851787edcbf9573853b5d2b5f87c02f50eaf5"}, @@ -147,6 +157,8 @@ version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" optional = false python-versions = "*" +groups = ["dev"] +markers = "platform_system == \"Darwin\" or sys_platform == \"darwin\"" files = [ {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, @@ -158,6 +170,8 @@ version = "5.3.1" description = "ARCH for Python" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "arch-5.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:75fa6f9386ecc2df81bcbf5d055a290a697482ca51e0b3459dab183d288993cb"}, {file = "arch-5.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f9c9220d331618322517e0f2b3b3529f9c51f5e5a891441da4a107fd2d6d7fce"}, @@ -197,6 +211,7 @@ version = "23.1.0" description = "Argon2 for Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"}, {file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"}, @@ -217,6 +232,7 @@ version = "21.2.0" description = "Low-level CFFI bindings for Argon2" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"}, {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"}, @@ -254,6 +270,7 @@ version = "1.3.0" description = "Better dates & times for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80"}, {file = "arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85"}, @@ -273,6 +290,7 @@ version = "2.4.1" description = "Annotate AST trees with source code positions" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"}, {file = "asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0"}, @@ -282,8 +300,8 @@ files = [ six = ">=1.12.0" [package.extras] -astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] -test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] +astroid = ["astroid (>=1,<2) ; python_version < \"3\"", "astroid (>=2,<4) ; python_version >= \"3\""] +test = ["astroid (>=1,<2) ; python_version < \"3\"", "astroid (>=2,<4) ; python_version >= \"3\"", "pytest"] [[package]] name = "async-lru" @@ -291,6 +309,7 @@ version = "2.0.4" description = "Simple LRU cache for asyncio" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "async-lru-2.0.4.tar.gz", hash = "sha256:b8a59a5df60805ff63220b2a0c5b5393da5521b113cd5465a44eb037d81a5627"}, {file = "async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224"}, @@ -302,6 +321,7 @@ version = "23.1.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, @@ -312,7 +332,7 @@ cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] dev = ["attrs[docs,tests]", "pre-commit"] docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-no-zope = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.1.1) ; platform_python_implementation == \"CPython\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version < \"3.11\"", "pytest-xdist[psutil]"] [[package]] name = "autodocsumm" @@ -320,6 +340,7 @@ version = "0.2.11" description = "Extended sphinx autodoc including automatic autosummaries" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "autodocsumm-0.2.11-py3-none-any.whl", hash = "sha256:f1d0a623bf1ad64d979a9e23fd360d1fb1b8f869beaf3197f711552cddc174e2"}, {file = "autodocsumm-0.2.11.tar.gz", hash = "sha256:183212bd9e9f3b58a96bb21b7958ee4e06224107aa45b2fd894b61b83581b9a9"}, @@ -334,6 +355,7 @@ version = "2.0.4" description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "autopep8-2.0.4-py2.py3-none-any.whl", hash = "sha256:067959ca4a07b24dbd5345efa8325f5f58da4298dab0dde0443d5ed765de80cb"}, {file = "autopep8-2.0.4.tar.gz", hash = "sha256:2913064abd97b3419d1cc83ea71f042cb821f87e45b9c88cad5ad3c4ea87fe0c"}, @@ -348,6 +370,8 @@ version = "0.4.2" description = "Automated installation of Atari ROMs for Gym/ALE-Py" optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "AutoROM-0.4.2-py3-none-any.whl", hash = "sha256:719c9d363ef08391fdb7003d70df235b68f36de628d289a946c4a59a3adefa13"}, {file = "AutoROM-0.4.2.tar.gz", hash = "sha256:b426a39bc0ee3781c7791f28963a9b2e4385b6421eeaf2f368edc00c761d428a"}, @@ -368,6 +392,8 @@ version = "0.6.1" description = "Automated installation of Atari ROMs for Gym/ALE-Py" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "AutoROM.accept-rom-license-0.6.1.tar.gz", hash = "sha256:0c905a708d634a076f686802f672817d3585259ce3be0bde8713a4fb59e3159e"}, ] @@ -385,6 +411,7 @@ version = "2.13.1" description = "Internationalization utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "Babel-2.13.1-py3-none-any.whl", hash = "sha256:7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed"}, {file = "Babel-2.13.1.tar.gz", hash = "sha256:33e0952d7dd6374af8dbf6768cc4ddf3ccfefc244f9986d4074704f2fbd18900"}, @@ -402,6 +429,7 @@ version = "4.12.2" description = "Screen-scraping library" optional = false python-versions = ">=3.6.0" +groups = ["dev"] files = [ {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, @@ -420,6 +448,7 @@ version = "23.11.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, @@ -462,6 +491,7 @@ version = "6.1.0" description = "An easy safelist-based HTML-sanitizing tool." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "bleach-6.1.0-py3-none-any.whl", hash = "sha256:3225f354cfc436b9789c66c4ee030194bee0568fbf9cbdad3bc8b5c26c5f12b6"}, {file = "bleach-6.1.0.tar.gz", hash = "sha256:0a31f1837963c41d46bbf1331b8778e1308ea0791db03cc4e7357b97cf42a8fe"}, @@ -480,6 +510,8 @@ version = "2.3.5" description = "Python Box2D" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"box2d\"" files = [ {file = "box2d-py-2.3.5.tar.gz", hash = "sha256:b37dc38844bcd7def48a97111d2b082e4f81cca3cece7460feb3eacda0da2207"}, {file = "box2d_py-2.3.5-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:287aa54005c0644b47bf7ad72966e4068d66e56bcf8458f5b4a653ffe42a2618"}, @@ -494,6 +526,7 @@ version = "0.13.1" description = "httplib2 caching for requests" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "cachecontrol-0.13.1-py3-none-any.whl", hash = "sha256:95dedbec849f46dda3137866dc28b9d133fc9af55f5b805ab1291833e4457aa4"}, {file = "cachecontrol-0.13.1.tar.gz", hash = "sha256:f012366b79d2243a6118309ce73151bf52a38d4a5dac8ea57f09bd29087e506b"}, @@ -515,6 +548,7 @@ version = "5.3.2" description = "Extensible memoizing collections and decorators" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"}, {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, @@ -526,6 +560,7 @@ version = "2024.7.4" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, @@ -537,6 +572,7 @@ version = "1.16.0" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, @@ -601,6 +637,7 @@ version = "3.4.0" description = "Validate configuration and produce human readable error messages." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, @@ -612,6 +649,7 @@ version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" +groups = ["main", "dev"] files = [ {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, @@ -711,10 +749,12 @@ version = "8.1.7" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, ] +markers = {main = "extra == \"atari\""} [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} @@ -725,6 +765,7 @@ version = "3.0.0" description = "Pickler class to extend the standard pickle.Pickler functionality" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, @@ -736,10 +777,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "platform_system == \"Windows\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "comm" @@ -747,6 +790,7 @@ version = "0.2.0" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "comm-0.2.0-py3-none-any.whl", hash = "sha256:2da8d9ebb8dd7bfc247adaff99f24dce705638a8042b85cb995066793e391001"}, {file = "comm-0.2.0.tar.gz", hash = "sha256:a517ea2ca28931c7007a7a99c562a0fa5883cfb48963140cf642c41c948498be"}, @@ -764,6 +808,7 @@ version = "1.2.1" description = "Python library for calculating contours of 2D quadrilateral grids" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"}, {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"}, @@ -827,6 +872,7 @@ version = "7.3.2" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "coverage-7.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d872145f3a3231a5f20fd48500274d7df222e291d90baa2026cc5152b7ce86bf"}, {file = "coverage-7.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:310b3bb9c91ea66d59c53fa4989f57d2436e08f18fb2f421a1b0b6b8cc7fffda"}, @@ -883,7 +929,7 @@ files = [ ] [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cssutils" @@ -891,6 +937,7 @@ version = "2.9.0" description = "A CSS Cascading Style Sheets library for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cssutils-2.9.0-py3-none-any.whl", hash = "sha256:f8b013169e281c0c6083207366c5005f5dd4549055f7aba840384fb06a78745c"}, {file = "cssutils-2.9.0.tar.gz", hash = "sha256:89477b3d17d790e97b9fb4def708767061055795aae6f7c82ae32e967c9be4cd"}, @@ -898,7 +945,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["cssselect", "importlib-resources", "jaraco.test (>=5.1)", "lxml", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +testing = ["cssselect", "importlib-resources ; python_version < \"3.9\"", "jaraco.test (>=5.1)", "lxml ; python_version < \"3.11\"", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-ruff"] [[package]] name = "cycler" @@ -906,6 +953,7 @@ version = "0.12.1" description = "Composable style cycles" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, @@ -921,6 +969,7 @@ version = "3.0.8" description = "The Cython compiler for writing C extensions in the Python language." optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["main"] files = [ {file = "Cython-3.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a846e0a38e2b24e9a5c5dc74b0e54c6e29420d88d1dafabc99e0fc0f3e338636"}, {file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45523fdc2b78d79b32834cc1cc12dc2ca8967af87e22a3ee1bff20e77c7f5520"}, @@ -988,6 +1037,7 @@ version = "1.8.0" description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "debugpy-1.8.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7fb95ca78f7ac43393cd0e0f2b6deda438ec7c5e47fa5d38553340897d2fbdfb"}, {file = "debugpy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef9ab7df0b9a42ed9c878afd3eaaff471fce3fa73df96022e1f5c9f8f8c87ada"}, @@ -1015,6 +1065,7 @@ version = "5.1.1" description = "Decorators for Humans" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, @@ -1026,6 +1077,7 @@ version = "7.0.1" description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"}, {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"}, @@ -1044,6 +1096,7 @@ version = "0.7.1" description = "XML bomb protection for Python stdlib modules" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, @@ -1055,6 +1108,7 @@ version = "0.3.0.post1" description = "A μ-library for constructing cascading style sheets from Python dictionaries." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "dict2css-0.3.0.post1-py3-none-any.whl", hash = "sha256:f006a6b774c3e31869015122ae82c491fd25e7de4a75607a62aa3e798f837e0d"}, {file = "dict2css-0.3.0.post1.tar.gz", hash = "sha256:89c544c21c4ca7472c3fffb9d37d3d926f606329afdb751dc1de67a411b70719"}, @@ -1070,6 +1124,7 @@ version = "0.3.7" description = "Distribution utilities" optional = false python-versions = "*" +groups = ["main", "dev"] files = [ {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"}, {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, @@ -1081,6 +1136,8 @@ version = "1.6" description = "A Python interface for Reinforcement Learning environments." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de"}, {file = "dm_env-1.6-py3-none-any.whl", hash = "sha256:0eabb6759dd453b625e041032f7ae0c1e87d4eb61b6a96b9ca586483837abf29"}, @@ -1097,6 +1154,8 @@ version = "0.1.8" description = "Tree is a library for working with nested data structures." optional = true python-versions = "*" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"}, {file = "dm_tree-0.1.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60"}, @@ -1152,6 +1211,7 @@ version = "0.4.0" description = "Python bindings for the docker credentials store API" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"}, {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"}, @@ -1166,6 +1226,8 @@ version = "0.15" description = "Parse Python docstrings in reST, Google and Numpydoc format" optional = true python-versions = ">=3.6,<4.0" +groups = ["main"] +markers = "extra == \"argparse\" or extra == \"eval\"" files = [ {file = "docstring_parser-0.15-py3-none-any.whl", hash = "sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9"}, {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"}, @@ -1177,6 +1239,7 @@ version = "0.20.1" description = "Docutils -- Python Documentation Utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "docutils-0.20.1-py3-none-any.whl", hash = "sha256:96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6"}, {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"}, @@ -1188,6 +1251,7 @@ version = "3.7.0" description = "Helpful functions for Python 🐍 🛠️" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "domdf_python_tools-3.7.0-py3-none-any.whl", hash = "sha256:7b4d1c3bdb7402b872d43953824bf921ae2e52f893adbe5c0052a21a6efa2fe4"}, {file = "domdf_python_tools-3.7.0.tar.gz", hash = "sha256:df1af9a91649af0fb2a4e7b3a4b0a0936e4f78389dd7280dd6fd2f53a339ca71"}, @@ -1207,6 +1271,8 @@ version = "0.8.4" description = "\"C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.\"" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "envpool-0.8.4-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:9c6a1af66c8a18d798b3069e8eee4cde2e5942af22b25d058189714f2630b024"}, {file = "envpool-0.8.4-cp311-cp311-manylinux_2_24_x86_64.whl", hash = "sha256:2407294307a3e20c18787bb836a94cc0649e708b04d8a8200be674f5fc46f3b4"}, @@ -1231,13 +1297,14 @@ version = "2.0.1" description = "Get the currently executing AST node of a frame, and other information" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, ] [package.extras] -tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""] [[package]] name = "farama-notifications" @@ -1245,28 +1312,19 @@ version = "0.0.4" description = "Notifications for all Farama Foundation maintained libraries." optional = false python-versions = "*" +groups = ["main"] files = [ {file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"}, {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, ] -[[package]] -name = "fasteners" -version = "0.19" -description = "A python package that provides useful locks" -optional = true -python-versions = ">=3.6" -files = [ - {file = "fasteners-0.19-py3-none-any.whl", hash = "sha256:758819cb5d94cdedf4e836988b74de396ceacb8e2794d21f82d131fd9ee77237"}, - {file = "fasteners-0.19.tar.gz", hash = "sha256:b4f37c3ac52d8a445af3a66bce57b33b5e90b97c696b7b984f530cf8f0ded09c"}, -] - [[package]] name = "fastjsonschema" version = "2.19.0" description = "Fastest Python implementation of JSON schema" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "fastjsonschema-2.19.0-py3-none-any.whl", hash = "sha256:b9fd1a2dd6971dbc7fee280a95bd199ae0dd9ce22beb91cc75e9c1c528a5170e"}, {file = "fastjsonschema-2.19.0.tar.gz", hash = "sha256:e25df6647e1bc4a26070b700897b07b542ec898dd4f1f6ea013e7f6a88417225"}, @@ -1281,6 +1339,7 @@ version = "3.13.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, @@ -1289,7 +1348,7 @@ files = [ [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] -typing = ["typing-extensions (>=4.8)"] +typing = ["typing-extensions (>=4.8) ; python_version < \"3.11\""] [[package]] name = "fonttools" @@ -1297,6 +1356,7 @@ version = "4.51.0" description = "Tools to manipulate font files" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74"}, {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308"}, @@ -1343,18 +1403,18 @@ files = [ ] [package.extras] -all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +all = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\"", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0) ; python_version <= \"3.12\"", "xattr ; sys_platform == \"darwin\"", "zopfli (>=0.1.4)"] graphite = ["lz4 (>=1.7.4.2)"] -interpolatable = ["munkres", "pycairo", "scipy"] +interpolatable = ["munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\""] lxml = ["lxml (>=4.0)"] pathops = ["skia-pathops (>=0.5.0)"] plot = ["matplotlib"] repacker = ["uharfbuzz (>=0.23.0)"] symfont = ["sympy"] -type1 = ["xattr"] +type1 = ["xattr ; sys_platform == \"darwin\""] ufo = ["fs (>=2.2.0,<3)"] -unicode = ["unicodedata2 (>=15.1.0)"] -woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] +unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""] +woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"] [[package]] name = "fqdn" @@ -1362,6 +1422,7 @@ version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" optional = false python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" +groups = ["dev"] files = [ {file = "fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014"}, {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, @@ -1373,6 +1434,8 @@ version = "1.4.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"}, @@ -1443,6 +1506,7 @@ version = "2023.10.0" description = "File-system specification" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "fsspec-2023.10.0-py3-none-any.whl", hash = "sha256:346a8f024efeb749d2a5fca7ba8854474b1ff9af7c3faaf636a4548781136529"}, {file = "fsspec-2023.10.0.tar.gz", hash = "sha256:330c66757591df346ad3091a53bd907e15348c2ba17d63fd54f5c39c4457d2a5"}, @@ -1478,6 +1542,7 @@ version = "4.0.11" description = "Git Object Database" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"}, {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"}, @@ -1492,6 +1557,7 @@ version = "3.1.41" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "GitPython-3.1.41-py3-none-any.whl", hash = "sha256:c36b6634d069b3f719610175020a9aed919421c87552185b085e04fbbdb10b7c"}, {file = "GitPython-3.1.41.tar.gz", hash = "sha256:ed66e624884f76df22c8e16066d567aaa5a37d5b5fa19db2c6df6f7156db9048"}, @@ -1501,7 +1567,7 @@ files = [ gitdb = ">=4.0.1,<5" [package.extras] -test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] +test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] [[package]] name = "glfw" @@ -1509,6 +1575,8 @@ version = "2.6.5" description = "A ctypes-based wrapper for GLFW3." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "glfw-2.6.5-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_10_6_intel.whl", hash = "sha256:57d00367f8dc31b898a47ab22849bab9f87dff4b4c7a56d16d9a7158cda96c19"}, {file = "glfw-2.6.5-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_11_0_arm64.whl", hash = "sha256:a1a132e7d6f78ae7f32957b56de2fd996d2a416f9520adb40345cc9cf744d277"}, @@ -1530,6 +1598,7 @@ version = "2.23.4" description = "Google Authentication Library" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "google-auth-2.23.4.tar.gz", hash = "sha256:79905d6b1652187def79d491d6e23d0cbb3a21d3c7ba0dbaa9c8a01906b13ff3"}, {file = "google_auth-2.23.4-py2.py3-none-any.whl", hash = "sha256:d4bbc92fe4b8bfd2f3e8d88e5ba7085935da208ee38a134fc280e7ce682a05f2"}, @@ -1553,6 +1622,7 @@ version = "1.1.0" description = "Google Authentication Library" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "google-auth-oauthlib-1.1.0.tar.gz", hash = "sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb"}, {file = "google_auth_oauthlib-1.1.0-py2.py3-none-any.whl", hash = "sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12"}, @@ -1571,6 +1641,8 @@ version = "3.0.1" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\"" files = [ {file = "greenlet-3.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f89e21afe925fcfa655965ca8ea10f24773a1791400989ff32f467badfe4a064"}, {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28e89e232c7593d33cac35425b58950789962011cc274aa43ef8865f2e11f46d"}, @@ -1641,6 +1713,7 @@ version = "1.59.3" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "grpcio-1.59.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:aca028a6c7806e5b61e5f9f4232432c52856f7fcb98e330b20b6bc95d657bdcc"}, {file = "grpcio-1.59.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:19ad26a7967f7999c8960d2b9fe382dae74c55b0c508c613a6c2ba21cddf2354"}, @@ -1707,6 +1780,8 @@ version = "0.26.2" description = "Gym: A universal API for reinforcement learning environments" optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "gym-0.26.2.tar.gz", hash = "sha256:e0d882f4b54f0c65f203104c24ab8a38b039f1289986803c7d02cdbe214fbcc4"}, ] @@ -1734,6 +1809,8 @@ version = "0.0.8" description = "Notices for gym" optional = true python-versions = "*" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "gym-notices-0.0.8.tar.gz", hash = "sha256:ad25e200487cafa369728625fe064e88ada1346618526102659b4640f2b4b911"}, {file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"}, @@ -1745,6 +1822,7 @@ version = "0.28.1" description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "gymnasium-0.28.1-py3-none-any.whl", hash = "sha256:7bc9a5bce1022f997d1dbc152fc91d1ac977bad9cc7794cdc25437010867cabf"}, {file = "gymnasium-0.28.1.tar.gz", hash = "sha256:4c2c745808792c8f45c6e88ad0a5504774394e0c126f6e3db555e720d3da6f24"}, @@ -1776,6 +1854,8 @@ version = "1.2.3" description = "Robotics environments for the Gymnasium repo." optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\"" files = [ {file = "gymnasium-robotics-1.2.3.tar.gz", hash = "sha256:b01eb9df74c0041e559e1251442ba1a59174bfc71a1c58519724d76df803c0b6"}, {file = "gymnasium_robotics-1.2.3-py3-none-any.whl", hash = "sha256:9c3cd7bcc7ac7a0efca03d5685a01686661c7fa678e34adfe4e15044580e7180"}, @@ -1799,6 +1879,7 @@ version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, @@ -1810,6 +1891,7 @@ version = "3.10.0" description = "Read and write HDF5 files from Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "h5py-3.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b963fb772964fc1d1563c57e4e2e874022ce11f75ddc6df1a626f42bd49ab99f"}, {file = "h5py-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:012ab448590e3c4f5a8dd0f3533255bc57f80629bf7c5054cf4c87b30085063c"}, @@ -1847,6 +1929,7 @@ version = "1.1" description = "HTML parser based on the WHATWG HTML specification" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "html5lib-1.1-py2.py3-none-any.whl", hash = "sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d"}, {file = "html5lib-1.1.tar.gz", hash = "sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f"}, @@ -1857,10 +1940,10 @@ six = ">=1.9" webencodings = "*" [package.extras] -all = ["chardet (>=2.2)", "genshi", "lxml"] +all = ["chardet (>=2.2)", "genshi", "lxml ; platform_python_implementation == \"CPython\""] chardet = ["chardet (>=2.2)"] genshi = ["genshi"] -lxml = ["lxml"] +lxml = ["lxml ; platform_python_implementation == \"CPython\""] [[package]] name = "httpcore" @@ -1868,6 +1951,7 @@ version = "1.0.5" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, @@ -1889,6 +1973,7 @@ version = "0.27.2" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, @@ -1902,7 +1987,7 @@ idna = "*" sniffio = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -1914,6 +1999,7 @@ version = "2.5.32" description = "File identification library for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "identify-2.5.32-py2.py3-none-any.whl", hash = "sha256:0b7656ef6cba81664b783352c73f8c24b39cf82f926f78f4550eda928e5e0545"}, {file = "identify-2.5.32.tar.gz", hash = "sha256:5d9979348ec1a21c768ae07e0a652924538e8bce67313a73cb0f681cf08ba407"}, @@ -1928,6 +2014,7 @@ version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.5" +groups = ["main", "dev"] files = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, @@ -1939,6 +2026,8 @@ version = "2.33.1" description = "Library for reading and writing a wide range of image, video, scientific, and volumetric data formats." optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "imageio-2.33.1-py3-none-any.whl", hash = "sha256:c5094c48ccf6b2e6da8b4061cd95e1209380afafcbeae4a4e280938cce227e1d"}, {file = "imageio-2.33.1.tar.gz", hash = "sha256:78722d40b137bd98f5ec7312119f8aea9ad2049f76f434748eb306b6937cc1ce"}, @@ -1971,6 +2060,7 @@ version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b"}, {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"}, @@ -1982,6 +2072,7 @@ version = "6.8.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, @@ -1993,7 +2084,7 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +testing = ["flufl.flake8", "importlib-resources (>=1.3) ; python_version < \"3.9\"", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" @@ -2001,6 +2092,8 @@ version = "6.1.1" description = "Read resources from Python packages" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"}, {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"}, @@ -2008,7 +2101,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"] +testing = ["pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-ruff", "zipp (>=3.17)"] [[package]] name = "iniconfig" @@ -2016,6 +2109,7 @@ version = "2.0.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -2027,6 +2121,7 @@ version = "6.26.0" description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "ipykernel-6.26.0-py3-none-any.whl", hash = "sha256:3ba3dc97424b87b31bb46586b5167b3161b32d7820b9201a9e698c71e271602c"}, {file = "ipykernel-6.26.0.tar.gz", hash = "sha256:553856658eb8430bbe9653ea041a41bff63e9606fc4628873fc92a6cf3abd404"}, @@ -2060,6 +2155,7 @@ version = "8.17.2" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "ipython-8.17.2-py3-none-any.whl", hash = "sha256:1e4d1d666a023e3c93585ba0d8e962867f7a111af322efff6b9c58062b3e5444"}, {file = "ipython-8.17.2.tar.gz", hash = "sha256:126bb57e1895594bb0d91ea3090bbd39384f6fe87c3d57fd558d0670f50339bb"}, @@ -2096,6 +2192,7 @@ version = "8.1.1" description = "Jupyter interactive widgets" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ipywidgets-8.1.1-py3-none-any.whl", hash = "sha256:2b88d728656aea3bbfd05d32c747cfd0078f9d7e159cf982433b58ad717eed7f"}, {file = "ipywidgets-8.1.1.tar.gz", hash = "sha256:40211efb556adec6fa450ccc2a77d59ca44a060f4f9f136833df59c9f538e6e8"}, @@ -2117,6 +2214,7 @@ version = "20.11.0" description = "Operations with ISO 8601 durations" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "isoduration-20.11.0-py3-none-any.whl", hash = "sha256:b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042"}, {file = "isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9"}, @@ -2131,6 +2229,7 @@ version = "1.0.0" description = "Common backend for Jax or Numpy." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "jax-jumpy-1.0.0.tar.gz", hash = "sha256:195fb955cc4c2b7f0b1453e3cb1fb1c414a51a407ffac7a51e69a73cb30d59ad"}, {file = "jax_jumpy-1.0.0-py3-none-any.whl", hash = "sha256:ab7e01454bba462de3c4d098e3e585c302a8f06bc36d9182ab4e7e4aa7067c5e"}, @@ -2149,6 +2248,7 @@ version = "0.19.1" description = "An autocompletion tool for Python that can be used for text editors." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, @@ -2168,6 +2268,7 @@ version = "3.1.4" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, @@ -2185,6 +2286,8 @@ version = "1.4.0" description = "Lightweight pipelining with Python functions" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "joblib-1.4.0-py3-none-any.whl", hash = "sha256:42942470d4062537be4d54c83511186da1fc14ba354961a2114da91efa9a4ed7"}, {file = "joblib-1.4.0.tar.gz", hash = "sha256:1eb0dc091919cd384490de890cb5dfd538410a6d4b3b54eef09fb8c50b409b1c"}, @@ -2196,6 +2299,7 @@ version = "0.9.14" description = "A Python implementation of the JSON5 data format." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "json5-0.9.14-py2.py3-none-any.whl", hash = "sha256:740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f"}, {file = "json5-0.9.14.tar.gz", hash = "sha256:9ed66c3a6ca3510a976a9ef9b8c0787de24802724ab1860bc0153c7fdd589b02"}, @@ -2210,6 +2314,8 @@ version = "4.27.0" description = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"argparse\" or extra == \"eval\"" files = [ {file = "jsonargparse-4.27.0-py3-none-any.whl", hash = "sha256:a6378bc8b7bbe38b708f090b10ea8431216e71f8b2eea1f9a4f095ae4abd0f2e"}, {file = "jsonargparse-4.27.0.tar.gz", hash = "sha256:6ac791cd7913cff34ad2dbd3ed0431f9e327af0be926332ac060bd5b13d353f2"}, @@ -2225,7 +2331,7 @@ coverage = ["jsonargparse[test-no-urls]", "pytest-cov (>=4.0.0)"] dev = ["build (>=0.10.0)", "jsonargparse[coverage]", "jsonargparse[doc]", "jsonargparse[mypy]", "jsonargparse[test]", "pre-commit (>=2.19.0)", "tox (>=3.25.0)"] doc = ["Sphinx (>=1.7.9)", "autodocsumm (>=0.1.10)", "sphinx-autodoc-typehints (>=1.19.5)", "sphinx-rtd-theme (>=1.2.2)"] fsspec = ["fsspec (>=0.8.4)"] -jsonnet = ["jsonnet (>=0.13.0)", "jsonnet-binary (>=0.17.0)"] +jsonnet = ["jsonnet (>=0.13.0) ; os_name == \"posix\"", "jsonnet-binary (>=0.17.0) ; os_name != \"posix\""] jsonschema = ["jsonschema (>=3.2.0)"] maintainer = ["bump2version (>=0.5.11)", "twine (>=4.0.2)"] omegaconf = ["omegaconf (>=2.1.1)"] @@ -2234,7 +2340,7 @@ ruyaml = ["ruyaml (>=0.20.0)"] signatures = ["docstring-parser (>=0.15)", "jsonargparse[typing-extensions]", "typeshed-client (>=2.1.0)"] test = ["attrs (>=22.2.0)", "jsonargparse[test-no-urls]", "pydantic (>=2.3.0)", "responses (>=0.12.0)", "types-PyYAML (>=6.0.11)", "types-requests (>=2.28.9)"] test-no-urls = ["pytest (>=6.2.5)", "pytest-subtests (>=0.8.0)"] -typing-extensions = ["typing-extensions (>=3.10.0.0)"] +typing-extensions = ["typing-extensions (>=3.10.0.0) ; python_version < \"3.10\""] urls = ["requests (>=2.18.4)"] [[package]] @@ -2243,6 +2349,7 @@ version = "2.4" description = "Identify specific nodes in a JSON document (RFC 6901)" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" +groups = ["dev"] files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, @@ -2254,6 +2361,7 @@ version = "4.20.0" description = "An implementation of JSON Schema validation for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jsonschema-4.20.0-py3-none-any.whl", hash = "sha256:ed6231f0429ecf966f5bc8dfef245998220549cbbcf140f913b7464c52c3b6b3"}, {file = "jsonschema-4.20.0.tar.gz", hash = "sha256:4f614fd46d8d61258610998997743ec5492a648b33cf478c1ddc23ed4598a5fa"}, @@ -2283,6 +2391,7 @@ version = "2023.11.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jsonschema_specifications-2023.11.1-py3-none-any.whl", hash = "sha256:f596778ab612b3fd29f72ea0d990393d0540a5aab18bf0407a46632eab540779"}, {file = "jsonschema_specifications-2023.11.1.tar.gz", hash = "sha256:c9b234904ffe02f079bf91b14d79987faa685fd4b39c377a0996954c0090b9ca"}, @@ -2297,6 +2406,7 @@ version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "jupyter-1.0.0-py2.py3-none-any.whl", hash = "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78"}, {file = "jupyter-1.0.0.tar.gz", hash = "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"}, @@ -2317,6 +2427,7 @@ version = "1.0.0" description = "Build a book with Jupyter Notebooks and Sphinx." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "jupyter_book-1.0.0-py3-none-any.whl", hash = "sha256:18238f1e7e1d425731e60ab509a7da878dd6db88b7d77bcfab4690361b72e1be"}, {file = "jupyter_book-1.0.0.tar.gz", hash = "sha256:539c5d0493546200d9de27bd4b5f77eaea03115f8937f825d4ff82b3801a987e"}, @@ -2354,6 +2465,7 @@ version = "0.6.1" description = "A defined interface for working with a cache of jupyter notebooks." optional = false python-versions = "~=3.8" +groups = ["dev"] files = [ {file = "jupyter-cache-0.6.1.tar.gz", hash = "sha256:26f83901143edf4af2f3ff5a91e2d2ad298e46e2cee03c8071d37a23a63ccbfc"}, {file = "jupyter_cache-0.6.1-py3-none-any.whl", hash = "sha256:2fce7d4975805c77f75bdfc1bc2e82bc538b8e5b1af27f2f5e06d55b9f996a82"}, @@ -2381,6 +2493,7 @@ version = "8.6.0" description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_client-8.6.0-py3-none-any.whl", hash = "sha256:909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99"}, {file = "jupyter_client-8.6.0.tar.gz", hash = "sha256:0642244bb83b4764ae60d07e010e15f0e2d275ec4e918a8f7b80fbbef3ca60c7"}, @@ -2395,7 +2508,7 @@ traitlets = ">=5.3" [package.extras] docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] -test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] +test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko ; sys_platform == \"win32\"", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] [[package]] name = "jupyter-console" @@ -2403,6 +2516,7 @@ version = "6.6.3" description = "Jupyter terminal console" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485"}, {file = "jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539"}, @@ -2427,6 +2541,7 @@ version = "5.5.0" description = "Jupyter core package. A base package on which Jupyter projects rely." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_core-5.5.0-py3-none-any.whl", hash = "sha256:e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805"}, {file = "jupyter_core-5.5.0.tar.gz", hash = "sha256:880b86053bf298a8724994f95e99b99130659022a4f7f45f563084b6223861d3"}, @@ -2447,6 +2562,7 @@ version = "0.9.0" description = "Jupyter Event System library" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_events-0.9.0-py3-none-any.whl", hash = "sha256:d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf"}, {file = "jupyter_events-0.9.0.tar.gz", hash = "sha256:81ad2e4bc710881ec274d31c6c50669d71bbaa5dd9d01e600b56faa85700d399"}, @@ -2472,6 +2588,7 @@ version = "2.2.2" description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter-lsp-2.2.2.tar.gz", hash = "sha256:256d24620542ae4bba04a50fc1f6ffe208093a07d8e697fea0a8d1b8ca1b7e5b"}, {file = "jupyter_lsp-2.2.2-py3-none-any.whl", hash = "sha256:3b95229e4168355a8c91928057c1621ac3510ba98b2a925e82ebd77f078b1aa5"}, @@ -2486,6 +2603,7 @@ version = "2.11.2" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_server-2.11.2-py3-none-any.whl", hash = "sha256:0c548151b54bcb516ca466ec628f7f021545be137d01b5467877e87f6fff4374"}, {file = "jupyter_server-2.11.2.tar.gz", hash = "sha256:0c99f9367b0f24141e527544522430176613f9249849be80504c6d2b955004bb"}, @@ -2522,6 +2640,7 @@ version = "0.4.4" description = "A Jupyter Server Extension Providing Terminals." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_server_terminals-0.4.4-py3-none-any.whl", hash = "sha256:75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36"}, {file = "jupyter_server_terminals-0.4.4.tar.gz", hash = "sha256:57ab779797c25a7ba68e97bcfb5d7740f2b5e8a83b5e8102b10438041a7eac5d"}, @@ -2541,6 +2660,7 @@ version = "4.2.5" description = "JupyterLab computational environment" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyterlab-4.2.5-py3-none-any.whl", hash = "sha256:73b6e0775d41a9fee7ee756c80f58a6bed4040869ccc21411dc559818874d321"}, {file = "jupyterlab-4.2.5.tar.gz", hash = "sha256:ae7f3a1b8cb88b4f55009ce79fa7c06f99d70cd63601ee4aa91815d054f46f75"}, @@ -2574,6 +2694,7 @@ version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyterlab_pygments-0.2.2-py2.py3-none-any.whl", hash = "sha256:2405800db07c9f770863bcf8049a529c3dd4d3e28536638bd7c1c01d2748309f"}, {file = "jupyterlab_pygments-0.2.2.tar.gz", hash = "sha256:7405d7fde60819d905a9fa8ce89e4cd830e318cdad22a0030f7a901da705585d"}, @@ -2585,6 +2706,7 @@ version = "2.27.3" description = "A set of server components for JupyterLab and JupyterLab like applications." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4"}, {file = "jupyterlab_server-2.27.3.tar.gz", hash = "sha256:eb36caca59e74471988f0ae25c77945610b887f777255aa21f8065def9e51ed4"}, @@ -2610,6 +2732,7 @@ version = "3.0.9" description = "Jupyter interactive widgets for JupyterLab" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyterlab_widgets-3.0.9-py3-none-any.whl", hash = "sha256:3cf5bdf5b897bf3bccf1c11873aa4afd776d7430200f765e0686bd352487b58d"}, {file = "jupyterlab_widgets-3.0.9.tar.gz", hash = "sha256:6005a4e974c7beee84060fdfba341a3218495046de8ae3ec64888e5fe19fdb4c"}, @@ -2621,6 +2744,7 @@ version = "1.4.5" description = "A fast implementation of the Cassowary constraint solver" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"}, {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"}, @@ -2734,6 +2858,7 @@ version = "2.0.1" description = "A lexer and codec to work with LaTeX code in Python." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "latexcodec-2.0.1-py2.py3-none-any.whl", hash = "sha256:c277a193638dc7683c4c30f6684e3db728a06efb0dc9cf346db8bd0aa6c5d271"}, {file = "latexcodec-2.0.1.tar.gz", hash = "sha256:2aa2551c373261cefe2ad3a8953a6d6533e68238d180eb4bb91d7964adb3fe9a"}, @@ -2748,6 +2873,7 @@ version = "2.0.2" description = "Links recognition library with FULL unicode support." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "linkify-it-py-2.0.2.tar.gz", hash = "sha256:19f3060727842c254c808e99d465c80c49d2c7306788140987a1a7a29b0d6ad2"}, {file = "linkify_it_py-2.0.2-py3-none-any.whl", hash = "sha256:a3a24428f6c96f27370d7fe61d2ac0be09017be5190d68d8658233171f1b6541"}, @@ -2768,6 +2894,7 @@ version = "0.43.0" description = "lightweight wrapper around basic LLVM functionality" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, @@ -2798,6 +2925,7 @@ version = "3.5.1" description = "Python implementation of John Gruber's Markdown." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "Markdown-3.5.1-py3-none-any.whl", hash = "sha256:5874b47d4ee3f0b14d764324d2c94c03ea66bee56f2d929da9f2508d65e722dc"}, {file = "Markdown-3.5.1.tar.gz", hash = "sha256:b65d7beb248dc22f2e8a31fb706d93798093c308dc1aba295aedeb9d41a813bd"}, @@ -2813,6 +2941,7 @@ version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, @@ -2837,6 +2966,7 @@ version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, @@ -2906,6 +3036,7 @@ version = "3.8.4" description = "Python plotting package" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "matplotlib-3.8.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014"}, {file = "matplotlib-3.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106"}, @@ -2954,6 +3085,7 @@ version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, @@ -2968,6 +3100,7 @@ version = "0.4.0" description = "Collection of plugins for markdown-it-py" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mdit_py_plugins-0.4.0-py3-none-any.whl", hash = "sha256:b51b3bb70691f57f974e257e367107857a93b36f322a9e6d44ca5bf28ec2def9"}, {file = "mdit_py_plugins-0.4.0.tar.gz", hash = "sha256:d8ab27e9aed6c38aa716819fedfde15ca275715955f8a185a8e1cf90fb1d2c1b"}, @@ -2987,6 +3120,7 @@ version = "0.1.2" description = "Markdown URL utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, @@ -2998,6 +3132,7 @@ version = "3.0.2" description = "A sane and fast Markdown parser with useful plugins and renderers" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, @@ -3009,6 +3144,7 @@ version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, @@ -3017,7 +3153,7 @@ files = [ [package.extras] develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] docs = ["sphinx"] -gmpy = ["gmpy2 (>=2.1.0a4)"] +gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] tests = ["pytest (>=4.6)"] [[package]] @@ -3026,6 +3162,7 @@ version = "1.0.7" description = "MessagePack serializer" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "msgpack-1.0.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04ad6069c86e531682f9e1e71b71c1c3937d6014a7c3e9edd2aa81ad58842862"}, {file = "msgpack-1.0.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cca1b62fe70d761a282496b96a5e51c44c213e410a964bdffe0928e611368329"}, @@ -3091,6 +3228,8 @@ version = "2.3.7" description = "MuJoCo Physics Simulator" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"}, @@ -3125,31 +3264,13 @@ glfw = "*" numpy = "*" pyopengl = "*" -[[package]] -name = "mujoco-py" -version = "2.1.2.14" -description = "" -optional = true -python-versions = ">=3.6" -files = [ - {file = "mujoco-py-2.1.2.14.tar.gz", hash = "sha256:eb5b14485acf80a3cf8c15f4b080c6a28a9f79e68869aa696d16cbd51ea7706f"}, - {file = "mujoco_py-2.1.2.14-py3-none-any.whl", hash = "sha256:37c0b41bc0153a8a0eb3663103a67c60f65467753f74e4ff6e68b879f3e3a71f"}, -] - -[package.dependencies] -cffi = ">=1.10" -Cython = ">=0.27.2" -fasteners = ">=0.15,<1.0" -glfw = ">=1.4.0" -imageio = ">=2.1.2" -numpy = ">=1.11" - [[package]] name = "mypy" version = "1.7.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5da84d7bf257fd8f66b4f759a904fd2c5a765f70d8b52dde62b521972a0a2357"}, {file = "mypy-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3637c03f4025f6405737570d6cbfa4f1400eb3c649317634d273687a09ffc2f"}, @@ -3196,6 +3317,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -3207,6 +3329,7 @@ version = "1.0.0" description = "A Jupyter Notebook Sphinx reader built on top of the MyST markdown parser." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "myst_nb-1.0.0-py3-none-any.whl", hash = "sha256:ee8febc6dd7d9e32bede0c66a9b962b2e2fdab697428ee9fbfd4919d82380911"}, {file = "myst_nb-1.0.0.tar.gz", hash = "sha256:9077e42a1c6b441ea55078506f83555dda5d6c816ef4930841d71d239e3e0c5e"}, @@ -3235,6 +3358,7 @@ version = "2.0.0" description = "An extended [CommonMark](https://spec.commonmark.org/) compliant parser," optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "myst_parser-2.0.0-py3-none-any.whl", hash = "sha256:7c36344ae39c8e740dad7fdabf5aa6fc4897a813083c6cc9990044eb93656b14"}, {file = "myst_parser-2.0.0.tar.gz", hash = "sha256:ea929a67a6a0b1683cdbe19b8d2e724cd7643f8aa3e7bb18dd65beac3483bead"}, @@ -3261,6 +3385,7 @@ version = "8.4.0" description = "Simple yet flexible natural sorting in Python." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, @@ -3276,6 +3401,7 @@ version = "0.7.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." optional = false python-versions = ">=3.7.0" +groups = ["dev"] files = [ {file = "nbclient-0.7.4-py3-none-any.whl", hash = "sha256:c817c0768c5ff0d60e468e017613e6eae27b6fa31e43f905addd2d24df60c125"}, {file = "nbclient-0.7.4.tar.gz", hash = "sha256:d447f0e5a4cfe79d462459aec1b3dc5c2e9152597262be8ee27f7d4c02566a0d"}, @@ -3298,6 +3424,7 @@ version = "7.11.0" description = "Converting Jupyter Notebooks" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "nbconvert-7.11.0-py3-none-any.whl", hash = "sha256:d1d417b7f34a4e38887f8da5bdfd12372adf3b80f995d57556cb0972c68909fe"}, {file = "nbconvert-7.11.0.tar.gz", hash = "sha256:abedc01cf543177ffde0bfc2a69726d5a478f6af10a332fc1bf29fcb4f0cf000"}, @@ -3335,6 +3462,7 @@ version = "5.9.2" description = "The Jupyter Notebook format" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "nbformat-5.9.2-py3-none-any.whl", hash = "sha256:1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9"}, {file = "nbformat-5.9.2.tar.gz", hash = "sha256:5f98b5ba1997dff175e77e0c17d5c10a96eaed2cbd1de3533d1fc35d5e111192"}, @@ -3356,6 +3484,7 @@ version = "1.7.1" description = "Run any standard Python code quality tool on a Jupyter Notebook" optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "nbqa-1.7.1-py3-none-any.whl", hash = "sha256:77cdff622bfcf527bf260004449984edfb3624f6e065ac6bb35d64cddcdad483"}, {file = "nbqa-1.7.1.tar.gz", hash = "sha256:44f5b5000d6df438c4f1cba339e3ad80acc405e61f4500ac951fa36a177133f4"}, @@ -3376,6 +3505,7 @@ version = "0.6.1" description = "Strips outputs from Jupyter and IPython notebooks" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "nbstripout-0.6.1-py2.py3-none-any.whl", hash = "sha256:5ff6eb0debbcd656c4a64db8e082a24fabcfc753a9e8c9f6d786971e8f29e110"}, {file = "nbstripout-0.6.1.tar.gz", hash = "sha256:9065bcdd1488b386e4f3c081ffc1d48f4513a2f8d8bf4d0d9a28208c5dafe9d3"}, @@ -3390,6 +3520,7 @@ version = "1.5.8" description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "nest_asyncio-1.5.8-py3-none-any.whl", hash = "sha256:accda7a339a70599cb08f9dd09a67e0c2ef8d8d6f4c07f96ab203f2ae254e48d"}, {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"}, @@ -3401,6 +3532,7 @@ version = "3.2.1" description = "Python package for creating and manipulating graphs and networks" optional = false python-versions = ">=3.9" +groups = ["main", "dev"] files = [ {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, @@ -3419,6 +3551,7 @@ version = "1.8.0" description = "Node.js virtual environment builder" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +groups = ["dev"] files = [ {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, @@ -3433,6 +3566,7 @@ version = "7.2.2" description = "Jupyter Notebook - A web-based notebook environment for interactive computing" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "notebook-7.2.2-py3-none-any.whl", hash = "sha256:c89264081f671bc02eec0ed470a627ed791b9156cad9285226b31611d3e9fe1c"}, {file = "notebook-7.2.2.tar.gz", hash = "sha256:2ef07d4220421623ad3fe88118d687bc0450055570cdd160814a59cf3a1c516e"}, @@ -3448,7 +3582,7 @@ tornado = ">=6.2.0" [package.extras] dev = ["hatch", "pre-commit"] docs = ["myst-parser", "nbsphinx", "pydata-sphinx-theme", "sphinx (>=1.3.6)", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] -test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.27.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"] +test = ["importlib-resources (>=5.0) ; python_version < \"3.10\"", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.27.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"] [[package]] name = "notebook-shim" @@ -3456,6 +3590,7 @@ version = "0.2.3" description = "A shim layer for notebook traits and config" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "notebook_shim-0.2.3-py3-none-any.whl", hash = "sha256:a83496a43341c1674b093bfcebf0fe8e74cbe7eda5fd2bbc56f8e39e1486c0c7"}, {file = "notebook_shim-0.2.3.tar.gz", hash = "sha256:f69388ac283ae008cd506dda10d0288b09a017d822d5e8c7129a152cbd3ce7e9"}, @@ -3473,6 +3608,7 @@ version = "0.60.0" description = "compiling Python code using LLVM" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"}, {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"}, @@ -3507,6 +3643,7 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -3544,6 +3681,8 @@ version = "12.1.3.1" description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, @@ -3555,6 +3694,8 @@ version = "12.1.105" description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, @@ -3566,6 +3707,8 @@ version = "12.1.105" description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, @@ -3577,6 +3720,8 @@ version = "12.1.105" description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, @@ -3588,6 +3733,8 @@ version = "8.9.2.26" description = "cuDNN runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, ] @@ -3601,6 +3748,8 @@ version = "11.0.2.54" description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, @@ -3612,6 +3761,8 @@ version = "10.3.2.106" description = "CURAND native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, @@ -3623,6 +3774,8 @@ version = "11.4.5.107" description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, @@ -3639,6 +3792,8 @@ version = "12.1.0.106" description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, @@ -3653,6 +3808,8 @@ version = "2.18.1" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, ] @@ -3663,6 +3820,8 @@ version = "12.3.101" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"}, @@ -3675,6 +3834,8 @@ version = "12.1.105" description = "NVIDIA Tools Extension" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, @@ -3686,6 +3847,7 @@ version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, @@ -3702,6 +3864,8 @@ version = "4.8.1.78" description = "Wrapper package for OpenCV python bindings." optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "opencv-python-4.8.1.78.tar.gz", hash = "sha256:cc7adbbcd1112877a39274106cb2752e04984bc01a031162952e97450d6117f6"}, {file = "opencv_python-4.8.1.78-cp37-abi3-macosx_10_16_x86_64.whl", hash = "sha256:91d5f6f5209dc2635d496f6b8ca6573ecdad051a09e6b5de4c399b8e673c60da"}, @@ -3721,6 +3885,8 @@ version = "0.10.0" description = "Optimized PyTree Utilities." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "optree-0.10.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ac2c0fa383f504f03887a0c0ffcb6a4187c43c8c99c32f52ff14e7eae2c8c69b"}, {file = "optree-0.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8fa16b16203938b7a9caa4603998d0968b408f7f3a1a9f7f84763802daf1cff0"}, @@ -3781,6 +3947,7 @@ version = "4.1.0" description = "An OrderedSet is a custom MutableSet that remembers its order, so that every" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"}, {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"}, @@ -3795,6 +3962,7 @@ version = "7.4.0" description = "A decorator to automatically detect mismatch when overriding a method." optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "overrides-7.4.0-py3-none-any.whl", hash = "sha256:3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d"}, {file = "overrides-7.4.0.tar.gz", hash = "sha256:9502a3cca51f4fac40b5feca985b6703a5c1f6ad815588a7ca9e285b9dca6757"}, @@ -3806,6 +3974,7 @@ version = "23.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, @@ -3817,6 +3986,7 @@ version = "2.1.0" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "pandas-2.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:40dd20439ff94f1b2ed55b393ecee9cb6f3b08104c2c40b0cb7186a2f0046242"}, {file = "pandas-2.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d4f38e4fedeba580285eaac7ede4f686c6701a9e618d8a857b138a126d067f2f"}, @@ -3875,6 +4045,7 @@ version = "1.5.0" description = "Utilities for writing pandoc filters in python" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"}, {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"}, @@ -3886,6 +4057,7 @@ version = "0.8.3" description = "A Python Parser" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, @@ -3901,6 +4073,7 @@ version = "0.2.1" description = "Bring colors to your terminal." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"}, {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"}, @@ -3912,6 +4085,7 @@ version = "0.11.2" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, @@ -3923,6 +4097,7 @@ version = "0.1.2" description = "File system general utilities" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"}, ] @@ -3933,6 +4108,8 @@ version = "0.5.6" description = "A Python package for describing statistical models and for building design matrices." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "patsy-0.5.6-py2.py3-none-any.whl", hash = "sha256:19056886fd8fa71863fa32f0eb090267f21fb74be00f19f5c70b2e9d76c883c6"}, {file = "patsy-0.5.6.tar.gz", hash = "sha256:95c6d47a7222535f84bff7f63d7303f2e297747a598db89cf5c67f0c0c7d2cdb"}, @@ -3951,6 +4128,7 @@ version = "1.24.2" description = "Gymnasium for multi-agent reinforcement learning." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pettingzoo-1.24.2-py3-none-any.whl", hash = "sha256:00268cf990d243654c2bbbbf8c88322c12b041dc0a879b74747f14ee8aa93dd6"}, {file = "pettingzoo-1.24.2.tar.gz", hash = "sha256:0a5856d47de78ab20feddfdac4940959dc892f6becc92107247b1c3a210c0984"}, @@ -3976,6 +4154,8 @@ version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, @@ -3990,6 +4170,7 @@ version = "10.2.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pillow-10.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:7823bdd049099efa16e4246bdf15e5a13dbb18a51b68fa06d6c1d4d8b99a796e"}, {file = "pillow-10.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:83b2021f2ade7d1ed556bc50a399127d7fb245e725aa0113ebd05cfe88aaf588"}, @@ -4066,7 +4247,7 @@ docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] -typing = ["typing-extensions"] +typing = ["typing-extensions ; python_version < \"3.10\""] xmp = ["defusedxml"] [[package]] @@ -4075,6 +4256,8 @@ version = "2.6.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform == \"win32\"" files = [ {file = "platformdirs-2.6.2-py3-none-any.whl", hash = "sha256:83c8f6d04389165de7c9b6f0c682439697887bca0aa2f1c87ef1826be3584490"}, {file = "platformdirs-2.6.2.tar.gz", hash = "sha256:e1fea1fe471b9ff8332e229df3cb7de4f53eeea4998d3b6bfff542115e998bd2"}, @@ -4090,6 +4273,8 @@ version = "3.11.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"}, {file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"}, @@ -4105,6 +4290,7 @@ version = "1.3.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, @@ -4120,6 +4306,7 @@ version = "0.20.0" description = "A task runner that works well with poetry." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "poethepoet-0.20.0-py3-none-any.whl", hash = "sha256:cb37be15f3895ccc65ddf188c2e3d8fb79e26cc9d469a6098cb1c6f994659f6f"}, {file = "poethepoet-0.20.0.tar.gz", hash = "sha256:ca5a2a955f52dfb0a53fad3c989ef0b69ce3d5ec0f6bfa9b1da1f9e32d262e20"}, @@ -4138,6 +4325,7 @@ version = "3.5.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pre_commit-3.5.0-py2.py3-none-any.whl", hash = "sha256:841dc9aef25daba9a0238cd27984041fa0467b4199fc4852e27950664919f660"}, {file = "pre_commit-3.5.0.tar.gz", hash = "sha256:5804465c675b659b0862f07907f96295d490822a450c4c40e747d0b1c6ebcb32"}, @@ -4156,6 +4344,7 @@ version = "0.18.0" description = "Python client for the Prometheus monitoring system." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "prometheus_client-0.18.0-py3-none-any.whl", hash = "sha256:8de3ae2755f890826f4b6479e5571d4f74ac17a81345fe69a6778fdb92579184"}, {file = "prometheus_client-0.18.0.tar.gz", hash = "sha256:35f7a8c22139e2bb7ca5a698e92d38145bc8dc74c1c0bf56f25cca886a764e17"}, @@ -4170,6 +4359,7 @@ version = "2.3" description = "Promises/A+ implementation for Python" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0"}, ] @@ -4186,6 +4376,7 @@ version = "3.0.41" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" +groups = ["dev"] files = [ {file = "prompt_toolkit-3.0.41-py3-none-any.whl", hash = "sha256:f36fe301fafb7470e86aaf90f036eef600a3210be4decf461a5b1ca8403d3cb2"}, {file = "prompt_toolkit-3.0.41.tar.gz", hash = "sha256:941367d97fc815548822aa26c2a269fdc4eb21e9ec05fc5d447cf09bad5d75f0"}, @@ -4200,6 +4391,8 @@ version = "1.6.4" description = "A decorator for caching properties in classes (forked from cached-property)." optional = true python-versions = ">= 3.5" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "property-cached-1.6.4.zip", hash = "sha256:3e9c4ef1ed3653909147510481d7df62a3cfb483461a6986a6f1dcd09b2ebb73"}, {file = "property_cached-1.6.4-py2.py3-none-any.whl", hash = "sha256:135fc059ec969c1646424a0db15e7fbe1b5f8c36c0006d0b3c91ba568c11e7d8"}, @@ -4211,6 +4404,7 @@ version = "3.20.3" description = "Protocol Buffers" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "protobuf-3.20.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99"}, {file = "protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e"}, @@ -4242,6 +4436,7 @@ version = "5.9.6" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +groups = ["dev"] files = [ {file = "psutil-5.9.6-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d"}, {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c"}, @@ -4262,7 +4457,7 @@ files = [ ] [package.extras] -test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +test = ["enum34 ; python_version <= \"3.4\"", "ipaddress ; python_version < \"3.0\"", "mock ; python_version < \"3.0\"", "pywin32 ; sys_platform == \"win32\"", "wmi ; sys_platform == \"win32\""] [[package]] name = "ptyprocess" @@ -4270,6 +4465,8 @@ version = "0.7.0" description = "Run a subprocess in a pseudo terminal" optional = false python-versions = "*" +groups = ["dev"] +markers = "os_name != \"nt\" or sys_platform != \"win32\"" files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, @@ -4281,6 +4478,7 @@ version = "0.2.2" description = "Safely evaluate AST nodes without side effects" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, @@ -4295,6 +4493,7 @@ version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["main"] files = [ {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"}, {file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"}, @@ -4306,6 +4505,7 @@ version = "0.3.0" description = "A collection of ASN.1-based protocols modules" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["main"] files = [ {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"}, @@ -4320,6 +4520,7 @@ version = "0.24.0" description = "A BibTeX-compatible bibliography processor in Python" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*" +groups = ["dev"] files = [ {file = "pybtex-0.24.0-py2.py3-none-any.whl", hash = "sha256:e1e0c8c69998452fea90e9179aa2a98ab103f3eed894405b7264e517cc2fcc0f"}, {file = "pybtex-0.24.0.tar.gz", hash = "sha256:818eae35b61733e5c007c3fcd2cfb75ed1bc8b4173c1f70b56cc4c0802d34755"}, @@ -4339,6 +4540,7 @@ version = "1.0.3" description = "A docutils backend for pybtex." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pybtex-docutils-1.0.3.tar.gz", hash = "sha256:3a7ebdf92b593e00e8c1c538aa9a20bca5d92d84231124715acc964d51d93c6b"}, {file = "pybtex_docutils-1.0.3-py3-none-any.whl", hash = "sha256:8fd290d2ae48e32fcb54d86b0efb8d573198653c7e2447d5bec5847095f430b9"}, @@ -4354,6 +4556,8 @@ version = "3.2.5" description = "Official Python Interface for the Bullet Physics SDK specialized for Robotics Simulation and Reinforcement Learning" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"pybullet\"" files = [ {file = "pybullet-3.2.5-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:4970aec0dd968924f6b1820655a20f80650da2f85ba38b641937c9701a8a2b14"}, {file = "pybullet-3.2.5-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b64e4523a11d03729035e0a5baa0ce4d2ca58de8d0a242c0b91e8253781b24c4"}, @@ -4371,6 +4575,7 @@ version = "2.11.1" description = "Python style guide checker" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"}, {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, @@ -4382,6 +4587,7 @@ version = "2.21" description = "C parser in Python" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, @@ -4393,6 +4599,7 @@ version = "0.14.3" description = "Bootstrap-based Sphinx theme from the PyData community" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pydata_sphinx_theme-0.14.3-py3-none-any.whl", hash = "sha256:b7e40cd75a20449adfe2d7525be379b9fe92f6d31e5233e449fa34ddcd4398d9"}, {file = "pydata_sphinx_theme-0.14.3.tar.gz", hash = "sha256:bd474f347895f3fc5b6ce87390af64330ee54f11ebf9660d5bc3f87d532d4e5c"}, @@ -4420,6 +4627,7 @@ version = "3.2.2" description = "Python bindings for the Enchant spellchecking system" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "pyenchant-3.2.2-py3-none-any.whl", hash = "sha256:5facc821ece957208a81423af7d6ec7810dad29697cb0d77aae81e4e11c8e5a6"}, {file = "pyenchant-3.2.2-py3-none-win32.whl", hash = "sha256:5a636832987eaf26efe971968f4d1b78e81f62bca2bde0a9da210c7de43c3bce"}, @@ -4433,6 +4641,7 @@ version = "2.5.2" description = "Python Game Development" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "pygame-2.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a0769eb628c818761755eb0a0ca8216b95270ea8cbcbc82227e39ac9644643da"}, {file = "pygame-2.5.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed9a3d98adafa0805ccbaaff5d2996a2b5795381285d8437a4a5d248dbd12b4a"}, @@ -4499,13 +4708,14 @@ version = "2.17.1" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pygments-2.17.1-py3-none-any.whl", hash = "sha256:1b37f1b1e1bff2af52ecaf28cc601e2ef7077000b227a0675da25aef85784bc4"}, {file = "pygments-2.17.1.tar.gz", hash = "sha256:e45a0e74bf9c530f564ca81b8952343be986a29f6afe7f5ad95c5f06b7bdf5e8"}, ] [package.extras] -plugins = ["importlib-metadata"] +plugins = ["importlib-metadata ; python_version < \"3.8\""] windows-terminal = ["colorama (>=0.4.6)"] [[package]] @@ -4514,6 +4724,7 @@ version = "6.6.0" description = "Pymunk is a easy-to-use pythonic 2d physics library" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pymunk-6.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6da50dd97683337a290110d594fad07a75153d2d837b570ef972478d739c33f8"}, {file = "pymunk-6.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bcd7d16a2b4d51d45d6780a701f65c8d5b36fdf545c3f4738910da41e2a9c4ee"}, @@ -4585,6 +4796,8 @@ version = "3.1.7" description = "Standard OpenGL bindings for Python" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "PyOpenGL-3.1.7-py3-none-any.whl", hash = "sha256:a6ab19cf290df6101aaf7470843a9c46207789855746399d0af92521a0a92b7a"}, {file = "PyOpenGL-3.1.7.tar.gz", hash = "sha256:eef31a3888e6984fd4d8e6c9961b184c9813ca82604d37fe3da80eb000a76c86"}, @@ -4596,6 +4809,7 @@ version = "3.1.2" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.6.8" +groups = ["main"] files = [ {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, @@ -4610,6 +4824,7 @@ version = "7.4.3" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, @@ -4630,6 +4845,7 @@ version = "4.1.0" description = "Pytest plugin for measuring coverage." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, @@ -4648,6 +4864,7 @@ version = "2.8.2" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, @@ -4662,6 +4879,7 @@ version = "2.0.7" description = "A python library adding a json log formatter" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "python-json-logger-2.0.7.tar.gz", hash = "sha256:23e7ec02d34237c5aa1e29a070193a4ea87583bb4e7f8fd06d3de8264c4b2e1c"}, {file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"}, @@ -4673,6 +4891,7 @@ version = "2024.1" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, @@ -4684,6 +4903,8 @@ version = "306" description = "Python for Window Extensions" optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\"" files = [ {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, @@ -4707,6 +4928,8 @@ version = "2.0.12" description = "Pseudo terminal support for Windows from Python." optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "os_name == \"nt\"" files = [ {file = "pywinpty-2.0.12-cp310-none-win_amd64.whl", hash = "sha256:21319cd1d7c8844fb2c970fb3a55a3db5543f112ff9cfcd623746b9c47501575"}, {file = "pywinpty-2.0.12-cp311-none-win_amd64.whl", hash = "sha256:853985a8f48f4731a716653170cd735da36ffbdc79dcb4c7b7140bce11d8c722"}, @@ -4722,6 +4945,7 @@ version = "6.0.1" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, @@ -4775,6 +4999,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +markers = {main = "extra == \"argparse\" or extra == \"eval\""} [[package]] name = "pyzmq" @@ -4782,6 +5007,7 @@ version = "25.1.1" description = "Python bindings for 0MQ" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"}, {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"}, @@ -4887,6 +5113,7 @@ version = "5.5.1" description = "Jupyter Qt console" optional = false python-versions = ">= 3.8" +groups = ["dev"] files = [ {file = "qtconsole-5.5.1-py3-none-any.whl", hash = "sha256:8c75fa3e9b4ed884880ff7cea90a1b67451219279ec33deaee1d59e3df1a5d2b"}, {file = "qtconsole-5.5.1.tar.gz", hash = "sha256:a0e806c6951db9490628e4df80caec9669b65149c7ba40f9bf033c025a5b56bc"}, @@ -4912,6 +5139,7 @@ version = "2.4.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "QtPy-2.4.1-py3-none-any.whl", hash = "sha256:1c1d8c4fa2c884ae742b069151b0abe15b3f70491f3972698c683b8e38de839b"}, {file = "QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"}, @@ -4929,6 +5157,8 @@ version = "2.8.0" description = "Ray provides a simple, universal API for building distributed applications." optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "ray-2.8.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:34e0676a0dfa277efa688bccd83ecb7a799bc03078e5b1f1aa747fe9263175a8"}, {file = "ray-2.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:72c696c1b784c55f0ad107d55bb58ecef5d368176765cf44fed87e714538d708"}, @@ -4966,16 +5196,16 @@ pyyaml = "*" requests = "*" [package.extras] -air = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -all = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "dm-tree", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (!=1.56.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "gymnasium (==0.28.1)", "lz4", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml", "ray-cpp (==2.8.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +air = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +all = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "dm-tree", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (!=1.56.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "gymnasium (==0.28.1)", "lz4", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml", "ray-cpp (==2.8.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] client = ["grpcio (!=1.56.0)"] cpp = ["ray-cpp (==2.8.0)"] data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (>=6.0.1)"] -default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "virtualenv (>=20.0.24,<20.21.1)"] +default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "virtualenv (>=20.0.24,<20.21.1)"] observability = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"] rllib = ["dm-tree", "fsspec", "gymnasium (==0.28.1)", "lz4", "pandas", "pyarrow (>=6.0.1)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "tensorboardX (>=1.9)", "typer"] -serve = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +serve = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] train = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] tune = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] @@ -4985,6 +5215,7 @@ version = "0.31.0" description = "JSON Referencing + Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "referencing-0.31.0-py3-none-any.whl", hash = "sha256:381b11e53dd93babb55696c71cf42aef2d36b8a150c49bf0bc301e36d536c882"}, {file = "referencing-0.31.0.tar.gz", hash = "sha256:cc28f2c88fbe7b961a7817a0abc034c09a1e36358f82fedb4ffdf29a25398863"}, @@ -5000,6 +5231,7 @@ version = "2.32.0" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "requests-2.32.0-py3-none-any.whl", hash = "sha256:f2c3881dddb70d056c5bd7600a4fae312b2a300e39be6a118d30b90bd27262b5"}, {file = "requests-2.32.0.tar.gz", hash = "sha256:fa5490319474c82ef1d2c9bc459d3652e3ae4ef4c4ebdd18a21145a47ca4b6b8"}, @@ -5021,6 +5253,7 @@ version = "1.3.1" description = "OAuthlib authentication support for Requests." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["main"] files = [ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, @@ -5039,6 +5272,7 @@ version = "0.1.4" description = "A pure python RFC3339 validator" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "rfc3339_validator-0.1.4-py2.py3-none-any.whl", hash = "sha256:24f6ec1eda14ef823da9e36ec7113124b39c04d50a4d3d3a3c2859577e7791fa"}, {file = "rfc3339_validator-0.1.4.tar.gz", hash = "sha256:138a2abdf93304ad60530167e51d2dfb9549521a836871b88d7f4695d0022f6b"}, @@ -5053,6 +5287,7 @@ version = "0.1.1" description = "Pure python rfc3986 validator" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9"}, {file = "rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055"}, @@ -5064,6 +5299,8 @@ version = "1.2.0" description = "rliable: Reliable evaluation on reinforcement learning and machine learning benchmarks." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "rliable-1.2.0.tar.gz", hash = "sha256:72789d9147d7c56e6efa812f9dffedcef44993a866ec08d75506ac7c1fe69cd5"}, ] @@ -5082,6 +5319,7 @@ version = "0.13.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "rpds_py-0.13.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:1758197cc8d7ff383c07405f188253535b4aa7fa745cbc54d221ae84b18e0702"}, {file = "rpds_py-0.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:715df74cbcef4387d623c917f295352127f4b3e0388038d68fa577b4e4c6e540"}, @@ -5190,6 +5428,7 @@ version = "4.9" description = "Pure-Python RSA implementation" optional = false python-versions = ">=3.6,<4" +groups = ["main"] files = [ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, @@ -5204,6 +5443,7 @@ version = "0.18.5" description = "ruamel.yaml is a YAML parser/emitter that supports roundtrip preservation of comments, seq/map flow style, and map key order" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ruamel.yaml-0.18.5-py3-none-any.whl", hash = "sha256:a013ac02f99a69cdd6277d9664689eb1acba07069f912823177c5eced21a6ada"}, {file = "ruamel.yaml-0.18.5.tar.gz", hash = "sha256:61917e3a35a569c1133a8f772e1226961bf5a1198bea7e23f06a0841dea1ab0e"}, @@ -5222,6 +5462,8 @@ version = "0.2.8" description = "C version of reader, parser and emitter for ruamel.yaml derived from libyaml" optional = false python-versions = ">=3.6" +groups = ["dev"] +markers = "platform_python_implementation == \"CPython\" and python_version < \"3.13\"" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, @@ -5281,6 +5523,7 @@ version = "0.0.285" description = "An extremely fast Python linter, written in Rust." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ruff-0.0.285-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:72a3a0936369b986b0e959f9090206ed3c18f9e5e439ea5b8e6867c6707aded5"}, {file = "ruff-0.0.285-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0d9ab6ad16742eb78919e0fba09f914f042409df40ad63423c34bb20d350162a"}, @@ -5307,6 +5550,7 @@ version = "1.11.4" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.9" +groups = ["main", "dev"] files = [ {file = "scipy-1.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc9a714581f561af0848e6b69947fda0614915f072dfd14142ed1bfe1b806710"}, {file = "scipy-1.11.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cf00bd2b1b0211888d4dc75656c0412213a8b25e80d73898083f402b50f47e41"}, @@ -5349,6 +5593,8 @@ version = "0.13.2" description = "Statistical data visualization" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"}, {file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"}, @@ -5370,15 +5616,16 @@ version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +groups = ["dev"] files = [ {file = "Send2Trash-1.8.2-py3-none-any.whl", hash = "sha256:a384719d99c07ce1eefd6905d2decb6f8b7ed054025bb0e618919f945de4f679"}, {file = "Send2Trash-1.8.2.tar.gz", hash = "sha256:c132d59fa44b9ca2b1699af5c86f57ce9f4c5eb56629d5d55fbb7a35f84e2312"}, ] [package.extras] -nativelib = ["pyobjc-framework-Cocoa", "pywin32"] -objc = ["pyobjc-framework-Cocoa"] -win32 = ["pywin32"] +nativelib = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\"", "pywin32 ; sys_platform == \"win32\""] +objc = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\""] +win32 = ["pywin32 ; sys_platform == \"win32\""] [[package]] name = "sensai-utils" @@ -5386,6 +5633,7 @@ version = "1.2.1" description = "Utilities from sensAI, the Python library for sensible AI" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "sensai_utils-1.2.1-py3-none-any.whl", hash = "sha256:222e60d9f9d371c9d62ffcd1e6def1186f0d5243588b0b5af57e983beecc95bb"}, {file = "sensai_utils-1.2.1.tar.gz", hash = "sha256:4d8ca94179931798cef5f920fb042cbf9e7d806c0026b02afb58d0f72211bf27"}, @@ -5400,6 +5648,7 @@ version = "2.8.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sentry_sdk-2.8.0-py2.py3-none-any.whl", hash = "sha256:6051562d2cfa8087bb8b4b8b79dc44690f8a054762a29c07e22588b1f619bfb5"}, {file = "sentry_sdk-2.8.0.tar.gz", hash = "sha256:aa4314f877d9cd9add5a0c9ba18e3f27f99f7de835ce36bd150e48a41c7c646f"}, @@ -5450,6 +5699,7 @@ version = "1.3.3" description = "A Python module to customize the process title" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754"}, {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452"}, @@ -5550,6 +5800,7 @@ version = "68.2.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "setuptools-68.2.2-py3-none-any.whl", hash = "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a"}, {file = "setuptools-68.2.2.tar.gz", hash = "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87"}, @@ -5557,7 +5808,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov ; platform_python_implementation != \"PyPy\"", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-ruff ; sys_platform != \"cygwin\"", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -5566,6 +5817,8 @@ version = "0.2.1" description = "API for converting popular non-gymnasium environments to a gymnasium compatible environment." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "Shimmy-0.2.1-py3-none-any.whl", hash = "sha256:2d7d21c4ca679a64bb452e6a4232c6b0f5dba7589f5420454ddc1f0634334334"}, {file = "Shimmy-0.2.1.tar.gz", hash = "sha256:7b96915445ee5488dcb19ccf52ce5581d6f00cc5cf0e0dff36b16cd65bffcb75"}, @@ -5590,6 +5843,7 @@ version = "1.0.11" description = "A generator library for concise, unambiguous and URL-safe UUIDs." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "shortuuid-1.0.11-py3-none-any.whl", hash = "sha256:27ea8f28b1bd0bf8f15057a3ece57275d2059d2b0bb02854f02189962c13b6aa"}, {file = "shortuuid-1.0.11.tar.gz", hash = "sha256:fc75f2615914815a8e4cb1501b3a513745cb66ef0fd5fc6fb9f8c3fa3481f789"}, @@ -5601,6 +5855,7 @@ version = "1.16.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +groups = ["main", "dev"] files = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, @@ -5612,6 +5867,7 @@ version = "5.0.1" description = "A pure Python implementation of a sliding window memory map manager" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"}, {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, @@ -5623,6 +5879,7 @@ version = "1.3.0" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, @@ -5634,6 +5891,7 @@ version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"}, {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, @@ -5645,6 +5903,7 @@ version = "2.5" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, @@ -5656,6 +5915,7 @@ version = "7.2.6" description = "Python documentation generator" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx-7.2.6-py3-none-any.whl", hash = "sha256:1e09160a40b956dc623c910118fa636da93bd3ca0b9876a7b3df90f07d691560"}, {file = "sphinx-7.2.6.tar.gz", hash = "sha256:9a5160e1ea90688d5963ba09a2dcd8bdd526620edbb65c328728f1b2228d5ab5"}, @@ -5690,6 +5950,7 @@ version = "1.19.1" description = "Type hints (PEP 484) support for the Sphinx autodoc extension" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx_autodoc_typehints-1.19.1-py3-none-any.whl", hash = "sha256:9be46aeeb1b315eb5df1f3a7cb262149895d16c7d7dcd77b92513c3c3a1e85e6"}, {file = "sphinx_autodoc_typehints-1.19.1.tar.gz", hash = "sha256:6c841db55e0e9be0483ff3962a2152b60e79306f4288d8c4e7e86ac84486a5ea"}, @@ -5700,7 +5961,7 @@ Sphinx = ">=4.5" [package.extras] testing = ["covdefaults (>=2.2)", "coverage (>=6.3)", "diff-cover (>=6.4)", "nptyping (>=2.1.2)", "pytest (>=7.1)", "pytest-cov (>=3)", "sphobjinv (>=2)", "typing-extensions (>=4.1)"] -type-comments = ["typed-ast (>=1.5.2)"] +type-comments = ["typed-ast (>=1.5.2) ; python_version < \"3.8\""] [[package]] name = "sphinx-book-theme" @@ -5708,6 +5969,7 @@ version = "1.1.0" description = "A clean book theme for scientific explanations and documentation with Sphinx" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_book_theme-1.1.0-py3-none-any.whl", hash = "sha256:088bc69d65fab8446adb8691ed61687f71bf7504c9740af68bc78cf936a26112"}, {file = "sphinx_book_theme-1.1.0.tar.gz", hash = "sha256:ad4f92998e53e24751ecd0978d3eb79fdaa59692f005b1b286ecdd6146ebc9c1"}, @@ -5728,6 +5990,7 @@ version = "0.0.3" description = "Add comments and annotation to your documentation." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-comments-0.0.3.tar.gz", hash = "sha256:00170afff27019fad08e421da1ae49c681831fb2759786f07c826e89ac94cf21"}, {file = "sphinx_comments-0.0.3-py3-none-any.whl", hash = "sha256:1e879b4e9bfa641467f83e3441ac4629225fc57c29995177d043252530c21d00"}, @@ -5747,6 +6010,7 @@ version = "0.5.2" description = "Add a copy button to each of your code cells." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx-copybutton-0.5.2.tar.gz", hash = "sha256:4cf17c82fb9646d1bc9ca92ac280813a3b605d8c421225fd9913154103ee1fbd"}, {file = "sphinx_copybutton-0.5.2-py3-none-any.whl", hash = "sha256:fb543fd386d917746c9a2c50360c7905b605726b9355cd26e9974857afeae06e"}, @@ -5765,6 +6029,7 @@ version = "0.5.0" description = "A sphinx extension for designing beautiful, view size responsive web components." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "sphinx_design-0.5.0-py3-none-any.whl", hash = "sha256:1af1267b4cea2eedd6724614f19dcc88fe2e15aff65d06b2f6252cee9c4f4c1e"}, {file = "sphinx_design-0.5.0.tar.gz", hash = "sha256:e8e513acea6f92d15c6de3b34e954458f245b8e761b45b63950f65373352ab00"}, @@ -5788,6 +6053,7 @@ version = "1.0.1" description = "A sphinx extension that allows the site-map to be defined in a single YAML file." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_external_toc-1.0.1-py3-none-any.whl", hash = "sha256:d9e02d50731dee9697c1887e4f8b361e7b86d38241f0e66bd5a9f4096779646f"}, {file = "sphinx_external_toc-1.0.1.tar.gz", hash = "sha256:a7d2c63cc47ec688546443b28bc4ef466121827ef3dc7bb509de354bad4ea2e0"}, @@ -5809,6 +6075,7 @@ version = "0.2.0.post1" description = "Patches Jinja2 v3 to restore compatibility with earlier Sphinx versions." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sphinx_jinja2_compat-0.2.0.post1-py3-none-any.whl", hash = "sha256:f9d329174bdde8db19dc12c62528367196eb2f6b46c91754eca604acd0c0f6ad"}, {file = "sphinx_jinja2_compat-0.2.0.post1.tar.gz", hash = "sha256:974289a12a9f402108dead621e9c15f7004e945d5cfcaea8d6419e94d3fa95a3"}, @@ -5824,6 +6091,7 @@ version = "1.0.0" description = "Latex specific features for jupyter book" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_jupyterbook_latex-1.0.0-py3-none-any.whl", hash = "sha256:e0cd3e9e1c5af69136434e21a533343fdf013475c410a414d5b7b4922b4f3891"}, {file = "sphinx_jupyterbook_latex-1.0.0.tar.gz", hash = "sha256:f54c6674c13f1616f9a93443e98b9b5353f9fdda8e39b6ec552ccf0b3e5ffb62"}, @@ -5845,6 +6113,7 @@ version = "0.1.3" description = "Supporting continuous HTML section numbering" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-multitoc-numbering-0.1.3.tar.gz", hash = "sha256:c9607671ac511236fa5d61a7491c1031e700e8d498c9d2418e6c61d1251209ae"}, {file = "sphinx_multitoc_numbering-0.1.3-py3-none-any.whl", hash = "sha256:33d2e707a9b2b8ad636b3d4302e658a008025106fe0474046c651144c26d8514"}, @@ -5864,6 +6133,7 @@ version = "1.5.0" description = "Sphinx directive to add unselectable prompt" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx_prompt-1.5.0-py3-none-any.whl", hash = "sha256:fa4e90d8088b5a996c76087d701fc7e31175f8b9dc4aab03a507e45051067162"}, ] @@ -5878,6 +6148,7 @@ version = "3.4.5" description = "Tabbed views for Sphinx" optional = false python-versions = "~=3.7" +groups = ["dev"] files = [ {file = "sphinx-tabs-3.4.5.tar.gz", hash = "sha256:ba9d0c1e3e37aaadd4b5678449eb08176770e0fc227e769b6ce747df3ceea531"}, {file = "sphinx_tabs-3.4.5-py3-none-any.whl", hash = "sha256:92cc9473e2ecf1828ca3f6617d0efc0aa8acb06b08c56ba29d1413f2f0f6cf09"}, @@ -5898,6 +6169,7 @@ version = "0.3.1" description = "Integrate interactive code blocks into your documentation with Thebe and Binder." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "sphinx_thebe-0.3.1-py3-none-any.whl", hash = "sha256:e7e7edee9f0d601c76bc70156c471e114939484b111dd8e74fe47ac88baffc52"}, {file = "sphinx_thebe-0.3.1.tar.gz", hash = "sha256:576047f45560e82f64aa5f15200b1eb094dcfe1c5b8f531a8a65bd208e25a493"}, @@ -5917,6 +6189,7 @@ version = "0.3.2" description = "Toggle page content and collapse admonitions in Sphinx." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-togglebutton-0.3.2.tar.gz", hash = "sha256:ab0c8b366427b01e4c89802d5d078472c427fa6e9d12d521c34fa0442559dc7a"}, {file = "sphinx_togglebutton-0.3.2-py3-none-any.whl", hash = "sha256:9647ba7874b7d1e2d43413d8497153a85edc6ac95a3fea9a75ef9c1e08aaae2b"}, @@ -5937,6 +6210,7 @@ version = "3.5.0" description = "Box of handy tools for Sphinx 🧰 📔" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx_toolbox-3.5.0-py3-none-any.whl", hash = "sha256:20dfd3566717db6f2da7a400a54dc4b946f064fb31250fa44802d54cfb9b8a03"}, {file = "sphinx_toolbox-3.5.0.tar.gz", hash = "sha256:e5b5a7153f1997572d71a06aaf6cec225483492ec2c60097a84f15aad6df18b7"}, @@ -5971,6 +6245,7 @@ version = "1.0.7" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_applehelp-1.0.7-py3-none-any.whl", hash = "sha256:094c4d56209d1734e7d252f6e0b3ccc090bd52ee56807a5d9315b19c122ab15d"}, {file = "sphinxcontrib_applehelp-1.0.7.tar.gz", hash = "sha256:39fdc8d762d33b01a7d8f026a3b7d71563ea3b72787d5f00ad8465bd9d6dfbfa"}, @@ -5989,6 +6264,7 @@ version = "2.5.0" description = "Sphinx extension for BibTeX style citations." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sphinxcontrib-bibtex-2.5.0.tar.gz", hash = "sha256:71b42e5db0e2e284f243875326bf9936aa9a763282277d75048826fef5b00eaa"}, {file = "sphinxcontrib_bibtex-2.5.0-py3-none-any.whl", hash = "sha256:748f726eaca6efff7731012103417ef130ecdcc09501b4d0c54283bf5f059f76"}, @@ -6006,6 +6282,7 @@ version = "1.0.5" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp documents" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_devhelp-1.0.5-py3-none-any.whl", hash = "sha256:fe8009aed765188f08fcaadbb3ea0d90ce8ae2d76710b7e29ea7d047177dae2f"}, {file = "sphinxcontrib_devhelp-1.0.5.tar.gz", hash = "sha256:63b41e0d38207ca40ebbeabcf4d8e51f76c03e78cd61abe118cf4435c73d4212"}, @@ -6024,6 +6301,7 @@ version = "2.0.4" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_htmlhelp-2.0.4-py3-none-any.whl", hash = "sha256:8001661c077a73c29beaf4a79968d0726103c5605e27db92b9ebed8bab1359e9"}, {file = "sphinxcontrib_htmlhelp-2.0.4.tar.gz", hash = "sha256:6c26a118a05b76000738429b724a0568dbde5b72391a688577da08f11891092a"}, @@ -6042,6 +6320,7 @@ version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8"}, {file = "sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178"}, @@ -6056,6 +6335,7 @@ version = "1.0.6" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp documents" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_qthelp-1.0.6-py3-none-any.whl", hash = "sha256:bf76886ee7470b934e363da7a954ea2825650013d367728588732c7350f49ea4"}, {file = "sphinxcontrib_qthelp-1.0.6.tar.gz", hash = "sha256:62b9d1a186ab7f5ee3356d906f648cacb7a6bdb94d201ee7adf26db55092982d"}, @@ -6074,6 +6354,7 @@ version = "1.1.9" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_serializinghtml-1.1.9-py3-none-any.whl", hash = "sha256:9b36e503703ff04f20e9675771df105e58aa029cfcbc23b8ed716019b7416ae1"}, {file = "sphinxcontrib_serializinghtml-1.1.9.tar.gz", hash = "sha256:0c64ff898339e1fac29abd2bf5f11078f3ec413cfe9c046d3120d7ca65530b54"}, @@ -6092,6 +6373,7 @@ version = "8.0.0" description = "Sphinx spelling extension" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinxcontrib-spelling-8.0.0.tar.gz", hash = "sha256:199d0a16902ad80c387c2966dc9eb10f565b1fb15ccce17210402db7c2443e5c"}, {file = "sphinxcontrib_spelling-8.0.0-py3-none-any.whl", hash = "sha256:b27e0a16aef00bcfc888a6490dc3f16651f901dc475446c6882834278c8dc7b3"}, @@ -6110,6 +6392,7 @@ version = "2.0.23" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"}, {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"}, @@ -6197,6 +6480,7 @@ version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, @@ -6216,6 +6500,8 @@ version = "0.14.0" description = "Statistical computations and models for Python" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "statsmodels-0.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16bfe0c96a53b20fa19067e3b6bd2f1d39e30d4891ea0d7bc20734a0ae95942d"}, {file = "statsmodels-0.14.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5a6a0a1a06ff79be8aa89c8494b33903442859add133f0dda1daf37c3c71682e"}, @@ -6257,7 +6543,7 @@ scipy = ">=1.4,<1.9.2 || >1.9.2" [package.extras] build = ["cython (>=0.29.26)"] -develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"] +develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty ; os_name == \"nt\"", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"] docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"] [[package]] @@ -6266,6 +6552,8 @@ version = "4.2.0" description = "SWIG is a software development tool that connects programs written in C and C++ with a variety of high-level programming languages." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"box2d\"" files = [ {file = "swig-4.2.0-py2.py3-none-macosx_10_9_universal2.whl", hash = "sha256:71bf282fb30aa179b870e29c8f4fe16b3404e8562377061f85d57a2ec1571d7c"}, {file = "swig-4.2.0-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:071c7a3af61c2c69d1e911c5428479a4536a8103623276847d8e55350da8cf05"}, @@ -6291,6 +6579,7 @@ version = "1.12" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, @@ -6305,6 +6594,7 @@ version = "0.9.0" description = "Pretty-print tabular data" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, @@ -6319,6 +6609,7 @@ version = "2.15.1" description = "TensorBoard lets you watch Tensors Flow" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "tensorboard-2.15.1-py3-none-any.whl", hash = "sha256:c46c1d1cf13a458c429868a78b2531d8ff5f682058d69ec0840b0bc7a38f1c0f"}, ] @@ -6343,6 +6634,7 @@ version = "0.7.2" description = "Fast data loading for TensorBoard" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, @@ -6355,6 +6647,7 @@ version = "0.18.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "terminado-0.18.0-py3-none-any.whl", hash = "sha256:87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e"}, {file = "terminado-0.18.0.tar.gz", hash = "sha256:1ea08a89b835dd1b8c0c900d92848147cef2537243361b2e3f4dc15df9b6fded"}, @@ -6376,6 +6669,7 @@ version = "1.2.1" description = "A tiny CSS parser" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tinycss2-1.2.1-py3-none-any.whl", hash = "sha256:2b80a96d41e7c3914b8cda8bc7f705a4d9c49275616e886103dd839dfc847847"}, {file = "tinycss2-1.2.1.tar.gz", hash = "sha256:8cff3a8f066c2ec677c06dbc7b45619804a6938478d9d73c284b29d14ecb0627"}, @@ -6394,6 +6688,7 @@ version = "5.2.0" description = "A wrapper around the stdlib `tokenize` which roundtrips." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tokenize_rt-5.2.0-py2.py3-none-any.whl", hash = "sha256:b79d41a65cfec71285433511b50271b05da3584a1da144a0752e9c621a285289"}, {file = "tokenize_rt-5.2.0.tar.gz", hash = "sha256:9fe80f8a5c1edad2d3ede0f37481cc0cc1538a2f442c9c2f9e4feacd2792d054"}, @@ -6405,6 +6700,7 @@ version = "2.0.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, @@ -6416,6 +6712,7 @@ version = "2.1.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" +groups = ["main"] files = [ {file = "torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5ebc43f5355a9b7be813392b3fb0133991f0380f6f0fcc8218d5468dc45d1071"}, {file = "torch-2.1.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:84fefd63356416c0cd20578637ccdbb82164993400ed17b57c951dd6376dcee8"}, @@ -6469,6 +6766,7 @@ version = "6.4.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1"}, {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803"}, @@ -6489,6 +6787,7 @@ version = "4.66.3" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "tqdm-4.66.3-py3-none-any.whl", hash = "sha256:4f41d54107ff9a223dca80b53efe4fb654c67efaba7f47bada3ee9d50e05bd53"}, {file = "tqdm-4.66.3.tar.gz", hash = "sha256:23097a41eba115ba99ecae40d06444c15d1c0c698d527a01c6c8bd1c5d0647e5"}, @@ -6509,6 +6808,7 @@ version = "5.13.0" description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "traitlets-5.13.0-py3-none-any.whl", hash = "sha256:baf991e61542da48fe8aef8b779a9ea0aa38d8a54166ee250d5af5ecf4486619"}, {file = "traitlets-5.13.0.tar.gz", hash = "sha256:9b232b9430c8f57288c1024b34a8f0251ddcc47268927367a0dd3eeaca40deb5"}, @@ -6524,6 +6824,8 @@ version = "2.1.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, @@ -6549,6 +6851,8 @@ version = "4.24.0.4" description = "Typing stubs for protobuf" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "types-protobuf-4.24.0.4.tar.gz", hash = "sha256:57ab42cb171dfdba2c74bb5b50c250478538cc3c5ed95b8b368929ad0c9f90a5"}, {file = "types_protobuf-4.24.0.4-py3-none-any.whl", hash = "sha256:131ab7d0cbc9e444bc89c994141327dcce7bcaeded72b1acb72a94827eb9c7af"}, @@ -6560,6 +6864,7 @@ version = "2.8.19.14" description = "Typing stubs for python-dateutil" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "types-python-dateutil-2.8.19.14.tar.gz", hash = "sha256:1f4f10ac98bb8b16ade9dbee3518d9ace017821d94b057a425b069f834737f4b"}, {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"}, @@ -6571,6 +6876,7 @@ version = "2.31.0.20240311" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "types-requests-2.31.0.20240311.tar.gz", hash = "sha256:b1c1b66abfb7fa79aae09097a811c4aa97130eb8831c60e47aee4ca344731ca5"}, {file = "types_requests-2.31.0.20240311-py3-none-any.whl", hash = "sha256:47872893d65a38e282ee9f277a4ee50d1b28bd592040df7d1fdaffdf3779937d"}, @@ -6585,6 +6891,7 @@ version = "0.9.0.20240106" description = "Typing stubs for tabulate" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "types-tabulate-0.9.0.20240106.tar.gz", hash = "sha256:c9b6db10dd7fcf55bd1712dd3537f86ddce72a08fd62bb1af4338c7096ce947e"}, {file = "types_tabulate-0.9.0.20240106-py3-none-any.whl", hash = "sha256:0378b7b6fe0ccb4986299496d027a6d4c218298ecad67199bbd0e2d7e9d335a1"}, @@ -6596,6 +6903,7 @@ version = "4.8.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, @@ -6607,6 +6915,7 @@ version = "2024.1" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, @@ -6618,6 +6927,7 @@ version = "1.0.2" description = "Micro subset of unicode data files for linkify-it-py projects." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "uc-micro-py-1.0.2.tar.gz", hash = "sha256:30ae2ac9c49f39ac6dce743bd187fcd2b574b16ca095fa74cd9396795c954c54"}, {file = "uc_micro_py-1.0.2-py3-none-any.whl", hash = "sha256:8c9110c309db9d9e87302e2f4ad2c3152770930d88ab385cd544e7a7e75f3de0"}, @@ -6632,6 +6942,7 @@ version = "1.3.0" description = "RFC 6570 URI Template Processor" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "uri-template-1.3.0.tar.gz", hash = "sha256:0e00f8eb65e18c7de20d595a14336e9f337ead580c70934141624b6d1ffdacc7"}, {file = "uri_template-1.3.0-py3-none-any.whl", hash = "sha256:a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363"}, @@ -6646,13 +6957,14 @@ version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -6663,6 +6975,8 @@ version = "20.16.3" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] +markers = "sys_platform == \"win32\"" files = [ {file = "virtualenv-20.16.3-py2.py3-none-any.whl", hash = "sha256:4193b7bc8a6cd23e4eb251ac64f29b4398ab2c233531e66e40b19a6b7b0d30c1"}, {file = "virtualenv-20.16.3.tar.gz", hash = "sha256:d86ea0bb50e06252d79e6c241507cb904fcd66090c3271381372d6221a3970f9"}, @@ -6683,6 +6997,8 @@ version = "20.24.6" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "virtualenv-20.24.6-py3-none-any.whl", hash = "sha256:520d056652454c5098a00c0f073611ccbea4c79089331f60bf9d7ba247bb7381"}, {file = "virtualenv-20.24.6.tar.gz", hash = "sha256:02ece4f56fbf939dbbc33c0715159951d6bf14aaf5457b092e4548e1382455af"}, @@ -6695,7 +7011,7 @@ platformdirs = ">=3.9.1,<4" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] [[package]] name = "vizdoom" @@ -6703,6 +7019,8 @@ version = "1.2.2" description = "ViZDoom is Doom-based AI Research Platform for Reinforcement Learning from Raw Visual Information." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"vizdoom\"" files = [ {file = "vizdoom-1.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3e2f478e1728702f17b828de0e7ee6bf0e2809c1786ce21f69ce00e4a4da82e0"}, {file = "vizdoom-1.2.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:49180ed13d30109bcd99b38e6b923c5bd74e6bb364add8d46beb5cdf7405fe10"}, @@ -6738,6 +7056,7 @@ version = "0.12.21" description = "A CLI and library for interacting with the Weights and Biases API." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "wandb-0.12.21-py2.py3-none-any.whl", hash = "sha256:150842447d355d90dc7f368b824951a625e5b2d1be355a00e99b11b73728bc1f"}, {file = "wandb-0.12.21.tar.gz", hash = "sha256:1975ff88c5024923c3321c93cfefb8d9b871543c0b009f34001bf0f31e444b04"}, @@ -6776,6 +7095,7 @@ version = "0.2.10" description = "Measures the displayed width of unicode strings in a terminal" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "wcwidth-0.2.10-py2.py3-none-any.whl", hash = "sha256:aec5179002dd0f0d40c456026e74a729661c9d468e1ed64405e3a6c2176ca36f"}, {file = "wcwidth-0.2.10.tar.gz", hash = "sha256:390c7454101092a6a5e43baad8f83de615463af459201709556b6e4b1c861f97"}, @@ -6787,6 +7107,7 @@ version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "webcolors-1.13-py3-none-any.whl", hash = "sha256:29bc7e8752c0a1bd4a1f03c14d6e6a72e93d82193738fa860cbff59d0fcc11bf"}, {file = "webcolors-1.13.tar.gz", hash = "sha256:c225b674c83fa923be93d235330ce0300373d02885cef23238813b0d5668304a"}, @@ -6802,6 +7123,7 @@ version = "0.5.1" description = "Character encoding aliases for legacy web content" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"}, {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, @@ -6813,6 +7135,7 @@ version = "1.6.4" description = "WebSocket client for Python with low level API options" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "websocket-client-1.6.4.tar.gz", hash = "sha256:b3324019b3c28572086c4a319f91d1dcd44e6e11cd340232978c684a7650d0df"}, {file = "websocket_client-1.6.4-py3-none-any.whl", hash = "sha256:084072e0a7f5f347ef2ac3d8698a5e0b4ffbfcab607628cadabc650fc9a83a24"}, @@ -6829,6 +7152,7 @@ version = "3.0.6" description = "The comprehensive WSGI web application library." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, @@ -6846,6 +7170,7 @@ version = "0.41.3" description = "A built-package format for Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "wheel-0.41.3-py3-none-any.whl", hash = "sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942"}, {file = "wheel-0.41.3.tar.gz", hash = "sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841"}, @@ -6860,6 +7185,7 @@ version = "4.0.9" description = "Jupyter interactive widgets for Jupyter Notebook" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"}, {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, @@ -6871,6 +7197,7 @@ version = "3.19.1" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, @@ -6888,12 +7215,11 @@ classic-control = ["pygame"] envpool = ["envpool"] eval = ["docstring-parser", "joblib", "jsonargparse", "rliable", "scipy"] mujoco = ["imageio", "mujoco"] -mujoco-py = ["cython", "mujoco-py"] pybullet = ["pybullet"] robotics = ["gymnasium-robotics"] vizdoom = ["vizdoom"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.11" -content-hash = "1ea1b72b90269fd86b81b1443785085618248ccf5b62506a166b879115749171" +content-hash = "34619479ac6375d1680cb9498deb7995c12402369ddd7afa559215748fad5fca" diff --git a/pyproject.toml b/pyproject.toml index 6d0f35536..1914e8979 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,6 @@ joblib = { version = "*", optional = true } jsonargparse = {version = "^4.24.1", optional = true} # we need <3 b/c of https://github.com/Farama-Foundation/Gymnasium/issues/749 mujoco = { version = ">=2.1.5, <3", optional = true } -mujoco-py = { version = ">=2.1,<2.2", optional = true } opencv_python = { version = "*", optional = true } pybullet = { version = "*", optional = true } pygame = { version = ">=2.1.3", optional = true } @@ -79,7 +78,6 @@ atari = ["ale-py", "autorom", "opencv-python", "shimmy"] box2d = ["box2d-py", "pygame", "swig"] classic_control = ["pygame"] mujoco = ["mujoco", "imageio"] -mujoco_py = ["mujoco-py", "cython"] pybullet = ["pybullet"] envpool = ["envpool"] robotics = ["gymnasium-robotics"] @@ -216,7 +214,7 @@ move-optionals-to-bottom = true PYDEVD_DISABLE_FILE_VALIDATION="1" # keep relevant parts in sync with pre-commit [tool.poe.tasks] # https://github.com/nat-n/poethepoet -test = "pytest test --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v --color=yes" +test = "pytest test" test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes" _black_check = "black --check ." _ruff_check = "ruff check ." From 5088ce4a16daddafc047d718b1e0a73f7544c202 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 5 May 2025 19:03:18 +0200 Subject: [PATCH 117/230] v2: processing functions in Algorithm are now private and have stricter types Added several # type: ignore[override] since the specifications of _update_with_batch is a violation of the substitution principle. However, we accept this violation since it is minor and only in Algorithm internals --- tianshou/data/types.py | 2 +- tianshou/highlevel/params/policy_params.py | 44 +++++++++++----------- tianshou/policy/base.py | 30 +++++++-------- tianshou/policy/imitation/discrete_bcq.py | 6 +-- tianshou/policy/imitation/discrete_crr.py | 6 +-- tianshou/policy/imitation/gail.py | 8 ++-- tianshou/policy/modelbased/icm.py | 20 +++++----- tianshou/policy/modelfree/a2c.py | 6 +-- tianshou/policy/modelfree/bdqn.py | 6 +-- tianshou/policy/modelfree/ddpg.py | 2 +- tianshou/policy/modelfree/dqn.py | 2 +- tianshou/policy/modelfree/npg.py | 6 +-- tianshou/policy/modelfree/pg.py | 2 +- tianshou/policy/modelfree/ppo.py | 11 ++++-- tianshou/policy/modelfree/trpo.py | 6 +-- tianshou/policy/multiagent/mapolicy.py | 6 +-- 16 files changed, 83 insertions(+), 80 deletions(-) diff --git a/tianshou/data/types.py b/tianshou/data/types.py index fd2f6d287..35ec917f6 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -32,7 +32,7 @@ class RolloutBatchProtocol(ObsBatchProtocol, Protocol): class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): """With added returns, usually computed with GAE.""" - returns: TArr + returns: torch.Tensor class PrioBatchProtocol(RolloutBatchProtocol, Protocol): diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index b79a8424c..bb45d56b4 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -413,28 +413,28 @@ class PPOParams(A2CParams): """ dual_clip: float | None = None """ - a clipping parameter (denoted as c in the literature) that prevents - excessive pessimism in policy updates for negative-advantage actions. - Excessive pessimism occurs when the policy update too strongly reduces the probability - of selecting actions that led to negative advantages, potentially eliminating useful - actions based on limited negative experiences. + a clipping parameter (denoted as c in the literature) that prevents + excessive pessimism in policy updates for negative-advantage actions. + Excessive pessimism occurs when the policy update too strongly reduces the probability + of selecting actions that led to negative advantages, potentially eliminating useful + actions based on limited negative experiences. When enabled (c > 1), the objective for negative advantages becomes: max(min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A), c*A), where min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A) is the original single-clipping objective determined by `eps_clip`. - This creates a floor on negative policy gradients, maintaining some probability - of exploring actions despite initial negative outcomes. - Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer - to 1.0 provide less protection against pessimistic updates. + This creates a floor on negative policy gradients, maintaining some probability + of exploring actions despite initial negative outcomes. + Larger values (e.g., 2.0 to 5.0) maintain more exploration, while values closer + to 1.0 provide less protection against pessimistic updates. Set to None to disable dual clipping. """ value_clip: bool = False """ flag indicating whether to enable clipping for value function updates. - When enabled, restricts how much the value function estimate can change from its + When enabled, restricts how much the value function estimate can change from its previous prediction, using the same clipping range as the policy updates (eps_clip). - This stabilizes training by preventing large fluctuations in value estimates, + This stabilizes training by preventing large fluctuations in value estimates, particularly useful in environments with high reward variance. - The clipped value loss uses a pessimistic approach, taking the maximum of the + The clipped value loss uses a pessimistic approach, taking the maximum of the original and clipped value errors: max((returns - value)², (returns - v_clipped)²) Setting to True often improves training stability but may slow convergence. @@ -541,16 +541,16 @@ def _get_param_transformers(self) -> list[ParamTransformer]: class ParamsMixinAlpha(GetParamTransformersProtocol): alpha: float | AutoAlphaFactory = 0.2 """ - the entropy regularization coefficient, which balances exploration and exploitation. - This coefficient controls how much the agent values randomness in its policy versus - pursuing higher rewards. - Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent - for maintaining diverse action choices, even if this means selecting some lower-value actions. - Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become - more focused on the highest-value actions. - A value of 0 would completely remove entropy regularization, potentially leading to - premature convergence to suboptimal deterministic policies. - Can be provided as a fixed float (0.2 is a reasonable default) or via a factory + the entropy regularization coefficient, which balances exploration and exploitation. + This coefficient controls how much the agent values randomness in its policy versus + pursuing higher rewards. + Higher values (e.g., 0.5-1.0) strongly encourage exploration by rewarding the agent + for maintaining diverse action choices, even if this means selecting some lower-value actions. + Lower values (e.g., 0.01-0.1) prioritize exploitation, allowing the policy to become + more focused on the highest-value actions. + A value of 0 would completely remove entropy regularization, potentially leading to + premature convergence to suboptimal deterministic policies. + Can be provided as a fixed float (0.2 is a reasonable default) or via a factory to support automatic tuning during training. """ diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index e31984209..142c9bd17 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -571,7 +571,7 @@ def load_state_dict( return super().load_state_dict(state_dict, strict=strict, assign=assign) - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -588,9 +588,9 @@ def preprocess_batch( """ return batch - def postprocess_batch( + def _postprocess_batch( self, - batch: BatchProtocol, + batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: @@ -649,10 +649,10 @@ def _update( return TrainingStats() start_time = time.time() batch, indices = buffer.sample(sample_size) - batch = self.preprocess_batch(batch, buffer, indices) + batch = self._preprocess_batch(batch, buffer, indices) with torch_train_mode(self): training_stat = update_with_batch_fn(batch) - self.postprocess_batch(batch, buffer, indices) + self._postprocess_batch(batch, buffer, indices) for lr_scheduler in self.lr_schedulers: lr_scheduler.step() training_stat.train_time = time.time() - start_time @@ -998,23 +998,23 @@ def __init__( super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: """Performs the pre-processing as defined by the wrapped algorithm.""" - return self.wrapped_algorithm.preprocess_batch(batch, buffer, indices) + return self.wrapped_algorithm._preprocess_batch(batch, buffer, indices) - def postprocess_batch( + def _postprocess_batch( self, - batch: BatchProtocol, + batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: """Performs the batch post-processing as defined by the wrapped algorithm.""" - self.wrapped_algorithm.postprocess_batch(batch, buffer, indices) + self.wrapped_algorithm._postprocess_batch(batch, buffer, indices) def _update_with_batch( self, batch: RolloutBatchProtocol, batch_size: int | None, repeat: int @@ -1055,23 +1055,23 @@ def __init__( super().__init__(policy=wrapped_algorithm.policy) self.wrapped_algorithm = wrapped_algorithm - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: """Performs the pre-processing as defined by the wrapped algorithm.""" - return self.wrapped_algorithm.preprocess_batch(batch, buffer, indices) + return self.wrapped_algorithm._preprocess_batch(batch, buffer, indices) - def postprocess_batch( + def _postprocess_batch( self, - batch: BatchProtocol, + batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: """Performs the batch post-processing as defined by the wrapped algorithm.""" - self.wrapped_algorithm.postprocess_batch(batch, buffer, indices) + self.wrapped_algorithm._postprocess_batch(batch, buffer, indices) def _update_with_batch( self, diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 57f818aea..ddbffaa66 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -174,7 +174,7 @@ def __init__( self.eps = eval_eps self._weight_reg = imitation_logits_penalty - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -198,9 +198,9 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: target_q, _ = self.model_old(batch.obs_next) return target_q[np.arange(len(act)), act] - def _update_with_batch( + def _update_with_batch( # type: ignore[override] self, - batch: RolloutBatchProtocol, + batch: BatchWithReturnsProtocol, ) -> DiscreteBCQTrainingStats: if self._iter % self.freq == 0: self._update_lagged_network_weights() diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 821e365c8..af03e436b 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -98,7 +98,7 @@ def __init__( self._beta = beta self._min_q_weight = min_q_weight - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -110,9 +110,9 @@ def preprocess_batch( indices, ) - def _update_with_batch( # type: ignore + def _update_with_batch( # type: ignore[override] self, - batch: RolloutBatchProtocol, + batch: BatchWithReturnsProtocol, ) -> DiscreteCRRTrainingStats: if self._target and self._iter % self._freq == 0: self._update_lagged_network_weights() diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 2929a99a2..19b17bac3 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -155,7 +155,7 @@ def __init__( raise TypeError("GAIL requires the policy to use an actor with known output dimension.") self.action_dim = actor.get_output_dim() - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -168,7 +168,7 @@ def preprocess_batch( # update reward with torch.no_grad(): batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten()) - return super().preprocess_batch(batch, buffer, indices) + return super()._preprocess_batch(batch, buffer, indices) def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: device = torch_device(self.disc_net) @@ -176,9 +176,9 @@ def disc(self, batch: RolloutBatchProtocol) -> torch.Tensor: act = to_torch(batch.act, device=device) return self.disc_net(torch.cat([obs, act], dim=1)) - def _update_with_batch( + def _update_with_batch( # type: ignore[override] self, - batch: RolloutBatchProtocol, + batch: LogpOldProtocol, batch_size: int | None, repeat: int, ) -> GailTrainingStats: diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index b77666056..7bb7cd55b 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -126,22 +126,22 @@ def __init__( forward_loss_weight=forward_loss_weight, ) - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: self._icm_preprocess_batch(batch) - return super().preprocess_batch(batch, buffer, indices) + return super()._preprocess_batch(batch, buffer, indices) - def postprocess_batch( + def _postprocess_batch( self, - batch: BatchProtocol, + batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: - super().postprocess_batch(batch, buffer, indices) + super()._postprocess_batch(batch, buffer, indices) self._icm_postprocess_batch(batch) def _wrapper_update_with_batch( @@ -185,22 +185,22 @@ def __init__( forward_loss_weight=forward_loss_weight, ) - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> RolloutBatchProtocol: self._icm_preprocess_batch(batch) - return super().preprocess_batch(batch, buffer, indices) + return super()._preprocess_batch(batch, buffer, indices) - def postprocess_batch( + def _postprocess_batch( self, - batch: BatchProtocol, + batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray, ) -> None: - super().postprocess_batch(batch, buffer, indices) + super()._postprocess_batch(batch, buffer, indices) self._icm_postprocess_batch(batch) def _wrapper_update_with_batch( diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index cfade460b..1236e4f34 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -198,7 +198,7 @@ def __init__( self.ent_coef = ent_coef self.max_grad_norm = max_grad_norm - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -208,9 +208,9 @@ def preprocess_batch( batch.act = to_torch_as(batch.act, batch.v_s) return batch - def _update_with_batch( + def _update_with_batch( # type: ignore[override] self, - batch: RolloutBatchProtocol, + batch: BatchWithAdvantagesProtocol, batch_size: int | None, repeat: int, ) -> A2CTrainingStats: diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 299c6799e..63f3f8abf 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -190,7 +190,7 @@ def _compute_return( batch.weight = to_torch_as(batch.weight, target_q_torch) return cast(BatchWithReturnsProtocol, batch) - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -199,9 +199,9 @@ def preprocess_batch( """Compute the 1-step return for BDQ targets.""" return self._compute_return(batch, buffer, indices) - def _update_with_batch( + def _update_with_batch( # type: ignore[override] self, - batch: RolloutBatchProtocol, + batch: BatchWithReturnsProtocol, ) -> SimpleLossTrainingStats: self._periodically_update_lagged_network_weights() weight = batch.pop("weight", 1.0) diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index d85c50f72..a67212c1c 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -275,7 +275,7 @@ def _minimize_critic_squared_loss( optimizer.step(critic_loss) return td, critic_loss - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c878a3afb..48aa7500b 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -246,7 +246,7 @@ def use_target_network(self) -> bool: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: pass - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index dbb15ec87..df682f23e 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -108,7 +108,7 @@ def __init__( # adjusts Hessian-vector product calculation for numerical stability self._damping = 0.1 - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -125,9 +125,9 @@ def preprocess_batch( batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() return batch - def _update_with_batch( + def _update_with_batch( # type: ignore[override] self, - batch: RolloutBatchProtocol, + batch: BatchWithAdvantagesProtocol, batch_size: int | None, repeat: int, ) -> NPGTrainingStats: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index fe0c2557d..34a98c0f6 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -343,7 +343,7 @@ def __init__( ) self.optim = self._create_optimizer(self.policy, optim) - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 5f869cddd..4fb3e8909 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -127,7 +127,7 @@ def __init__( self.norm_adv = advantage_normalization self.recompute_adv = recompute_advantage - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -145,9 +145,9 @@ def preprocess_batch( batch.logp_old = torch.cat(logp_old, dim=0).flatten() return cast(LogpOldProtocol, batch) - def _update_with_batch( + def _update_with_batch( # type: ignore[override] self, - batch: RolloutBatchProtocol, + batch: LogpOldProtocol, batch_size: int | None, repeat: int, ) -> A2CTrainingStats: @@ -156,7 +156,10 @@ def _update_with_batch( split_batch_size = batch_size or -1 for step in range(repeat): if self.recompute_adv and step > 0: - batch = self._add_returns_and_advantages(batch, self._buffer, self._indices) + batch = cast( + LogpOldProtocol, + self._add_returns_and_advantages(batch, self._buffer, self._indices), + ) for minibatch in batch.split(split_batch_size, merge_last=True): gradient_steps += 1 # calculate loss for actor diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index d5a0b8e34..c73bc06b0 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -6,7 +6,7 @@ from torch.distributions import kl_divergence from tianshou.data import SequenceSummaryStats -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import BatchWithAdvantagesProtocol from tianshou.policy import NPG from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import ActorPolicy @@ -108,9 +108,9 @@ def __init__( self.max_kl = max_kl self.backtrack_coeff = backtrack_coeff - def _update_with_batch( + def _update_with_batch( # type: ignore[override] self, - batch: RolloutBatchProtocol, + batch: BatchWithAdvantagesProtocol, batch_size: int | None, repeat: int, ) -> TRPOTrainingStats: diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 5a3df68ea..0b4a2fbcc 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -243,7 +243,7 @@ def dispatch_process_fn( tmp_batch.obs = tmp_batch.obs.obs if hasattr(tmp_batch.obs_next, "obs"): tmp_batch.obs_next = tmp_batch.obs_next.obs - results[agent] = algorithm.preprocess_batch(tmp_batch, buffer, tmp_indice) + results[agent] = algorithm._preprocess_batch(tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return cast(MAPRolloutBatchProtocol, Batch(results)) @@ -291,7 +291,7 @@ def __init__( def get_algorithm(self, agent_id: str | int) -> OffPolicyAlgorithm: return self._dispatcher.algorithms[agent_id] - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, @@ -334,7 +334,7 @@ def __init__( def get_algorithm(self, agent_id: str | int) -> OnPolicyAlgorithm: return self._dispatcher.algorithms[agent_id] - def preprocess_batch( + def _preprocess_batch( self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, From 3f8948b6aca653e1380abf4aba8afda2e9b736d0 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 21:13:34 +0200 Subject: [PATCH 118/230] v2: Clean up 'reward_normalization' parameter * Rename to `return_standardization` in `Reinforce` and `DiscreteCRR` (as it applies standardization of returns) * Rename to `return_scaling` in actor-critic on-policy algorithms (A2C, PPO, GAIL, NPG, TRPO), where it only applies scaling with the standard deviation * Removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) --- CHANGELOG.md | 6 ++- examples/atari/atari_ppo.py | 2 +- examples/inverse/irl_gail.py | 2 +- examples/mujoco/mujoco_a2c.py | 2 +- examples/mujoco/mujoco_npg.py | 2 +- examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_ppo.py | 2 +- examples/mujoco/mujoco_reinforce.py | 2 +- examples/mujoco/mujoco_reinforce_hl.py | 2 +- examples/mujoco/mujoco_trpo.py | 2 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- examples/vizdoom/vizdoom_ppo.py | 2 +- test/continuous/test_npg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_trpo.py | 2 +- test/discrete/test_a2c_with_il.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo2.py | 2 +- test/modelbased/test_ppo_icm.py | 2 +- test/offline/test_gail.py | 2 +- test/pettingzoo/pistonball_continuous.py | 2 +- tianshou/highlevel/params/policy_params.py | 57 +++++++++++++++++----- tianshou/policy/base.py | 9 ++-- tianshou/policy/imitation/discrete_bcq.py | 5 -- tianshou/policy/imitation/discrete_cql.py | 9 ++-- tianshou/policy/imitation/discrete_crr.py | 8 +-- tianshou/policy/imitation/gail.py | 22 ++++++--- tianshou/policy/modelfree/a2c.py | 39 +++++++++++---- tianshou/policy/modelfree/bdqn.py | 8 +-- tianshou/policy/modelfree/c51.py | 8 +-- tianshou/policy/modelfree/dqn.py | 13 ++--- tianshou/policy/modelfree/fqf.py | 9 ++-- tianshou/policy/modelfree/iqn.py | 9 ++-- tianshou/policy/modelfree/npg.py | 18 +++++-- tianshou/policy/modelfree/pg.py | 19 +++----- tianshou/policy/modelfree/ppo.py | 26 +++++----- tianshou/policy/modelfree/qrdqn.py | 8 +-- tianshou/policy/modelfree/trpo.py | 18 +++++-- 38 files changed, 207 insertions(+), 124 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a17c00da..72f44dae7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,8 +83,12 @@ `LRSchedulerFactory`). The parameter `lr_scheduler` has thus been removed from all algorithm constructors. * The flag `updating` has been removed (no internal usage, general usefulness questionable). - * Parameter name changes: + * Parameter changes: * `discount_factor` -> `gamma` (was already used internally almost everywhere) + * `reward_normalization` -> `return_standardization` or `return_scaling` (more precise naming) or removed (was actually unsupported by Q-learning algorithms) + * `return_standardization` in `Reinforce` and `DiscreteCRR` (as it applies standardization of returns) + * `return_scaling` in actor-critic on-policy algorithms (A2C, PPO, GAIL, NPG, TRPO) + * removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 6c1295e0e..4060f55e7 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -161,7 +161,7 @@ def main(args: argparse.Namespace = get_args()) -> None: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 13af8a387..cbb160451 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -224,7 +224,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 6f0fbc212..e0041b99b 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -156,7 +156,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, ) # load a previous policy diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 50152e437..ca178557f 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -151,7 +151,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: optim=optim, gamma=args.gamma, gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, advantage_normalization=args.norm_adv, optim_critic_iters=args.optim_critic_iters, actor_step_size=args.actor_step_size, diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 63dfc7b89..dcebfbe93 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -67,7 +67,7 @@ def main( discount_factor=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - reward_normalization=rew_norm, + return_standardization=rew_norm, advantage_normalization=norm_adv, optim_critic_iters=optim_critic_iters, actor_step_size=actor_step_size, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index c2f775934..c65b302f4 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -157,7 +157,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index d8739a30b..42546f2cf 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -135,7 +135,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, optim=optim, gamma=args.gamma, - reward_normalization=args.rew_norm, + return_standardization=args.rew_norm, ) # load a previous policy diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 27af61efd..4bdd2918e 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -62,7 +62,7 @@ def main( ReinforceParams( discount_factor=gamma, action_bound_method=action_bound_method, - reward_normalization=rew_norm, + return_standardization=rew_norm, lr=lr, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 15006bff8..1620e7ac3 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -154,7 +154,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: optim=optim, gamma=args.gamma, gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, advantage_normalization=args.norm_adv, optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index c2518d07e..f3531c1e5 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -69,7 +69,7 @@ def main( discount_factor=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - reward_normalization=rew_norm, + return_standardization=rew_norm, advantage_normalization=norm_adv, optim_critic_iters=optim_critic_iters, max_kl=max_kl, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 7a9a48974..bcf94f70b 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -165,7 +165,7 @@ def dist(logits: torch.Tensor) -> Categorical: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index aaa931e77..028d5d13f 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -116,7 +116,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: critic=critic, optim=AdamOptimizerFactory(lr=args.lr), gamma=args.gamma, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, advantage_normalization=args.norm_adv, gae_lambda=args.gae_lambda, optim_critic_iters=args.optim_critic_iters, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 54ffd523a..7af8b914b 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -118,7 +118,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 72e713d42..dbb68c854 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -116,7 +116,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: critic=critic, optim=optim, gamma=args.gamma, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, advantage_normalization=args.norm_adv, gae_lambda=args.gae_lambda, optim_critic_iters=args.optim_critic_iters, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index a03e76497..3b2b5f145 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -112,7 +112,7 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: vf_coef=args.vf_coef, ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, ) # collector train_collector = Collector[CollectStats]( diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index bf3fbae4c..fa38a9068 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -86,7 +86,7 @@ def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tru policy=policy, optim=optim, gamma=args.gamma, - reward_normalization=args.rew_norm, + return_standardization=args.rew_norm, ) for m in net.modules(): if isinstance(m, torch.nn.Linear): diff --git a/test/discrete/test_ppo2.py b/test/discrete/test_ppo2.py index 5fe7e7dac..90e7dd930 100644 --- a/test/discrete/test_ppo2.py +++ b/test/discrete/test_ppo2.py @@ -115,7 +115,7 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, dual_clip=args.dual_clip, value_clip=args.value_clip, advantage_normalization=args.norm_adv, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 794ff3ac4..566ea4746 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -135,7 +135,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, dual_clip=args.dual_clip, value_clip=args.value_clip, advantage_normalization=args.norm_adv, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index e069fbf2a..6f787c022 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -150,7 +150,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 6a4d60a67..b6df7a8a2 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -209,7 +209,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, + return_scaling=args.rew_norm, advantage_normalization=args.norm_adv, recompute_advantage=args.recompute_adv, # dual_clip=args.dual_clip, diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index bb45d56b4..57e826d5a 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -334,20 +334,13 @@ class ParamsMixinDeterministicEval: """ -@dataclass(kw_only=True) -class ReinforceParams( +class OnPolicyAlgorithmParams( Params, ParamsMixinGamma, ParamsMixinActionScaling, ParamsMixinSingleModel, ParamsMixinDeterministicEval, ): - reward_normalization: bool = False - """ - if True, will normalize the returns by subtracting the running mean and dividing by the running - standard deviation. - """ - def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) @@ -355,6 +348,16 @@ def _get_param_transformers(self) -> list[ParamTransformer]: return transformers +@dataclass(kw_only=True) +class ReinforceParams(OnPolicyAlgorithmParams): + return_standardization: bool = False + """ + whether to standardize episode returns by subtracting the running mean and + dividing by the running standard deviation. + Note that this is known to be detrimental to performance in many cases! + """ + + @dataclass(kw_only=True) class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol): gae_lambda: float = 0.95 @@ -384,7 +387,26 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) -class A2CParams(ReinforceParams, ParamsMixinGeneralAdvantageEstimation): +class ActorCriticOnPolicyParams(OnPolicyAlgorithmParams): + return_scaling: bool = False + """ + flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ + + +@dataclass(kw_only=True) +class A2CParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation): vf_coef: float = 0.5 """weight (coefficient) of the value loss in the loss function""" ent_coef: float = 0.01 @@ -595,8 +617,21 @@ class QLearningOffPolicyParams( ): target_update_freq: int = 0 """the target network update frequency (0 if no target network is to be used)""" - reward_normalization: bool = False - """whether to normalize the returns to Normal(0, 1)""" + return_scaling: bool = False + """ + flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. + """ eps_training: float = 0.0 """ the epsilon value for epsilon-greedy exploration during training. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 142c9bd17..748208dbc 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -754,7 +754,7 @@ def compute_nstep_return( target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, - rew_norm: bool = False, + return_scaling: bool = False, ) -> BatchWithReturnsProtocol: r""" Computes the n-step return for Q-learning targets, adds it to the batch and returns the resulting batch. @@ -780,12 +780,11 @@ def compute_nstep_return( Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step: the number of estimation step, should be an int greater than 0. - :param rew_norm: normalize the reward to Normal(0, 1). - TODO: passing True is not supported and will cause an error! - :return: a Batch. The result will be stored in batch.returns as a + :param return_scaling: whether to standardise returns to Normal(0, 1); + supported is currently suspended! + :return: a Batch. The result will be stored in `batch.returns` as a torch.Tensor with the same shape as target_q_fn's return tensor. """ - assert not rew_norm, "Reward normalization in computing n-step returns is unsupported now." if len(indices) != len(batch): raise ValueError(f"Batch size {len(batch)} and indices size {len(indices)} mismatch.") diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index ddbffaa66..c53f6aa1b 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -110,7 +110,6 @@ def __init__( target_update_freq: int = 8000, eval_eps: float = 1e-3, imitation_logits_penalty: float = 1e-2, - reward_normalization: bool = False, is_double: bool = True, clip_loss_grad: bool = False, ) -> None: @@ -144,8 +143,6 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? :param is_double: use double dqn. :param clip_loss_grad: clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of @@ -167,7 +164,6 @@ def __init__( self._iter = 0 if self._target: self.model_old = self._add_lagged_network(self.policy.model) - self.rew_norm = reward_normalization self.is_double = is_double self.clip_loss_grad = clip_loss_grad assert 0.0 <= eval_eps < 1.0 @@ -187,7 +183,6 @@ def _preprocess_batch( target_q_fn=self._target_q, gamma=self.gamma, n_step=self.n_step, - rew_norm=self.rew_norm, ) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 4b5167485..884310044 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -33,7 +33,7 @@ def __init__( num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -57,8 +57,9 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? + :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based + on running mean and standard deviation. + Support for this is currently suspended and therefore the flag should not be enabled. """ QRDQN.__init__( self, @@ -68,7 +69,7 @@ def __init__( num_quantiles=num_quantiles, estimation_step=estimation_step, target_update_freq=target_update_freq, - reward_normalization=reward_normalization, + return_scaling=return_scaling, ) self.min_q_weight = min_q_weight diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index af03e436b..6c1e36d0d 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -47,7 +47,7 @@ def __init__( beta: float = 1.0, min_q_weight: float = 10.0, target_update_freq: int = 0, - reward_normalization: bool = False, + return_standardization: bool = False, ) -> None: r""" :param policy: the policy @@ -70,9 +70,9 @@ def __init__( :param min_q_weight: weight for CQL loss/regularizer. Default to 10. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param reward_normalization: if True, will normalize the *returns* + :param return_standardization: whether to standardize episode returns by subtracting the running mean and dividing by the running standard deviation. - Can be detrimental to performance! + Note that this is known to be detrimental to performance in many cases! """ super().__init__( policy=policy, @@ -80,7 +80,7 @@ def __init__( LaggedNetworkFullUpdateAlgorithmMixin.__init__(self) self.discounted_return_computation = DiscountedReturnComputation( gamma=gamma, - reward_normalization=reward_normalization, + return_standardization=return_standardization, ) self.critic = critic self.optim = self._create_optimizer(ModuleList([self.policy, self.critic]), optim) diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 19b17bac3..a655d7640 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -29,7 +29,7 @@ class GailTrainingStats(A2CTrainingStats): class GAIL(PPO): - r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.""" + """Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.""" def __init__( self, @@ -52,10 +52,9 @@ def __init__( gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: - r""" + """ :param policy: the policy (which must use an actor with known output dimension, i.e. any Tianshou `Actor` implementation or other subclass of `ModuleWithVectorOutput`). :param critic: the critic network. (s -> V(s)) @@ -127,7 +126,18 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param reward_normalization: normalize estimated values to have std close to 1. + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. """ super().__init__( policy=policy, @@ -144,7 +154,7 @@ def __init__( gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, - reward_normalization=reward_normalization, + return_scaling=return_scaling, ) self.disc_net = disc_net self.disc_optim = self._create_optimizer(self.disc_net, disc_optim) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 1236e4f34..0eabea5a5 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -43,7 +43,7 @@ def __init__( gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param critic: the critic network. (s -> V(s)) @@ -76,7 +76,18 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param reward_normalization: normalize estimated values to have std close to 1. + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. """ super().__init__( policy=policy, @@ -92,7 +103,7 @@ def __init__( else: self.optim = self._create_optimizer(self.critic, optim, max_grad_norm=max_grad_norm) self.gamma = gamma - self.rew_norm = reward_normalization + self.return_scaling = return_scaling self.ret_rms = RunningMeanStd() self._eps = 1e-8 @@ -115,7 +126,7 @@ def _add_returns_and_advantages( # consistent with OPENAI baselines' value normalization pipeline. Empirical # study also shows that "minus mean" will harm performances a tiny little bit # due to unknown reasons (on Mujoco envs, not confident, though). - if self.rew_norm: # unnormalize v_s & v_s_ + if self.return_scaling: # unnormalize v_s & v_s_ v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( @@ -127,7 +138,7 @@ def _add_returns_and_advantages( gamma=self.gamma, gae_lambda=self.gae_lambda, ) - if self.rew_norm: + if self.return_scaling: batch.returns = unnormalized_returns / np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: @@ -152,8 +163,7 @@ def __init__( gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, - # TODO: This algorithm does not seem to use the reward_normalization parameter. - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param policy: the policy containing the actor network. @@ -181,7 +191,18 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param reward_normalization: normalize estimated values to have std close to 1. + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. """ super().__init__( policy=policy, @@ -192,7 +213,7 @@ def __init__( gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, - reward_normalization=reward_normalization, + return_scaling=return_scaling, ) self.vf_coef = vf_coef self.ent_coef = ent_coef diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 63f3f8abf..ef1655341 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -109,7 +109,7 @@ def __init__( gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, - reward_normalization: bool = False, + return_scaling: bool = False, is_double: bool = True, ) -> None: """ @@ -131,8 +131,9 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? + :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based + on running mean and standard deviation. + Support for this is currently suspended and therefore the flag should not be enabled. :param is_double: whether to use double DQN. """ assert ( @@ -144,7 +145,6 @@ def __init__( gamma=gamma, estimation_step=estimation_step, target_update_freq=target_update_freq, - reward_normalization=reward_normalization, ) self.is_double = is_double diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 448fa9b3c..c7414cd50 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -78,7 +78,7 @@ def __init__( gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param policy: a policy following the rules (s -> action_values_BA) @@ -99,8 +99,9 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? + :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based + on running mean and standard deviation. + Support for this is currently suspended and therefore the flag should not be enabled. """ super().__init__( policy=policy, @@ -108,7 +109,6 @@ def __init__( gamma=gamma, estimation_step=estimation_step, target_update_freq=target_update_freq, - reward_normalization=reward_normalization, ) self.delta_z = (policy.v_max - policy.v_min) / (policy.num_atoms - 1) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 48aa7500b..41d77d970 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -192,7 +192,6 @@ def __init__( gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, - reward_normalization: bool = False, ) -> None: """ :param policy: the policy @@ -213,8 +212,6 @@ def __init__( complete episode returns. :param target_update_freq: the frequency with which to update the weights of the target network; 0 if a target network shall not be used. - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? """ super().__init__( policy=policy, @@ -227,7 +224,6 @@ def __init__( estimation_step > 0 ), f"estimation_step should be greater than 0 but got: {estimation_step}" self.n_step = estimation_step - self.rew_norm = reward_normalization self.target_update_freq = target_update_freq # TODO: 1 would be a more reasonable initialization given how it is incremented self._iter = 0 @@ -264,7 +260,6 @@ def _preprocess_batch( target_q_fn=self._target_q, gamma=self.gamma, n_step=self.n_step, - rew_norm=self.rew_norm, ) def _periodically_update_lagged_network_weights(self) -> None: @@ -298,7 +293,7 @@ def __init__( gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, - reward_normalization: bool = False, + return_scaling: bool = False, is_double: bool = True, clip_loss_grad: bool = False, ) -> None: @@ -321,8 +316,9 @@ def __init__( complete episode returns. :param target_update_freq: the frequency with which to update the weights of the target network; 0 if a target network shall not be used. - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? + :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based + on running mean and standard deviation. + Support for this is currently suspended and therefore the flag should not be enabled. :param is_double: use double dqn. :param clip_loss_grad: clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of @@ -334,7 +330,6 @@ def __init__( gamma=gamma, estimation_step=estimation_step, target_update_freq=target_update_freq, - reward_normalization=reward_normalization, ) self.is_double = is_double self.clip_loss_grad = clip_loss_grad diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index c2a347561..9ac4fd6cd 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -123,7 +123,7 @@ def __init__( ent_coef: float = 0.0, estimation_step: int = 1, target_update_freq: int = 0, - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -148,8 +148,9 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? + :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based + on running mean and standard deviation. + Support for this is currently suspended and therefore the flag should not be enabled. """ super().__init__( policy=policy, @@ -158,7 +159,7 @@ def __init__( num_quantiles=num_fractions, estimation_step=estimation_step, target_update_freq=target_update_freq, - reward_normalization=reward_normalization, + return_scaling=return_scaling, ) self.ent_coef = ent_coef self.fraction_optim = self._create_optimizer(self.policy.fraction_model, fraction_optim) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index dad67d621..54363b949 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -115,7 +115,7 @@ def __init__( num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -138,8 +138,9 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? + :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based + on running mean and standard deviation. + Support for this is currently suspended and therefore the flag should not be enabled. """ super().__init__( policy=policy, @@ -148,7 +149,7 @@ def __init__( num_quantiles=num_quantiles, estimation_step=estimation_step, target_update_freq=target_update_freq, - reward_normalization=reward_normalization, + return_scaling=return_scaling, ) def _update_with_batch( diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index df682f23e..0e7674d22 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -42,8 +42,7 @@ def __init__( gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param policy: the policy containing the actor network. @@ -90,7 +89,18 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param reward_normalization: normalize estimated values to have std close to 1. + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. """ super().__init__( policy=policy, @@ -100,7 +110,7 @@ def __init__( gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, - reward_normalization=reward_normalization, + return_scaling=return_scaling, ) self.norm_adv = advantage_normalization self.optim_critic_iters = optim_critic_iters diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 34a98c0f6..019ec74b4 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -243,7 +243,7 @@ class DiscountedReturnComputation: def __init__( self, gamma: float = 0.99, - reward_normalization: bool = False, + return_standardization: bool = False, ): """ :param gamma: the discount factor in [0, 1] for future rewards. @@ -253,13 +253,13 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param reward_normalization: if True, will normalize the *returns* + :param return_standardization: whether to standardize episode returns by subtracting the running mean and dividing by the running standard deviation. - Can be detrimental to performance! + Note that this is known to be detrimental to performance in many cases! """ assert 0.0 <= gamma <= 1.0, "discount factor gamma should be in [0, 1]" self.gamma = gamma - self.rew_norm = reward_normalization + self.return_standardization = return_standardization self.ret_rms = RunningMeanStd() self.eps = 1e-8 @@ -295,10 +295,7 @@ def add_discounted_returns( gamma=self.gamma, gae_lambda=1.0, ) - # TODO: overridden in A2C, where mean is not subtracted. Subtracting mean - # can be very detrimental! It also has no theoretical grounding. - # This should be addressed soon! - if self.rew_norm: + if self.return_standardization: batch.returns = (unnormalized_returns - self.ret_rms.mean) / np.sqrt( self.ret_rms.var + self.eps, ) @@ -317,7 +314,7 @@ def __init__( *, policy: TActorPolicy, gamma: float = 0.99, - reward_normalization: bool = False, + return_standardization: bool = False, optim: OptimizerFactory, ) -> None: """ @@ -330,7 +327,7 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param reward_normalization: if True, will normalize the *returns* + :param return_standardization: if True, will scale/standardize returns by subtracting the running mean and dividing by the running standard deviation. Can be detrimental to performance! """ @@ -339,7 +336,7 @@ def __init__( ) self.discounted_return_computation = DiscountedReturnComputation( gamma=gamma, - reward_normalization=reward_normalization, + return_standardization=return_standardization, ) self.optim = self._create_optimizer(self.policy, optim) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 4fb3e8909..c9fdc6874 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -14,13 +14,7 @@ class PPO(A2C): - r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. - - .. seealso:: - - Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed - explanation. - """ + """Implementation of Proximal Policy Optimization. arXiv:1707.06347.""" def __init__( self, @@ -39,8 +33,7 @@ def __init__( gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: r""" :param policy: the policy containing the actor network. @@ -103,7 +96,18 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param reward_normalization: normalize estimated values to have std close to 1. + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. """ assert ( dual_clip is None or dual_clip > 1.0 @@ -119,7 +123,7 @@ def __init__( gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, - reward_normalization=reward_normalization, + return_scaling=return_scaling, ) self.eps_clip = eps_clip self.dual_clip = dual_clip diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 6f7a67a89..56be073d3 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -38,7 +38,7 @@ def __init__( num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -61,8 +61,9 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param reward_normalization: normalize the **returns** to Normal(0, 1). - TODO: rename to return_normalization? + :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based + on running mean and standard deviation. + Support for this is currently suspended and therefore the flag should not be enabled. """ assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" super().__init__( @@ -71,7 +72,6 @@ def __init__( gamma=gamma, estimation_step=estimation_step, target_update_freq=target_update_freq, - reward_normalization=reward_normalization, ) self.num_quantiles = num_quantiles tau = torch.linspace(0, 1, self.num_quantiles + 1) diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index c73bc06b0..2bbe43e76 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -38,8 +38,7 @@ def __init__( gae_lambda: float = 0.95, max_batchsize: int = 256, gamma: float = 0.99, - # TODO: rename to return_normalization? - reward_normalization: bool = False, + return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -90,7 +89,18 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param reward_normalization: normalize estimated values to have std close to 1. + :param return_scaling: flag indicating whether to enable scaling of estimated returns by + dividing them by their running standard deviation without centering the mean. + This reduces the magnitude variation of advantages across different episodes while + preserving their signs and relative ordering. + The use of running statistics (rather than batch-specific scaling) means that early + training experiences may be scaled differently than later ones as the statistics evolve. + When enabled, this improves training stability in environments with highly variable + reward scales and makes the algorithm less sensitive to learning rate settings. + However, it may reduce the algorithm's ability to distinguish between episodes with + different absolute return magnitudes. + Best used in environments where the relative ordering of actions is more important + than the absolute scale of returns. """ super().__init__( policy=policy, @@ -102,7 +112,7 @@ def __init__( gae_lambda=gae_lambda, max_batchsize=max_batchsize, gamma=gamma, - reward_normalization=reward_normalization, + return_scaling=return_scaling, ) self.max_backtracks = max_backtracks self.max_kl = max_kl From ab603d54331d063add56396d55cb34f4a72cd5ad Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 21:16:14 +0200 Subject: [PATCH 119/230] v2: Update parameter references (discount_factor -> gamma) --- examples/atari/atari_dqn_hl.py | 2 +- examples/atari/atari_iqn_hl.py | 2 +- examples/atari/atari_ppo_hl.py | 2 +- examples/discrete/discrete_dqn_hl.py | 2 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl_multi.py | 2 +- examples/mujoco/mujoco_reinforce_hl.py | 2 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 63f7b5afe..6d76b5dde 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -77,7 +77,7 @@ def main( DQNExperimentBuilder(env_factory, experiment_config, training_config) .with_dqn_params( DQNParams( - discount_factor=gamma, + gamma=gamma, estimation_step=n_step, lr=lr, target_update_freq=target_update_freq, diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index e12267010..5a18dd82f 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -75,7 +75,7 @@ def main( IQNExperimentBuilder(env_factory, experiment_config, training_config) .with_iqn_params( IQNParams( - discount_factor=gamma, + gamma=gamma, estimation_step=n_step, lr=lr, sample_size=sample_size, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 32673f937..92304946d 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -83,7 +83,7 @@ def main( PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, reward_normalization=rew_norm, ent_coef=ent_coef, diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index 5f2e3ec97..5f85b2b51 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -41,7 +41,7 @@ def main() -> None: .with_dqn_params( DQNParams( lr=1e-3, - discount_factor=0.9, + gamma=0.9, estimation_step=3, target_update_freq=320, eps_training=0.3, diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 6f2fb6fa4..788dc1e18 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -65,7 +65,7 @@ def main( A2CExperimentBuilder(env_factory, experiment_config, training_config) .with_a2c_params( A2CParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, reward_normalization=rew_norm, diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index dcebfbe93..3b2630c5d 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -64,7 +64,7 @@ def main( NPGExperimentBuilder(env_factory, experiment_config, training_config) .with_npg_params( NPGParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, return_standardization=rew_norm, diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 997173084..b0c021e1f 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -69,7 +69,7 @@ def main( PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, reward_normalization=rew_norm, diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index a8d3ae828..cc2ccedc9 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -98,7 +98,7 @@ def main( PPOExperimentBuilder(env_factory, experiment_config, training_config) .with_ppo_params( PPOParams( - discount_factor=0.99, + gamma=0.99, gae_lambda=0.95, action_bound_method="clip", reward_normalization=True, diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 4bdd2918e..9fd13e462 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -60,7 +60,7 @@ def main( ReinforceExperimentBuilder(env_factory, experiment_config, training_config) .with_reinforce_params( ReinforceParams( - discount_factor=gamma, + gamma=gamma, action_bound_method=action_bound_method, return_standardization=rew_norm, lr=lr, diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index f3531c1e5..310d2a39c 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -66,7 +66,7 @@ def main( TRPOExperimentBuilder(env_factory, experiment_config, training_config) .with_trpo_params( TRPOParams( - discount_factor=gamma, + gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, return_standardization=rew_norm, From e2f1ca2b97680cabc033429e646fa879f7aee7a6 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 21:47:54 +0200 Subject: [PATCH 120/230] Fix message assignment --- tianshou/utils/determinism.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index a4d79e73b..c29c72898 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -351,11 +351,11 @@ def check( if core_messages_changed_only: status_message = ( "The behaviour log has changed, but the core messages are still the same (so this " - f"probably isn't an issue). {main_message}", + f"probably isn't an issue). {main_message}" ) else: status_message = ( - f"The behaviour log has changed; even the core messages are different. {main_message}", + f"The behaviour log has changed; even the core messages are different. {main_message}" ) # write log message From 78d42fd985cd599b00270b642063c37f391555c1 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 23:32:31 +0200 Subject: [PATCH 121/230] v2: Fix handling of eps in DiscreteBCQ (moved to policy, inheriting from DQNPolicy) Changed default value from 1e-3 to 0.0 (no randomness by default) --- examples/offline/atari_bcq.py | 2 +- test/offline/test_discrete_bcq.py | 2 +- tianshou/policy/imitation/discrete_bcq.py | 19 ++++++++++++------- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index d4cb49fed..5ef5aa59f 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -123,6 +123,7 @@ def main(args: argparse.Namespace = get_args()) -> None: imitator=imitation_net, action_space=env.action_space, unlikely_action_threshold=args.unlikely_action_threshold, + eps_inference=args.eps_test, ) algorithm: DiscreteBCQ = DiscreteBCQ( policy=policy, @@ -130,7 +131,6 @@ def main(args: argparse.Namespace = get_args()) -> None: gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, - eval_eps=args.eps_test, imitation_logits_penalty=args.imitation_logits_penalty, ) # load a previous policy diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 6f5635237..b998d2cf6 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -97,6 +97,7 @@ def test_discrete_bcq( imitator=imitation_net, action_space=env.action_space, unlikely_action_threshold=args.unlikely_action_threshold, + eps_inference=args.eps_test, ) algorithm: DiscreteBCQ = DiscreteBCQ( policy=policy, @@ -104,7 +105,6 @@ def test_discrete_bcq( gamma=args.gamma, estimation_step=args.n_step, target_update_freq=args.target_update_freq, - eval_eps=args.eps_test, imitation_logits_penalty=args.imitation_logits_penalty, ) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index c53f6aa1b..7022c00f2 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -17,8 +17,8 @@ from tianshou.policy.base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, - Policy, ) +from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.pg import SimpleLossTrainingStats from tianshou.policy.optim import OptimizerFactory @@ -33,7 +33,7 @@ class DiscreteBCQTrainingStats(SimpleLossTrainingStats): reg_loss: float -class DiscreteBCQPolicy(Policy): +class DiscreteBCQPolicy(DQNPolicy): def __init__( self, *, @@ -43,6 +43,7 @@ def __init__( unlikely_action_threshold: float = 0.3, action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, + eps_inference: float = 0.0, ) -> None: """ :param model: a model following the rules (s_B -> action_values_BA) @@ -55,12 +56,20 @@ def __init__( you do not use the target network). :param action_space: the environment's action space. :param observation_space: the environment's observation space. + :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, + i.e. non-training cases (such as evaluation during test steps). + The epsilon value is the probability of choosing a random action instead of the action + chosen by the policy. + A value of 0.0 means no exploration (fully greedy) and a value of 1.0 means full + exploration (fully random). """ super().__init__( + model=model, action_space=action_space, observation_space=observation_space, + eps_training=0.0, # no training data collection (offline) + eps_inference=eps_inference, ) - self.model = model self.imitator = imitator assert ( target_update_freq > 0 @@ -108,7 +117,6 @@ def __init__( gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 8000, - eval_eps: float = 1e-3, imitation_logits_penalty: float = 1e-2, is_double: bool = True, clip_loss_grad: bool = False, @@ -131,7 +139,6 @@ def __init__( bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. :param target_update_freq: the target network update frequency. - :param eval_eps: the epsilon-greedy noise added in evaluation. :param imitation_logits_penalty: regularization weight for imitation logits. :param estimation_step: the number of future steps (> 0) to consider when computing temporal @@ -166,8 +173,6 @@ def __init__( self.model_old = self._add_lagged_network(self.policy.model) self.is_double = is_double self.clip_loss_grad = clip_loss_grad - assert 0.0 <= eval_eps < 1.0 - self.eps = eval_eps self._weight_reg = imitation_logits_penalty def _preprocess_batch( From 521123ac8111e941a2da7a6d3143960a5539d55b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 23:37:03 +0200 Subject: [PATCH 122/230] v2: Disable determinism tests by default (only on demand) --- test/determinism_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/determinism_test.py b/test/determinism_test.py index cedc08007..7dfdb2dfb 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -38,7 +38,7 @@ class AlgorithmDeterminismTest: 3. Inspect determinism_tests.log """ - ENABLED = True + ENABLED = False """ whether determinism tests are enabled. """ From c7fd6612dee8da9f05afb773ad4c4bcc7ac0edd0 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 12:17:22 +0200 Subject: [PATCH 123/230] v2: Update Algorithm method name references --- CHANGELOG.md | 8 ++++---- README.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72f44dae7..928689012 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,10 +68,10 @@ * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. * Interface changes/improvements: - * Core methods have been renamed: - * `process_fn` -> `preprocess_batch` - * `post_process_fn` -> `postprocess_batch` - * `learn` -> `_update_with_batch` (no longer in public interface) + * Core methods have been renamed (and removed from the public interface): + * `process_fn` -> `_preprocess_batch` + * `post_process_fn` -> `_postprocess_batch` + * `learn` -> `_update_with_batch` * The updating interface has been cleaned up: * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. diff --git a/README.md b/README.md index 5b6309527..5dbb66226 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ Reinforcement learning algorithms are build on abstractions for all of which clearly separate the core algorithm from the training process and the respective environment interactions. In each case, the implementation of an algorithm necessarily involves only the implementation of methods for - * pre-processing a batch of data, augmenting it with necessary information/sufficient statistics for learning (`preprocess_batch`), + * pre-processing a batch of data, augmenting it with necessary information/sufficient statistics for learning (`_preprocess_batch`), * updating model parameters based on an augmented batch of data (`_update_with_batch`). The implementation of these methods suffices for a new algorithm to be applicable within Tianshou, From 3107bd39685c387c3be9faf6ff4a9e45ecab04f5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 12:19:17 +0200 Subject: [PATCH 124/230] v2: Remove obsolete module utils.optim (superseded by Algorithm.Optimizer) --- tianshou/utils/optim.py | 27 --------------------------- 1 file changed, 27 deletions(-) delete mode 100644 tianshou/utils/optim.py diff --git a/tianshou/utils/optim.py b/tianshou/utils/optim.py deleted file mode 100644 index ce59edce1..000000000 --- a/tianshou/utils/optim.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -from torch import nn - - -def optim_step( - loss: torch.Tensor, - optim: torch.optim.Optimizer, - module: nn.Module | None = None, - max_grad_norm: float | None = None, -) -> None: - """Perform a single optimization step: zero_grad -> backward (-> clip_grad_norm) -> step. - - :param loss: - :param optim: - :param module: the module to optimize, required if max_grad_norm is passed - :param max_grad_norm: if passed, will clip gradients using this - """ - optim.zero_grad() - loss.backward() - if max_grad_norm: - if not module: - raise ValueError( - "module must be passed if max_grad_norm is passed. " - "Note: often the module will be the policy, i.e.`self`", - ) - nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm) - optim.step() From d2c75eb0501b3654f31efa74f2ba0ba79ebe4d39 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 12:33:05 +0200 Subject: [PATCH 125/230] v2: Improve description of parameter 'max_grad_norm' --- tianshou/highlevel/params/policy_params.py | 8 +++++++- tianshou/policy/base.py | 6 +++++- tianshou/policy/imitation/gail.py | 6 +++++- tianshou/policy/modelfree/a2c.py | 13 ++++++++++--- tianshou/policy/modelfree/ppo.py | 6 +++++- 5 files changed, 32 insertions(+), 7 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 57e826d5a..9934ac98a 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -412,7 +412,13 @@ class A2CParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation ent_coef: float = 0.01 """weight (coefficient) of the entropy loss in the loss function""" max_grad_norm: float | None = None - """maximum norm for clipping gradients in backpropagation""" + """ + the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. + """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 0c550b7b9..3064f70ca 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -509,7 +509,11 @@ def __init__( """ :param optim: the optimizer :param module: the module whose parameters are being affected by `optim` - :param max_grad_norm: the maximum gradient norm for gradient clipping; if None, do not apply gradient clipping + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. """ super().__init__() self._optim = optim diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index a655d7640..a9161b48a 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -101,7 +101,11 @@ def __init__( repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the size of parameter updates. + Set to None to disable gradient clipping. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 0eabea5a5..fcf789212 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -50,8 +50,11 @@ def __init__( :param optim: the optimizer factory. :param optim_include_actor: whether the optimizer shall include the actor network's parameters. Pass False for algorithms that shall update only the critic via the optimizer. - :param max_grad_norm: the maximum gradient norm for gradient clipping; if None, gradient clipping - is not applied + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values @@ -171,7 +174,11 @@ def __init__( :param optim: the optimizer factory. :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index c9fdc6874..27738838b 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -76,7 +76,11 @@ def __init__( repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. :param vf_coef: weight for value loss. :param ent_coef: weight for entropy loss. - :param max_grad_norm: clipping gradients in back propagation. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. + When not None, gradients will be rescaled using to ensure their L2 norm does not + exceed this value. This prevents exploding gradients and stabilizes training by + limiting the magnitude of parameter updates. + Set to None to disable gradient clipping. :param gae_lambda: the lambda parameter in [0, 1] for generalized advantage estimation (GAE). Controls the bias-variance tradeoff in advantage estimates, acting as a weighting factor for combining different n-step advantage estimators. Higher values From 26c588bf92e671c7f8d42df6d525ab6a53a5a85a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 12:47:08 +0200 Subject: [PATCH 126/230] v2: Improve descriptions of parameters 'vf_coef' and 'ent_coef' --- tianshou/highlevel/params/policy_params.py | 16 ++++++++++++++-- tianshou/policy/imitation/gail.py | 12 ++++++++++-- tianshou/policy/modelfree/a2c.py | 12 ++++++++++-- tianshou/policy/modelfree/fqf.py | 6 +++++- tianshou/policy/modelfree/ppo.py | 12 ++++++++++-- 5 files changed, 49 insertions(+), 9 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 9934ac98a..ec4478ccd 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -408,9 +408,21 @@ class ActorCriticOnPolicyParams(OnPolicyAlgorithmParams): @dataclass(kw_only=True) class A2CParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation): vf_coef: float = 0.5 - """weight (coefficient) of the value loss in the loss function""" + """ + coefficient that weights the value loss relative to the actor loss in the overall + loss function. + Higher values prioritize accurate value function estimation over policy improvement. + Controls the trade-off between policy optimization and value function fitting. + Typically set between 0.5 and 1.0 for most actor-critic implementations. + """ ent_coef: float = 0.01 - """weight (coefficient) of the entropy loss in the loss function""" + """ + coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. + """ max_grad_norm: float | None = None """ the maximum L2 norm threshold for gradient clipping. diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index a9161b48a..c8ca46af6 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -99,8 +99,16 @@ def __init__( normalization. :param recompute_advantage: whether to recompute advantage every update repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. + :param vf_coef: coefficient that weights the value loss relative to the actor loss in + the overall loss function. + Higher values prioritize accurate value function estimation over policy improvement. + Controls the trade-off between policy optimization and value function fitting. + Typically set between 0.5 and 1.0 for most actor-critic implementations. + :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index fcf789212..5cd09d7fd 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -172,8 +172,16 @@ def __init__( :param policy: the policy containing the actor network. :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory. - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. + :param vf_coef: coefficient that weights the value loss relative to the actor loss in + the overall loss function. + Higher values prioritize accurate value function estimation over policy improvement. + Controls the trade-off between policy optimization and value function fitting. + Typically set between 0.5 and 1.0 for most actor-critic implementations. + :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 9ac4fd6cd..6c101b9d6 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -138,7 +138,11 @@ def __init__( increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_fractions: the number of fractions to use. - :param ent_coef: the coefficient for entropy loss. + :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 27738838b..f580b7485 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -74,8 +74,16 @@ def __init__( normalization. :param recompute_advantage: whether to recompute advantage every update repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. - :param vf_coef: weight for value loss. - :param ent_coef: weight for entropy loss. + :param vf_coef: coefficient that weights the value loss relative to the actor loss in + the overall loss function. + Higher values prioritize accurate value function estimation over policy improvement. + Controls the trade-off between policy optimization and value function fitting. + Typically set between 0.5 and 1.0 for most actor-critic implementations. + :param ent_coef: coefficient that weights the entropy bonus relative to the actor loss. + Controls the exploration-exploitation trade-off by encouraging policy entropy. + Higher values promote more exploration by encouraging a more uniform action distribution. + Lower values focus more on exploitation of the current policy's knowledge. + Typically set between 0.01 and 0.05 for most actor-critic implementations. :param max_grad_norm: the maximum L2 norm threshold for gradient clipping. When not None, gradients will be rescaled using to ensure their L2 norm does not exceed this value. This prevents exploding gradients and stabilizes training by From e607b004a7507a212b7a6c67ab37fb787c7df110 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 13:08:22 +0200 Subject: [PATCH 127/230] v2: Improve descriptions of ICM parameters --- tianshou/policy/modelbased/icm.py | 60 +++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 7bb7cd55b..887b59ac5 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -49,8 +49,24 @@ def __init__( """ :param model: the ICM model. :param optim: the optimizer factory. - :param lr_scale: the scaling factor for ICM learning. - :param forward_loss_weight: the weight for forward model loss. + :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. + Higher values increase the step size during optimization of the intrinsic curiosity module. + Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. + This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. + :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven + rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided + by the environment). + Scales the prediction error (curiosity signal) before adding it to the environment rewards. + Higher values increase the agent's motivation to explore novel states. + Lower values decrease the influence of curiosity relative to task-specific rewards. + Setting to zero effectively disables intrinsic motivation while still learning the ICM model. + :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to + the inverse model loss. + Controls the trade-off between state prediction and action prediction in the ICM algorithm. + Higher values (> 0.5) prioritize learning to predict next states given current states and actions. + Lower values (< 0.5) prioritize learning to predict actions given current and next states. + The total loss combines both components: + (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. """ self.model = model self.optim = optim @@ -110,8 +126,24 @@ def __init__( :param wrapped_algorithm: the base algorithm to which we want to add the ICM. :param model: the ICM model. :param optim: the optimizer factory for the ICM model. - :param lr_scale: the scaling factor for ICM learning. - :param forward_loss_weight: the weight for forward model loss. + :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. + Higher values increase the step size during optimization of the intrinsic curiosity module. + Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. + This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. + :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven + rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided + by the environment). + Scales the prediction error (curiosity signal) before adding it to the environment rewards. + Higher values increase the agent's motivation to explore novel states. + Lower values decrease the influence of curiosity relative to task-specific rewards. + Setting to zero effectively disables intrinsic motivation while still learning the ICM model. + :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to + the inverse model loss. + Controls the trade-off between state prediction and action prediction in the ICM algorithm. + Higher values (> 0.5) prioritize learning to predict next states given current states and actions. + Lower values (< 0.5) prioritize learning to predict actions given current and next states. + The total loss combines both components: + (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. """ OffPolicyWrapperAlgorithm.__init__( self, @@ -169,8 +201,24 @@ def __init__( :param wrapped_algorithm: the base algorithm to which we want to add the ICM. :param model: the ICM model. :param optim: the optimizer factory for the ICM model. - :param lr_scale: the scaling factor for ICM learning. - :param forward_loss_weight: the weight for forward model loss. + :param lr_scale: a multiplier that effectively scales the learning rate for the ICM model updates. + Higher values increase the step size during optimization of the intrinsic curiosity module. + Lower values decrease the step size, leading to more gradual learning of the curiosity mechanism. + This parameter offers an alternative to directly adjusting the base learning rate in the optimizer. + :param reward_scale: a multiplier that controls the magnitude of intrinsic rewards (curiosity-driven + rewards generated by the agent itself) relative to extrinsic rewards (task-specific rewards provided + by the environment). + Scales the prediction error (curiosity signal) before adding it to the environment rewards. + Higher values increase the agent's motivation to explore novel states. + Lower values decrease the influence of curiosity relative to task-specific rewards. + Setting to zero effectively disables intrinsic motivation while still learning the ICM model. + :param forward_loss_weight: relative importance in [0, 1] of the forward model loss in relation to + the inverse model loss. + Controls the trade-off between state prediction and action prediction in the ICM algorithm. + Higher values (> 0.5) prioritize learning to predict next states given current states and actions. + Lower values (< 0.5) prioritize learning to predict actions given current and next states. + The total loss combines both components: + (1-forward_loss_weight)*inverse_loss + forward_loss_weight*forward_loss. """ OnPolicyWrapperAlgorithm.__init__( self, From 225012f8ee4d79c5fad1088ac462fad825913c88 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 13:14:15 +0200 Subject: [PATCH 128/230] v2: Complete removal of reward_normalization/return_scaling parameter from Q-learning algorithm (fixup of 3f8948b6aca653e1380abf4aba8afda2e9b736d0) --- tianshou/policy/imitation/discrete_cql.py | 5 ----- tianshou/policy/modelfree/bdqn.py | 4 ---- tianshou/policy/modelfree/c51.py | 4 ---- tianshou/policy/modelfree/dqn.py | 4 ---- tianshou/policy/modelfree/fqf.py | 5 ----- tianshou/policy/modelfree/iqn.py | 5 ----- tianshou/policy/modelfree/qrdqn.py | 4 ---- 7 files changed, 31 deletions(-) diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 884310044..38c9e589c 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -33,7 +33,6 @@ def __init__( num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, - return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -57,9 +56,6 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based - on running mean and standard deviation. - Support for this is currently suspended and therefore the flag should not be enabled. """ QRDQN.__init__( self, @@ -69,7 +65,6 @@ def __init__( num_quantiles=num_quantiles, estimation_step=estimation_step, target_update_freq=target_update_freq, - return_scaling=return_scaling, ) self.min_q_weight = min_q_weight diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index ef1655341..35606db67 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -109,7 +109,6 @@ def __init__( gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, - return_scaling: bool = False, is_double: bool = True, ) -> None: """ @@ -131,9 +130,6 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based - on running mean and standard deviation. - Support for this is currently suspended and therefore the flag should not be enabled. :param is_double: whether to use double DQN. """ assert ( diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index c7414cd50..e20cbdfb9 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -78,7 +78,6 @@ def __init__( gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, - return_scaling: bool = False, ) -> None: """ :param policy: a policy following the rules (s -> action_values_BA) @@ -99,9 +98,6 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based - on running mean and standard deviation. - Support for this is currently suspended and therefore the flag should not be enabled. """ super().__init__( policy=policy, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 41d77d970..962926c74 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -293,7 +293,6 @@ def __init__( gamma: float = 0.99, estimation_step: int = 1, target_update_freq: int = 0, - return_scaling: bool = False, is_double: bool = True, clip_loss_grad: bool = False, ) -> None: @@ -316,9 +315,6 @@ def __init__( complete episode returns. :param target_update_freq: the frequency with which to update the weights of the target network; 0 if a target network shall not be used. - :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based - on running mean and standard deviation. - Support for this is currently suspended and therefore the flag should not be enabled. :param is_double: use double dqn. :param clip_loss_grad: clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 6c101b9d6..af9acfbfa 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -123,7 +123,6 @@ def __init__( ent_coef: float = 0.0, estimation_step: int = 1, target_update_freq: int = 0, - return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -152,9 +151,6 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based - on running mean and standard deviation. - Support for this is currently suspended and therefore the flag should not be enabled. """ super().__init__( policy=policy, @@ -163,7 +159,6 @@ def __init__( num_quantiles=num_fractions, estimation_step=estimation_step, target_update_freq=target_update_freq, - return_scaling=return_scaling, ) self.ent_coef = ent_coef self.fraction_optim = self._create_optimizer(self.policy.fraction_model, fraction_optim) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 54363b949..d7c0704ec 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -115,7 +115,6 @@ def __init__( num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, - return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -138,9 +137,6 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based - on running mean and standard deviation. - Support for this is currently suspended and therefore the flag should not be enabled. """ super().__init__( policy=policy, @@ -149,7 +145,6 @@ def __init__( num_quantiles=num_quantiles, estimation_step=estimation_step, target_update_freq=target_update_freq, - return_scaling=return_scaling, ) def _update_with_batch( diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 56be073d3..819bbffdc 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -38,7 +38,6 @@ def __init__( num_quantiles: int = 200, estimation_step: int = 1, target_update_freq: int = 0, - return_scaling: bool = False, ) -> None: """ :param policy: the policy @@ -61,9 +60,6 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param return_scaling: flag indicating whether to scale/standardise returns to Normal(0, 1) based - on running mean and standard deviation. - Support for this is currently suspended and therefore the flag should not be enabled. """ assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" super().__init__( From 05e0d0dbff6acc6eb1c85c8c02c9e183a0a0b086 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 13:51:34 +0200 Subject: [PATCH 129/230] v2: Improve description of parameter 'is_double' --- tianshou/highlevel/params/policy_params.py | 10 +++++++++- tianshou/policy/modelfree/bdqn.py | 9 ++++++--- tianshou/policy/modelfree/dqn.py | 10 +++++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index ec4478ccd..11432f098 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -677,7 +677,15 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) class DQNParams(QLearningOffPolicyParams): is_double: bool = True - """whether to use double Q learning""" + """ + flag indicating whether to use the Double DQN algorithm for target value computation. + If True, the algorithm uses the online network to select actions and the target network to + evaluate their Q-values. This approach helps reduce the overestimation bias in Q-learning + by decoupling action selection from action evaluation. + If False, the algorithm follows the vanilla DQN method that directly takes the maximum Q-value + from the target network. + Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). + """ clip_loss_grad: bool = False """whether to clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of the MSE loss.""" diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 35606db67..069c7f2d5 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -128,9 +128,12 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param is_double: whether to use double DQN. + :param target_update_freq: the target network update frequency (0 if a target network shall not be used). + :param is_double: flag indicating whether to use Double Q-learning for target value calculation. + If True, the algorithm uses the online network to select actions and the target network to evaluate their Q-values. + This decoupling helps reduce the overestimation bias that standard Q-learning is prone to. + If False, the algorithm selects actions by directly taking the maximum Q-value from the target network. + Note: This parameter is most effective when used with a target network (target_update_freq > 0). """ assert ( estimation_step == 1 diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 962926c74..21ac176fa 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,3 +1,4 @@ +import logging from abc import ABC, abstractmethod from typing import Any, Generic, TypeVar, cast @@ -31,6 +32,7 @@ mark_used(ActBatchProtocol) TModel = TypeVar("TModel", bound=torch.nn.Module | Net) +log = logging.getLogger(__name__) class DQNPolicy(Policy, Generic[TModel]): @@ -315,7 +317,13 @@ def __init__( complete episode returns. :param target_update_freq: the frequency with which to update the weights of the target network; 0 if a target network shall not be used. - :param is_double: use double dqn. + :param is_double: flag indicating whether to use the Double DQN algorithm for target value computation. + If True, the algorithm uses the online network to select actions and the target network to + evaluate their Q-values. This approach helps reduce the overestimation bias in Q-learning + by decoupling action selection from action evaluation. + If False, the algorithm follows the vanilla DQN method that directly takes the maximum Q-value + from the target network. + Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). :param clip_loss_grad: clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber loss instead of the MSE loss. From 02b3f01697bdca769a202c03c228546864aef607 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 14:03:32 +0200 Subject: [PATCH 130/230] v2: DiscreteBCQ: Remove parameters 'is_double' and 'clip_loss_grad' (actually unused, were only passed on to former base class) --- CHANGELOG.md | 4 ++++ tianshou/policy/imitation/discrete_bcq.py | 8 -------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 928689012..cf49e3619 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -128,6 +128,10 @@ * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). * Remove parameter `estimation_step`, which was not actually used (it was only passed it on to its superclass). + * `DiscreteBCQ`: + * Inherit directly from `OfflineAlgorithm` instead of `DQN` + * Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + former the base class but actually unused. * `DiscreteCQL`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to base class `QRDQN` (and unused by it). * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 7022c00f2..81fedfbe5 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -118,8 +118,6 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 8000, imitation_logits_penalty: float = 1e-2, - is_double: bool = True, - clip_loss_grad: bool = False, ) -> None: """ :param policy: the policy @@ -150,10 +148,6 @@ def __init__( complete episode returns. :param target_update_freq: the target network update frequency (0 if you do not use the target network). - :param is_double: use double dqn. - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. """ super().__init__( policy=policy, @@ -171,8 +165,6 @@ def __init__( self._iter = 0 if self._target: self.model_old = self._add_lagged_network(self.policy.model) - self.is_double = is_double - self.clip_loss_grad = clip_loss_grad self._weight_reg = imitation_logits_penalty def _preprocess_batch( From 3ed750cd1d7da0462943fcb0b30b54ca701ead17 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 14:14:56 +0200 Subject: [PATCH 131/230] v2: Improve description of parameter 'clip_loss_grad' --- tianshou/highlevel/params/policy_params.py | 9 +++++++-- tianshou/policy/modelfree/dqn.py | 8 +++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 11432f098..5353b5815 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -687,8 +687,13 @@ class DQNParams(QLearningOffPolicyParams): Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). """ clip_loss_grad: bool = False - """whether to clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber - loss instead of the MSE loss.""" + """ + flag indicating whether to use the Huber loss instead of the MSE loss for the TD error. + If True, uses the Huber loss as described in the Nature DQN paper (nature14236), which limits the influence + of outliers. Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber + loss causes the gradients to plateau at a constant value for large errors, providing more stable training. + If False, uses the standard MSE loss where the gradient magnitude continues to scale with the error size. + """ def _get_param_transformers(self) -> list[ParamTransformer]: return super()._get_param_transformers() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 21ac176fa..c8a5a888a 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -324,9 +324,11 @@ def __init__( If False, the algorithm follows the vanilla DQN method that directly takes the maximum Q-value from the target network. Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). - :param clip_loss_grad: clip the gradient of the loss in accordance - with nature14236; this amounts to using the Huber loss instead of - the MSE loss. + :param clip_loss_grad: flag indicating whether to use the Huber loss instead of the MSE loss for the TD error. + If True, uses the Huber loss as described in the Nature DQN paper (nature14236), which limits the influence + of outliers. Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber + loss causes the gradients to plateau at a constant value for large errors, providing more stable training. + If False, uses the standard MSE loss where the gradient magnitude continues to scale with the error size. """ super().__init__( policy=policy, From 1cf3d47243fbb4631064d14d782eb0b576551b3d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 14:32:58 +0200 Subject: [PATCH 132/230] v2: Update DDPG parameter docstrings [addendum] --- tianshou/policy/modelfree/ddpg.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index a67212c1c..17f8c5606 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -125,7 +125,7 @@ def __init__( set to None for discrete action spaces. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). - :param action_space: Env's action space. + :param action_space: the environment's action space. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's @@ -135,8 +135,17 @@ def __init__( stability. Typically set to a small value (0.001 to 0.01) for most reinforcement learning tasks. :param observation_space: the environment's observation space. - :param action_scaling: if True, scale the action from [-1, 1] to the range - of action_space. Only used if the action_space is continuous. + :param action_scaling: flag indicating whether, for continuous action spaces, actions + should be scaled from the standard neural network output range [-1, 1] to the + environment's action space range [action_space.low, action_space.high]. + This applies to continuous action spaces only (gym.spaces.Box) and has no effect + for discrete spaces. + When enabled, policy outputs are expected to be in the normalized range [-1, 1] + (after bounding), and are then linearly transformed to the actual required range. + This improves neural network training stability, allows the same algorithm to work + across environments with different action ranges, and standardizes exploration + strategies. + Should be disabled if the actor model already produces outputs in the correct range. :param action_bound_method: method to bound action to range [-1, 1]. """ if action_scaling and not np.isclose(actor.max_action, 1.0): From eea6ba721472bf76016f00976db146a3ee6a47b0 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 18:15:54 +0200 Subject: [PATCH 133/230] v2: Improve description of parameter 'policy_noise' --- tianshou/highlevel/params/policy_params.py | 9 ++++++++- tianshou/policy/imitation/td3_bc.py | 7 ++++++- tianshou/policy/modelfree/td3.py | 7 ++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 5353b5815..55e3a88de 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -765,7 +765,14 @@ class TD3Params( ParamsMixinTau, ): policy_noise: float | FloatEnvValueFactory = 0.2 - """the scale of the the noise used in updating policy network""" + """ + scaling factor for the Gaussian noise added to target policy actions. + This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. + The noise is sampled from a normal distribution and multiplied by this value before being added to actions. + Higher values increase exploration in the target policy, helping to address function approximation error. + The added noise is optionally clipped to a range determined by the noise_clip parameter. + Typically set between 0.1 and 0.5 relative to the action scale of the environment. + """ noise_clip: float | FloatEnvValueFactory = 0.5 """determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]""" update_actor_freq: int = 2 diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 68efd1f75..755dd58a9 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -59,7 +59,12 @@ def __init__( :param exploration_noise: add noise to action for exploration. This is useful when solving "hard exploration" problems. "default" is equivalent to GaussianNoise(sigma=0.1). - :param policy_noise: the noise used in updating policy network. + :param policy_noise: scaling factor for the Gaussian noise added to target policy actions. + This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. + The noise is sampled from a normal distribution and multiplied by this value before being added to actions. + Higher values increase exploration in the target policy, helping to address function approximation error. + The added noise is optionally clipped to a range determined by the noise_clip parameter. + Typically set between 0.1 and 0.5 relative to the action scale of the environment. :param update_actor_freq: the update frequency of actor network. :param noise_clip: the clipping range used in updating policy network. :param alpha: the value of alpha, which controls the weight for TD3 learning diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index e28913713..01534dcd2 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -147,7 +147,12 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param policy_noise: the noise used in updating policy network. + :param policy_noise: scaling factor for the Gaussian noise added to target policy actions. + This parameter implements target policy smoothing, a regularization technique described in the TD3 paper. + The noise is sampled from a normal distribution and multiplied by this value before being added to actions. + Higher values increase exploration in the target policy, helping to address function approximation error. + The added noise is optionally clipped to a range determined by the noise_clip parameter. + Typically set between 0.1 and 0.5 relative to the action scale of the environment. :param update_actor_freq: the update frequency of actor network. :param noise_clip: the clipping range used in updating policy network. """ From 650ae77b2b9696b0b2f544f5617bdf55267b0564 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 18:20:03 +0200 Subject: [PATCH 134/230] v2: Improve description of parameter 'update_actor_freq' --- tianshou/highlevel/params/policy_params.py | 11 ++++++++++- tianshou/policy/imitation/td3_bc.py | 9 ++++++++- tianshou/policy/modelfree/td3.py | 9 ++++++++- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 55e3a88de..f501d8196 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -776,7 +776,16 @@ class TD3Params( noise_clip: float | FloatEnvValueFactory = 0.5 """determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]""" update_actor_freq: int = 2 - """the update frequency of actor network""" + """ + the frequency of actor network updates relative to critic network updates + (the actor network is only updated once for every `update_actor_freq` critic updates). + This implements the "delayed" policy updates from the TD3 algorithm, where the actor is + updated less frequently than the critics. + Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more + accurate before updating the policy. + The default value of 2 follows the original TD3 paper's recommendation of updating the + policy at half the rate of the Q-functions. + """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 755dd58a9..78ce6bd6b 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -65,7 +65,14 @@ def __init__( Higher values increase exploration in the target policy, helping to address function approximation error. The added noise is optionally clipped to a range determined by the noise_clip parameter. Typically set between 0.1 and 0.5 relative to the action scale of the environment. - :param update_actor_freq: the update frequency of actor network. + :param update_actor_freq: the frequency of actor network updates relative to critic network updates + (the actor network is only updated once for every `update_actor_freq` critic updates). + This implements the "delayed" policy updates from the TD3 algorithm, where the actor is + updated less frequently than the critics. + Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more + accurate before updating the policy. + The default value of 2 follows the original TD3 paper's recommendation of updating the + policy at half the rate of the Q-functions. :param noise_clip: the clipping range used in updating policy network. :param alpha: the value of alpha, which controls the weight for TD3 learning relative to behavior cloning. diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 01534dcd2..21d8f3941 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -153,7 +153,14 @@ def __init__( Higher values increase exploration in the target policy, helping to address function approximation error. The added noise is optionally clipped to a range determined by the noise_clip parameter. Typically set between 0.1 and 0.5 relative to the action scale of the environment. - :param update_actor_freq: the update frequency of actor network. + :param update_actor_freq: the frequency of actor network updates relative to critic network updates + (the actor network is only updated once for every `update_actor_freq` critic updates). + This implements the "delayed" policy updates from the TD3 algorithm, where the actor is + updated less frequently than the critics. + Higher values (e.g., 2-5) help stabilize training by allowing the critic to become more + accurate before updating the policy. + The default value of 2 follows the original TD3 paper's recommendation of updating the + policy at half the rate of the Q-functions. :param noise_clip: the clipping range used in updating policy network. """ super().__init__( From 11b2320902a7c110e292684e82466c5aa98ea9d7 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 6 May 2025 18:26:46 +0200 Subject: [PATCH 135/230] v2: Improve description of parameter 'noise_clip' --- tianshou/highlevel/params/policy_params.py | 10 +++++++++- tianshou/policy/imitation/td3_bc.py | 9 +++++++-- tianshou/policy/modelfree/td3.py | 8 +++++++- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index f501d8196..1b252ff01 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -774,7 +774,15 @@ class TD3Params( Typically set between 0.1 and 0.5 relative to the action scale of the environment. """ noise_clip: float | FloatEnvValueFactory = 0.5 - """determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]""" + """ + defines the maximum absolute value of the noise added to target policy actions, i.e. noise values + are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise + via `policy_noise`). + This parameter implements bounded target policy smoothing as described in the TD3 paper. + It prevents extreme noise values from causing unrealistic target values during training. + Setting it 0.0 (or a negative value) disables clipping entirely. + It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). + """ update_actor_freq: int = 2 """ the frequency of actor network updates relative to critic network updates diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 78ce6bd6b..d4d8c5fb1 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -28,7 +28,6 @@ def __init__( policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5, - # TODO: same name as alpha in SAC and REDQ, which also inherit from DDPGPolicy. Rename? alpha: float = 2.5, estimation_step: int = 1, ) -> None: @@ -73,7 +72,13 @@ def __init__( accurate before updating the policy. The default value of 2 follows the original TD3 paper's recommendation of updating the policy at half the rate of the Q-functions. - :param noise_clip: the clipping range used in updating policy network. + :param noise_clip: defines the maximum absolute value of the noise added to target policy actions, i.e. noise values + are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise + via `policy_noise`). + This parameter implements bounded target policy smoothing as described in the TD3 paper. + It prevents extreme noise values from causing unrealistic target values during training. + Setting it 0.0 (or a negative value) disables clipping entirely. + It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). :param alpha: the value of alpha, which controls the weight for TD3 learning relative to behavior cloning. """ diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 21d8f3941..f0f8b7080 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -161,7 +161,13 @@ def __init__( accurate before updating the policy. The default value of 2 follows the original TD3 paper's recommendation of updating the policy at half the rate of the Q-functions. - :param noise_clip: the clipping range used in updating policy network. + :param noise_clip: defines the maximum absolute value of the noise added to target policy actions, i.e. noise values + are clipped to the range [-noise_clip, noise_clip] (after generating and scaling the noise + via `policy_noise`). + This parameter implements bounded target policy smoothing as described in the TD3 paper. + It prevents extreme noise values from causing unrealistic target values during training. + Setting it 0.0 (or a negative value) disables clipping entirely. + It is typically set to about twice the `policy_noise` value (e.g. 0.5 when `policy_noise` is 0.2). """ super().__init__( policy=policy, From aff12e64ee74feb26ed3852e42b9c4407e5e81ad Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 May 2025 17:22:40 +0200 Subject: [PATCH 136/230] v2: Improve descriptions of REDQ parameters ensemble_size, subset_size, actor_delay, target_mode --- tianshou/highlevel/params/policy_params.py | 40 ++++++++++++++++++++-- tianshou/policy/modelfree/redq.py | 35 +++++++++++++++++-- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 1b252ff01..499b018de 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -741,12 +741,46 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) class REDQParams(DDPGParams, ParamsMixinDeterministicEval, ParamsMixinAlpha): ensemble_size: int = 10 - """the number of sub-networks in the critic ensemble""" + """ + the total number of critic networks in the ensemble. + This parameter implements the randomized ensemble approach described in REDQ. + The algorithm maintains `ensemble_size` different critic networks that all share the same architecture. + During target value computation, a random subset of these networks (determined by `subset_size`) is used. + Larger values increase the diversity of the ensemble but require more memory and computation. + The original paper recommends a value of 10 for most tasks, balancing performance and computational efficiency. + """ subset_size: int = 2 - """the number of networks in the subset""" + """ + the number of critic networks randomly selected from the ensemble for computing target Q-values. + During each update, the algorithm samples `subset_size` networks from the ensemble of + `ensemble_size` networks without replacement. + The target Q-value is then calculated as either the minimum or mean (based on target_mode) + of the predictions from this subset. + Smaller values increase randomization and sample efficiency but may introduce more variance. + Larger values provide more stable estimates but reduce the benefits of randomization. + The REDQ paper recommends a value of 2 for optimal sample efficiency. + Must satisfy 0 < subset_size <= ensemble_size. + """ actor_delay: int = 20 - """the number of critic updates before an actor update""" + """ + the number of critic updates performed before each actor update. + The actor network is only updated once for every actor_delay critic updates, implementing + a delayed policy update strategy similar to TD3. + Larger values stabilize training by allowing critics to become more accurate before policy updates. + Smaller values allow the policy to adapt more quickly but may lead to less stable learning. + The REDQ paper recommends a value of 20 for most tasks. + """ target_mode: Literal["mean", "min"] = "min" + """ + the method used to aggregate Q-values from the subset of critic networks. + Can be either "min" or "mean". + If "min", uses the minimum Q-value across the selected subset of critics for each state-action pair. + If "mean", uses the average Q-value across the selected subset of critics. + Using "min" helps prevent overestimation bias but may lead to more conservative value estimates. + Using "mean" provides more optimistic value estimates but may suffer from overestimation bias. + Default is "min" following the conservative value estimation approach common in recent Q-learning + algorithms. + """ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index ba76a334a..6cd209669 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -157,8 +157,24 @@ def __init__( :param policy_optim: the optimizer factory for the policy's model. :param critic: the critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer factory for the critic network. - :param ensemble_size: the number of sub-networks in the critic ensemble. - :param subset_size: the number of networks in the subset. + :param ensemble_size: the total number of critic networks in the ensemble. + This parameter implements the randomized ensemble approach described in REDQ. + The algorithm maintains `ensemble_size` different critic networks that all share the same + architecture. During target value computation, a random subset of these networks (determined + by `subset_size`) is used. + Larger values increase the diversity of the ensemble but require more memory and computation. + The original paper recommends a value of 10 for most tasks, balancing performance and + computational efficiency. + :param subset_size: the number of critic networks randomly selected from the ensemble for + computing target Q-values. + During each update, the algorithm samples `subset_size` networks from the ensemble of + `ensemble_size` networks without replacement. + The target Q-value is then calculated as either the minimum or mean (based on `target_mode`) + of the predictions from this subset. + Smaller values increase randomization and sample efficiency but may introduce more variance. + Larger values provide more stable estimates but reduce the benefits of randomization. + The REDQ paper recommends a value of 2 for optimal sample efficiency. + Must satisfy 0 < subset_size <= ensemble_size. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's @@ -192,7 +208,20 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param actor_delay: Number of critic updates before an actor update. + :param actor_delay: the number of critic updates performed before each actor update. + The actor network is only updated once for every actor_delay critic updates, implementing + a delayed policy update strategy similar to TD3. + Larger values stabilize training by allowing critics to become more accurate before policy updates. + Smaller values allow the policy to adapt more quickly but may lead to less stable learning. + The REDQ paper recommends a value of 20 for most tasks. + :param target_mode: the method used to aggregate Q-values from the subset of critic networks. + Can be either "min" or "mean". + If "min", uses the minimum Q-value across the selected subset of critics for each state-action pair. + If "mean", uses the average Q-value across the selected subset of critics. + Using "min" helps prevent overestimation bias but may lead to more conservative value estimates. + Using "mean" provides more optimistic value estimates but may suffer from overestimation bias. + Default is "min" following the conservative value estimation approach common in recent Q-learning + algorithms. """ if target_mode not in ("min", "mean"): raise ValueError(f"Unsupported target_mode: {target_mode}") From ab49383fb01b2ef7502a161e4bdc567e339d02be Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 May 2025 17:34:15 +0200 Subject: [PATCH 137/230] v2: Improve descriptions of CQL parameters Renamed parameter 'clip_grad' to 'max_grad_norm' for consistency --- CHANGELOG.md | 1 + tianshou/policy/imitation/cql.py | 93 +++++++++++++++++++++++++------- 2 files changed, 76 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf49e3619..9ce8c432e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,7 @@ * `return_standardization` in `Reinforce` and `DiscreteCRR` (as it applies standardization of returns) * `return_scaling` in actor-critic on-policy algorithms (A2C, PPO, GAIL, NPG, TRPO) * removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) + * `clip_grad` -> `max_grad_norm` (for consistency) * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 7be0e3c32..897c85984 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -54,7 +54,7 @@ def __init__( num_repeat_actions: int = 10, alpha_min: float = 0.0, alpha_max: float = 1e6, - clip_grad: float = 1.0, + max_grad_norm: float = 1.0, calibrated: bool = True, ) -> None: """ @@ -68,8 +68,17 @@ def __init__( If None, use the same network as critic (via deepcopy). :param critic2_optim: the optimizer factory for the second critic network. If None, clone the first critic's optimizer factory. - :param cql_alpha_lr: The learning rate of cql_log_alpha. - :param cql_weight: + :param cql_alpha_lr: the learning rate for the Lagrange multiplier optimization. + Controls how quickly the CQL regularization coefficient (alpha) adapts during training. + Higher values allow faster adaptation but may cause instability in the training process. + Lower values provide more stable but slower adaptation of the regularization strength. + Only relevant when with_lagrange=True. + :param cql_weight: the coefficient that scales the conservative regularization term in the Q-function loss. + Controls the strength of the conservative Q-learning component relative to standard TD learning. + Higher values enforce more conservative value estimates by penalizing overestimation more strongly. + Lower values allow the algorithm to behave more like standard Q-learning. + Increasing this weight typically improves performance in purely offline settings where + overestimation bias can lead to poor policy extraction. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. When the parameters of the target network are updated with the current (source) network's @@ -87,19 +96,67 @@ def __init__( Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). - :param temperature: - :param with_lagrange: Whether to use Lagrange. - TODO: extend documentation - what does this mean? - :param lagrange_threshold: The value of tau in CQL(Lagrange). - :param min_action: The minimum value of each dimension of action. - :param max_action: The maximum value of each dimension of action. - :param num_repeat_actions: The number of times the action is repeated when calculating log-sum-exp. - :param alpha_min: Lower bound for clipping cql_alpha. - :param alpha_max: Upper bound for clipping cql_alpha. - :param clip_grad: Clip_grad for updating critic network. - :param calibrated: calibrate Q-values as in CalQL paper `arXiv:2303.05479`. - Useful for offline pre-training followed by online training, - and also was observed to achieve better results than vanilla cql. + :param temperature: the temperature parameter used in the LogSumExp calculation of the CQL loss. + Controls the sharpness of the softmax distribution when computing the expected Q-values. + Lower values make the LogSumExp operation more selective, focusing on the highest Q-values. + Higher values make the operation closer to an average, giving more weight to all Q-values. + The temperature affects how conservatively the algorithm penalizes out-of-distribution actions. + :param with_lagrange: a flag indicating whether to automatically tune the CQL regularization strength. + If True, uses Lagrangian dual gradient descent to dynamically adjust the CQL alpha parameter. + This formulation maintains the CQL regularization loss near the lagrange_threshold value. + Adaptive tuning helps balance conservative learning against excessive pessimism. + If False, the conservative loss is scaled by a fixed cql_weight throughout training. + The original CQL paper recommends setting this to True for most offline RL tasks. + :param lagrange_threshold: the target value for the CQL regularization loss when using Lagrangian optimization. + When with_lagrange=True, the algorithm dynamically adjusts the CQL alpha parameter to maintain + the regularization loss close to this threshold. + Lower values result in more conservative behavior by enforcing stronger penalties on + out-of-distribution actions. + Higher values allow more optimistic Q-value estimates similar to standard Q-learning. + This threshold effectively controls the level of conservatism in CQL's value estimation. + :param min_action: the lower bound for each dimension of the action space. + Used when sampling random actions for the CQL regularization term. + Should match the environment's action space minimum values. + These random actions help penalize Q-values for out-of-distribution actions. + Typically set to -1.0 for normalized continuous action spaces. + :param max_action: the upper bound for each dimension of the action space. + Used when sampling random actions for the CQL regularization term. + Should match the environment's action space maximum values. + These random actions help penalize Q-values for out-of-distribution actions. + Typically set to 1.0 for normalized continuous action spaces. + :param num_repeat_actions: the number of action samples generated per state when computing + the CQL regularization term. + Controls how many random and policy actions are sampled for each state in the batch when + estimating expected Q-values. + Higher values provide more accurate approximation of the expected Q-values but increase + computational cost. + Lower values reduce computation but may provide less stable or less accurate regularization. + The original CQL paper typically uses values around 10. + :param alpha_min: the minimum value allowed for the adaptive CQL regularization coefficient. + When using Lagrangian optimization (with_lagrange=True), constrains the automatically tuned + cql_alpha parameter to be at least this value. + Prevents the regularization strength from becoming too small during training. + Setting a positive value ensures the algorithm maintains at least some degree of conservatism. + Only relevant when with_lagrange=True. + :param alpha_max: the maximum value allowed for the adaptive CQL regularization coefficient. + When using Lagrangian optimization (with_lagrange=True), constrains the automatically tuned + cql_alpha parameter to be at most this value. + Prevents the regularization strength from becoming too large during training. + Setting an appropriate upper limit helps avoid overly conservative behavior that might hinder + learning useful value functions. + Only relevant when with_lagrange=True. + :param max_grad_norm: the maximum L2 norm threshold for gradient clipping when updating critic networks. + Gradients with norm exceeding this value will be rescaled to have norm equal to this value. + Helps stabilize training by preventing excessively large parameter updates from outlier samples. + Higher values allow larger updates but may lead to training instability. + Lower values enforce more conservative updates but may slow down learning. + Setting to a large value effectively disables gradient clipping. + :param calibrated: a flag indicating whether to use the calibrated version of CQL (CalQL). + If True, calibrates Q-values by taking the maximum of computed Q-values and Monte Carlo returns. + This modification helps address the excessive pessimism problem in standard CQL. + Particularly useful for offline pre-training followed by online fine-tuning scenarios. + Experimental results suggest this approach often achieves better performance than vanilla CQL. + Based on techniques from the CalQL paper (arXiv:2303.05479). """ super().__init__( policy=policy, @@ -111,11 +168,11 @@ def __init__( self.policy_optim = self._create_optimizer(self.policy, policy_optim) self.critic = critic self.critic_optim = self._create_optimizer( - self.critic, critic_optim, max_grad_norm=clip_grad + self.critic, critic_optim, max_grad_norm=max_grad_norm ) self.critic2 = critic2 or deepcopy(critic) self.critic2_optim = self._create_optimizer( - self.critic2, critic2_optim or critic_optim, max_grad_norm=clip_grad + self.critic2, critic2_optim or critic_optim, max_grad_norm=max_grad_norm ) self.critic_old = self._add_lagged_network(self.critic) self.critic2_old = self._add_lagged_network(self.critic2) From ab15bfc0af189c49e889a22b773a0c40187bc148 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 May 2025 17:44:14 +0200 Subject: [PATCH 138/230] v2: Improve descriptions of GAIL parameters disc_* --- tianshou/policy/imitation/gail.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index c8ca46af6..c3bb18894 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -60,10 +60,23 @@ def __init__( :param critic: the critic network. (s -> V(s)) :param optim: the optimizer factory for the actor and critic networks. :param expert_buffer: the replay buffer containing expert experience. - :param disc_net: the discriminator network with input dim equals - state dim plus action dim and output dim equals 1. + :param disc_net: the discriminator neural network that distinguishes between expert and policy behaviors. + Takes concatenated state-action pairs [obs, act] as input and outputs an unbounded logit value. + The raw output is transformed in the algorithm using sigmoid functions: o(output) for expert + probability and -log(1-o(-output)) for policy rewards. + Positive output values indicate the discriminator believes the behavior is from an expert. + Negative output values indicate the discriminator believes the behavior is from the policy. + The network architecture should end with a linear layer of output size 1 without any + activation function, as sigmoid operations are applied separately. :param disc_optim: the optimizer factory for the discriminator network. - :param disc_update_num: the number of discriminator grad steps per model grad step. + :param disc_update_num: the number of discriminator update steps performed for each policy update step. + Controls the learning dynamics between the policy and the discriminator. + Higher values strengthen the discriminator relative to the policy, potentially improving + the quality of the reward signal but slowing down training. + Lower values allow faster policy updates but may result in a weaker discriminator that fails + to properly distinguish between expert and policy behaviors. + Typical values range from 1 to 10, with the original GAIL paper using multiple discriminator + updates per policy update. :param eps_clip: determines the range of allowed change in the policy during a policy update: The ratio of action probabilities indicated by the new and old policy is constrained to stay in the interval [1 - eps_clip, 1 + eps_clip]. From 2e42c85fa0abd7216dae178bbdd12ee25e49b330 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 May 2025 17:46:37 +0200 Subject: [PATCH 139/230] v2: Improve description of parameter 'num_quantiles' --- tianshou/policy/modelfree/qrdqn.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 819bbffdc..923d98587 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -49,8 +49,13 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param num_quantiles: the number of quantile midpoints in the inverse - cumulative distribution function of the value. + :param num_quantiles: the number of quantiles used to represent the return distribution for each action. + Determines the granularity of the approximated distribution function. + Higher values provide a more fine-grained approximation of the true return distribution but + increase computational and memory requirements. + Lower values reduce computational cost but may not capture the distribution accurately enough. + The original QRDQN paper used 200 quantiles for Atari environments. + Must be greater than 1, as at least two quantiles are needed to represent a distribution. :param estimation_step: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) From 02c03189b93a7db5f5416d0a05c63edbbd84db99 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 May 2025 17:53:09 +0200 Subject: [PATCH 140/230] v2: Improve description of parameter 'target_update_freq' --- tianshou/highlevel/params/policy_params.py | 13 +++++- tianshou/policy/imitation/discrete_bcq.py | 50 +++++++++++++++++++--- tianshou/policy/imitation/discrete_cql.py | 13 +++++- tianshou/policy/imitation/discrete_crr.py | 13 +++++- tianshou/policy/modelfree/bdqn.py | 12 +++++- tianshou/policy/modelfree/c51.py | 13 +++++- tianshou/policy/modelfree/dqn.py | 26 +++++++++-- tianshou/policy/modelfree/fqf.py | 13 +++++- tianshou/policy/modelfree/iqn.py | 13 +++++- tianshou/policy/modelfree/qrdqn.py | 13 +++++- 10 files changed, 155 insertions(+), 24 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 499b018de..347ea6d74 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -634,7 +634,18 @@ class QLearningOffPolicyParams( Params, ParamsMixinGamma, ParamsMixinSingleModel, ParamsMixinEstimationStep ): target_update_freq: int = 0 - """the target network update frequency (0 if no target network is to be used)""" + """ + the number of training iterations between each complete update of the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on + environment complexity. + """ return_scaling: bool = False """ flag indicating whether to enable scaling of estimated returns by diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 81fedfbe5..774e48f06 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -49,11 +49,30 @@ def __init__( :param model: a model following the rules (s_B -> action_values_BA) :param imitator: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) - :param target_update_freq: the target network update frequency. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. :param unlikely_action_threshold: the threshold (tau) for unlikely actions, as shown in Equ. (17) in the paper. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. :param action_space: the environment's action space. :param observation_space: the environment's observation space. :param eps_inference: the epsilon value for epsilon-greedy exploration during inference, @@ -136,7 +155,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. :param imitation_logits_penalty: regularization weight for imitation logits. :param estimation_step: the number of future steps (> 0) to consider when computing temporal @@ -146,8 +175,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. """ super().__init__( policy=policy, diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 38c9e589c..56632850a 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -54,8 +54,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. """ QRDQN.__init__( self, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 6c1e36d0d..759702e99 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -68,8 +68,17 @@ def __init__( :param beta: when policy_improvement_mode is "exp", this is the denominator of the exp function. :param min_q_weight: weight for CQL loss/regularizer. Default to 10. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. :param return_standardization: whether to standardize episode returns by subtracting the running mean and dividing by the running standard deviation. Note that this is known to be detrimental to performance in many cases! diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 069c7f2d5..100438096 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -128,7 +128,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency (0 if a target network shall not be used). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. :param is_double: flag indicating whether to use Double Q-learning for target value calculation. If True, the algorithm uses the online network to select actions and the target network to evaluate their Q-values. This decoupling helps reduce the overestimation bias that standard Q-learning is prone to. diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index e20cbdfb9..b89b7d940 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -96,8 +96,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. """ super().__init__( policy=policy, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c8a5a888a..85a8ec951 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -212,8 +212,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the frequency with which to update the weights of the target network; - 0 if a target network shall not be used. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. """ super().__init__( policy=policy, @@ -315,8 +324,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the frequency with which to update the weights of the target network; - 0 if a target network shall not be used. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. :param is_double: flag indicating whether to use the Double DQN algorithm for target value computation. If True, the algorithm uses the online network to select actions and the target network to evaluate their Q-values. This approach helps reduce the overestimation bias in Q-learning diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index af9acfbfa..e251473dc 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -149,8 +149,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. """ super().__init__( policy=policy, diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index d7c0704ec..dec4d87a0 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -135,8 +135,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. """ super().__init__( policy=policy, diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 923d98587..c5c6da588 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -63,8 +63,17 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param target_update_freq: the target network update frequency (0 if - you do not use the target network). + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. """ assert num_quantiles > 1, f"num_quantiles should be greater than 1 but got: {num_quantiles}" super().__init__( From a66396bc9f45eb2ab720c7711a6112f6b59b516c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 May 2025 17:54:54 +0200 Subject: [PATCH 141/230] v2: Improve description of parameter 'add_done_loop' --- tianshou/policy/modelbased/psrl.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 9d0124b0e..f52632141 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -233,9 +233,14 @@ def __init__( ) -> None: """ :param policy: the policy - :param add_done_loop: whether to add an extra self-loop for the - terminal state in MDP. Default to False. - :param lr_scheduler: if not None, will be called in `policy.update()`. + :param add_done_loop: a flag indicating whether to add a self-loop transition for terminal states + in the MDP. + If True, whenever an episode terminates, an artificial transition from the terminal state + back to itself is added to the transition counts for all actions. + This modification can help stabilize learning for terminal states that have limited samples. + Setting to True can be beneficial in environments where episodes frequently terminate, + ensuring that terminal states receive sufficient updates to their value estimates. + Default is False, which preserves the standard MDP formulation without artificial self-loops. """ super().__init__( policy=policy, From f459556c26ae7a50974ab63f76be27081ae2b3cc Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 May 2025 17:55:34 +0200 Subject: [PATCH 142/230] v2: Remove docstrings for removed parameters --- tianshou/policy/modelfree/discrete_sac.py | 2 -- tianshou/policy/multiagent/mapolicy.py | 1 - 2 files changed, 3 deletions(-) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index e7e217e0c..a3875e35b 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -131,8 +131,6 @@ def __init__( the averaging effect). A value of 1 corresponds to standard TD learning with immediate bootstrapping, while very large values approach Monte Carlo-like estimation that uses complete episode returns. - :param lr_scheduler: a learning rate scheduler that adjusts the learning rate - in optimizer in each policy.update() """ super().__init__( policy=policy, diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 0b4a2fbcc..588ad4a50 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -280,7 +280,6 @@ def __init__( """ :param algorithms: a list of off-policy algorithms. :param env: the multi-agent RL environment - :param lr_scheduler: if not None, will be called in `policy.update()`. """ self._dispatcher: MARLDispatcher[OffPolicyAlgorithm] = MARLDispatcher(algorithms, env) super().__init__( From 0404b7960dbdce124146db9658ef468390413309 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 7 May 2025 18:01:52 +0200 Subject: [PATCH 143/230] v2: Fix/remove references to BasePolicy --- tianshou/data/collector.py | 3 +-- tianshou/policy/base.py | 2 +- tianshou/policy/imitation/cql.py | 3 +-- tianshou/policy/imitation/discrete_bcq.py | 3 +-- tianshou/policy/modelbased/psrl.py | 5 ----- tianshou/policy/modelfree/ddpg.py | 5 ----- tianshou/policy/modelfree/dqn.py | 5 ----- tianshou/policy/modelfree/pg.py | 5 ----- tianshou/policy/modelfree/redq.py | 3 +-- tianshou/trainer/base.py | 2 +- 10 files changed, 6 insertions(+), 30 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 9eaa5487b..f84709de6 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -559,8 +559,7 @@ def __init__( collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: """ - :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch - of actions from a batch of observations. + :param policy: a tianshou policy or algorithm :param env: a ``gymnasium.Env`` environment or a vectorized instance of the :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv` diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3064f70ca..36a1d3260 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -329,7 +329,7 @@ def map_action_inverse( self, act: TArr, ) -> np.ndarray: - """Inverse operation to :meth:`~tianshou.policy.BasePolicy.map_action`. + """Inverse operation to :meth:`map_action`. This function is called in :meth:`~tianshou.data.Collector.collect` for random initial steps. It scales [action_space.low, action_space.high] to diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 897c85984..c9365b56e 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -58,8 +58,7 @@ def __init__( calibrated: bool = True, ) -> None: """ - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param actor: the actor network following the rules (s -> a) :param policy_optim: the optimizer factory for the policy/its actor network. :param critic: the first critic network. :param critic_optim: the optimizer factory for the first critic network. diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 774e48f06..4751fe187 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -47,8 +47,7 @@ def __init__( ) -> None: """ :param model: a model following the rules (s_B -> action_values_BA) - :param imitator: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) + :param imitator: a model following the rules (s -> imitation_logits) :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index f52632141..b7c93ee72 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -207,11 +207,6 @@ def forward( :return: A :class:`~tianshou.data.Batch` with "act" key containing the action. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. """ assert isinstance(batch.obs, np.ndarray), "only support np.ndarray observation" # TODO: shouldn't the model output a state as well if state is passed (i.e. RNNs are involved)? diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 17f8c5606..1c890b9e4 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -178,11 +178,6 @@ def forward( * ``act`` the action. * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. """ if model is None: model = self.actor diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 85a8ec951..4e6956223 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -125,11 +125,6 @@ def forward( * ``act`` the action. * ``logits`` the network's raw output. * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. """ if model is None: model = self.model diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 019ec74b4..4b0c217e0 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -161,11 +161,6 @@ def forward( Will sample from the dist_fn, if appropriate. Returns a new object representing the processed batch data (contrary to other methods that modify the input batch inplace). - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. """ # TODO - ALGO: marked for algorithm refactoring action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 6cd209669..1da3b84af 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -47,8 +47,7 @@ def __init__( observation_space: gym.Space | None = None, ): """ - :param actor: The actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> model_output) + :param actor: The actor network following the rules (s -> model_output) :param action_space: the environment's action_space. :param deterministic_eval: flag indicating whether the policy should use deterministic actions (using the mode of the action distribution) instead of stochastic ones diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 519bb218b..e1e9d8f13 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -147,7 +147,7 @@ class TrainerParams(ToStringMixin): save_best_fn: Callable[["Algorithm"], None] | None = None """ the callback function to call in order to save the best model whenever a new best score (see :attr:`compute_score_fn`) - is achieved in a test step. It should have the signature ``f(policy: BasePolicy) -> None``. + is achieved in a test step. It should have the signature ``f(algorithm: Algorithm) -> None``. """ save_checkpoint_fn: Callable[[int, int, int], str] | None = None From dbfb462dfe7d4c02334476941fdf3ae91ab37f29 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 12 May 2025 17:38:52 +0200 Subject: [PATCH 144/230] v2: Rename DDPGPolicy -> ContinuousDeterministicPolicy --- examples/mujoco/fetch_her_ddpg.py | 4 ++-- examples/mujoco/mujoco_ddpg.py | 4 ++-- examples/mujoco/mujoco_td3.py | 4 ++-- examples/offline/d4rl_td3_bc.py | 4 ++-- test/continuous/test_ddpg.py | 4 ++-- test/continuous/test_td3.py | 4 ++-- test/offline/test_td3_bc.py | 4 ++-- tianshou/highlevel/algorithm.py | 12 +++++++----- tianshou/policy/imitation/td3_bc.py | 6 +++--- tianshou/policy/modelfree/ddpg.py | 8 +++++--- tianshou/policy/modelfree/td3.py | 6 +++--- 11 files changed, 32 insertions(+), 28 deletions(-) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index fa0c04713..061edab8f 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -24,7 +24,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPG from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net, get_dict_state_decorator @@ -170,7 +170,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: ) critic = dict_state_dec(ContinuousCritic)(net_c, device=args.device).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy = DDPGPolicy( + policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 2e7bfa2b6..56f1a233f 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -14,7 +14,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPG from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net @@ -100,7 +100,7 @@ def main(args: argparse.Namespace = get_args()) -> None: ) critic = ContinuousCritic(preprocess_net=net_c).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy = DDPGPolicy( + policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 219147288..06d974535 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -14,7 +14,7 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TD3 from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net @@ -114,7 +114,7 @@ def main(args: argparse.Namespace = get_args()) -> None: critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy = DDPGPolicy( + policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 772a90f0d..7a1224c2a 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -16,7 +16,7 @@ from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BC from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger @@ -133,7 +133,7 @@ def test_td3_bc() -> None: critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy = DDPGPolicy( + policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 698fbaeca..827596f7f 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -11,7 +11,7 @@ from tianshou.exploration import GaussianNoise from tianshou.policy import DDPG from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -87,7 +87,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: ) critic = ContinuousCritic(preprocess_net=net).to(args.device) critic_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy = DDPGPolicy( + policy = ContinuousDeterministicPolicy( actor=actor, exploration_noise=GaussianNoise(sigma=args.exploration_noise), action_space=env.action_space, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index e718f99eb..0a1e55c51 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,7 +11,7 @@ from tianshou.exploration import GaussianNoise from tianshou.policy import TD3 from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -99,7 +99,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: ) critic2 = ContinuousCritic(preprocess_net=net_c2).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) - policy = DDPGPolicy( + policy = ContinuousDeterministicPolicy( actor=actor, action_space=env.action_space, exploration_noise=GaussianNoise(sigma=args.exploration_noise), diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 5cfe864ff..b0675b673 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -14,7 +14,7 @@ from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BC from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger @@ -125,7 +125,7 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) # policy and algorithm - policy = DDPGPolicy( + policy = ContinuousDeterministicPolicy( actor=actor, action_space=env.action_space, exploration_noise=GaussianNoise(sigma=args.exploration_noise), diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 5c5b53dd0..4f62abf3c 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -66,7 +66,7 @@ OnPolicyAlgorithm, Policy, ) -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.iqn import IQNPolicy @@ -528,7 +528,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: ), ) policy = self._create_policy_from_args( - DDPGPolicy, + ContinuousDeterministicPolicy, kwargs, ["exploration_noise", "action_scaling", "action_bound_method"], actor=actor, @@ -699,12 +699,14 @@ def _get_algorithm_class(self) -> type[DiscreteSAC]: return DiscreteSAC -class TD3AlgorithmFactory(ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params, TD3, DDPGPolicy]): +class TD3AlgorithmFactory( + ActorDualCriticsOffPolicyAlgorithmFactory[TD3Params, TD3, ContinuousDeterministicPolicy] +): def _create_policy( self, actor: torch.nn.Module | DiscreteActor, envs: Environments, params: dict - ) -> DDPGPolicy: + ) -> ContinuousDeterministicPolicy: return self._create_policy_from_args( - DDPGPolicy, + ContinuousDeterministicPolicy, params, ["exploration_noise", "action_scaling", "action_bound_method"], actor=actor, diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index d4d8c5fb1..b6ed83086 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -5,19 +5,19 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import TD3 from tianshou.policy.base import OfflineAlgorithm -from tianshou.policy.modelfree.ddpg import DDPGPolicy +from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.modelfree.td3 import TD3TrainingStats from tianshou.policy.optim import OptimizerFactory # NOTE: This uses diamond inheritance to convert from off-policy to offline -class TD3BC(OfflineAlgorithm[DDPGPolicy], TD3): # type: ignore +class TD3BC(OfflineAlgorithm[ContinuousDeterministicPolicy], TD3): # type: ignore """Implementation of TD3+BC. arXiv:2106.06860.""" def __init__( self, *, - policy: DDPGPolicy, + policy: ContinuousDeterministicPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 1c890b9e4..2335b787e 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -108,7 +108,9 @@ def add_exploration_noise( return act -class DDPGPolicy(ContinuousPolicyWithExplorationNoise): +class ContinuousDeterministicPolicy(ContinuousPolicyWithExplorationNoise): + """A policy for continuous action spaces that uses an actor which directly maps states to actions.""" + def __init__( self, *, @@ -335,14 +337,14 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: class DDPG( - ActorCriticOffPolicyAlgorithm[DDPGPolicy, ActBatchProtocol], + ActorCriticOffPolicyAlgorithm[ContinuousDeterministicPolicy, ActBatchProtocol], ): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.""" def __init__( self, *, - policy: DDPGPolicy, + policy: ContinuousDeterministicPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module | ContinuousCritic, critic_optim: OptimizerFactory, diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index f0f8b7080..cca3e4fdf 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -16,7 +16,7 @@ ) from tianshou.policy.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, - DDPGPolicy, + ContinuousDeterministicPolicy, TActBatchProtocol, ) from tianshou.policy.optim import OptimizerFactory @@ -103,14 +103,14 @@ def _target_q_compute_value( class TD3( - ActorDualCriticsOffPolicyAlgorithm[DDPGPolicy, ActStateBatchProtocol], + ActorDualCriticsOffPolicyAlgorithm[ContinuousDeterministicPolicy, ActStateBatchProtocol], ): """Implementation of TD3, arXiv:1802.09477.""" def __init__( self, *, - policy: DDPGPolicy, + policy: ContinuousDeterministicPolicy, policy_optim: OptimizerFactory, critic: torch.nn.Module, critic_optim: OptimizerFactory, From d5a5289033cc18ccbfd2661b2825083f97850171 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 12 May 2025 17:55:59 +0200 Subject: [PATCH 145/230] v2: Rename DQNPolicy -> DiscreteQLearningPolicy --- examples/atari/atari_dqn.py | 4 ++-- examples/box2d/acrobot_dualdqn.py | 4 ++-- examples/box2d/lunarlander_dqn.py | 4 ++-- examples/discrete/discrete_dqn.py | 4 ++-- test/discrete/test_dqn.py | 4 ++-- test/discrete/test_drqn.py | 4 ++-- test/modelbased/test_dqn_icm.py | 4 ++-- test/pettingzoo/pistonball.py | 4 ++-- test/pettingzoo/tic_tac_toe.py | 4 ++-- tianshou/highlevel/algorithm.py | 4 ++-- tianshou/highlevel/trainer.py | 8 ++++---- tianshou/policy/imitation/discrete_bcq.py | 4 ++-- tianshou/policy/modelfree/bdqn.py | 4 ++-- tianshou/policy/modelfree/c51.py | 4 ++-- tianshou/policy/modelfree/dqn.py | 4 ++-- tianshou/policy/modelfree/fqf.py | 6 ++++-- tianshou/policy/modelfree/qrdqn.py | 4 ++-- 17 files changed, 38 insertions(+), 36 deletions(-) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 482f92ef9..97a6d1299 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -14,7 +14,7 @@ from tianshou.policy import DQN from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import IntrinsicCuriosityModule @@ -110,7 +110,7 @@ def main(args: argparse.Namespace = get_args()) -> None: optim = AdamOptimizerFactory(lr=args.lr) # define policy and algorithm - policy = DQNPolicy( + policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 9fa3d0242..1f1f16c1c 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DQN from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -76,7 +76,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: dueling_param=(Q_param, V_param), ) optim = AdamOptimizerFactory(lr=args.lr) - policy = DQNPolicy( + policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index f5d8c217b..5744451d1 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import DQN from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -78,7 +78,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: dueling_param=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) - policy = DQNPolicy( + policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 2d23c1242..3fbfb7801 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -3,7 +3,7 @@ import tianshou as ts from tianshou.data import CollectStats -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils.space_info import SpaceInfo @@ -37,7 +37,7 @@ def main() -> None: net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = AdamOptimizerFactory(lr=lr) - policy = DQNPolicy( + policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test ) algorithm: ts.policy.DQN = ts.policy.DQN( diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 9a457d358..936680f75 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -17,7 +17,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DQN from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -88,7 +88,7 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # dueling=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) - policy = DQNPolicy( + policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, observation_space=env.observation_space, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index a9df8f1ec..88cddbb62 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -10,7 +10,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import DQN from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -78,7 +78,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: args.device, ) optim = AdamOptimizerFactory(lr=args.lr) - policy = DQNPolicy( + policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 9813d5004..d314ec2ab 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -14,7 +14,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.policy import DQN, Algorithm, ICMOffPolicyWrapper -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -106,7 +106,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: # dueling=(Q_param, V_param), ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) - policy = DQNPolicy( + policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index c958d3fbd..1634581fe 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -13,7 +13,7 @@ from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import DQN, Algorithm, MultiAgentOffPolicyAlgorithm from tianshou.policy.base import OffPolicyAlgorithm -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -102,7 +102,7 @@ def get_agents( hidden_sizes=args.hidden_sizes, ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) - policy = DQNPolicy( + policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index ce7aca006..e4ea9ba2d 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -20,7 +20,7 @@ MultiAgentOffPolicyAlgorithm, ) from tianshou.policy.base import OffPolicyAlgorithm -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -122,7 +122,7 @@ def get_agents( ).to(args.device) if optim is None: optim = AdamOptimizerFactory(lr=args.lr) - algorithm = DQNPolicy( + algorithm = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=args.eps_train, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 4f62abf3c..cb20b0c28 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -68,7 +68,7 @@ ) from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.modelfree.pg import ActorPolicy from tianshou.policy.modelfree.redq import REDQPolicy @@ -458,7 +458,7 @@ def _create_policy( observation_space: gymnasium.spaces.Space, ) -> Policy: return self._create_policy_from_args( - constructor=DQNPolicy, + constructor=DiscreteQLearningPolicy, params_dict=params, policy_params=["eps_training", "eps_inference"], model=model, diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index b462f6fd4..8ae3f3b7c 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -9,7 +9,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger from tianshou.policy import DQN, Algorithm -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) log = logging.getLogger(__name__) @@ -92,7 +92,7 @@ def __init__(self, eps: float): def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) - policy: DQNPolicy = algorithm.policy + policy: DiscreteQLearningPolicy = algorithm.policy policy.set_eps_training(self.eps) @@ -108,7 +108,7 @@ def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) - policy: DQNPolicy = algorithm.policy + policy: DiscreteQLearningPolicy = algorithm.policy logger = context.logger if env_step <= self.decay_steps: eps = self.eps_train - env_step / self.decay_steps * ( @@ -130,7 +130,7 @@ def __init__(self, eps: float): def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None: algorithm = cast(DQN, context.algorithm) - policy: DQNPolicy = algorithm.policy + policy: DiscreteQLearningPolicy = algorithm.policy policy.set_eps_inference(self.eps) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 4751fe187..6a66f986c 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -18,7 +18,7 @@ LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, ) -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.modelfree.pg import SimpleLossTrainingStats from tianshou.policy.optim import OptimizerFactory @@ -33,7 +33,7 @@ class DiscreteBCQTrainingStats(SimpleLossTrainingStats): reg_loss: float -class DiscreteBCQPolicy(DQNPolicy): +class DiscreteBCQPolicy(DiscreteQLearningPolicy): def __init__( self, *, diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 100438096..195a22666 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -16,7 +16,7 @@ ) from tianshou.policy.base import TArrOrActBatch from tianshou.policy.modelfree.dqn import ( - DQNPolicy, + DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) from tianshou.policy.modelfree.pg import SimpleLossTrainingStats @@ -26,7 +26,7 @@ mark_used(ActBatchProtocol) -class BDQNPolicy(DQNPolicy[BranchingNet]): +class BDQNPolicy(DiscreteQLearningPolicy[BranchingNet]): def __init__( self, *, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index b89b7d940..01dbbbeab 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -5,7 +5,7 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.modelfree.dqn import ( - DQNPolicy, + DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) from tianshou.policy.modelfree.pg import LossSequenceTrainingStats @@ -13,7 +13,7 @@ from tianshou.utils.net.common import Net -class C51Policy(DQNPolicy): +class C51Policy(DiscreteQLearningPolicy): def __init__( self, model: torch.nn.Module | Net, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 4e6956223..504ec5b80 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -35,7 +35,7 @@ log = logging.getLogger(__name__) -class DQNPolicy(Policy, Generic[TModel]): +class DiscreteQLearningPolicy(Policy, Generic[TModel]): def __init__( self, *, @@ -168,7 +168,7 @@ def add_exploration_noise( return act -TDQNPolicy = TypeVar("TDQNPolicy", bound=DQNPolicy) +TDQNPolicy = TypeVar("TDQNPolicy", bound=DiscreteQLearningPolicy) class QLearningOffPolicyAlgorithm( diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index e251473dc..c709056b9 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -10,7 +10,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import QRDQN, Algorithm -from tianshou.policy.modelfree.dqn import DQNPolicy +from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.modelfree.pg import SimpleLossTrainingStats from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.optim import OptimizerFactory @@ -92,7 +92,9 @@ def forward( # type: ignore info=batch.info, ) weighted_logits = (fractions.taus[:, 1:] - fractions.taus[:, :-1]).unsqueeze(1) * logits - q = DQNPolicy.compute_q_value(self, weighted_logits.sum(2), getattr(obs, "mask", None)) + q = DiscreteQLearningPolicy.compute_q_value( + self, weighted_logits.sum(2), getattr(obs, "mask", None) + ) if self.max_action_num is None: # type: ignore # TODO: see same thing in DQNPolicy! Also reduce code duplication. self.max_action_num = q.shape[1] diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index c5c6da588..8f9daa58c 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -8,14 +8,14 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.modelfree.dqn import ( - DQNPolicy, + DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) from tianshou.policy.modelfree.pg import SimpleLossTrainingStats from tianshou.policy.optim import OptimizerFactory -class QRDQNPolicy(DQNPolicy): +class QRDQNPolicy(DiscreteQLearningPolicy): def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: return super().compute_q_value(logits.mean(2), mask) From 24532bab1faf45eab36c81a30b37f8f68c0cf38e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 12 May 2025 18:19:19 +0200 Subject: [PATCH 146/230] v2: Fix references to SamplingConfig/sampling_config --- tianshou/highlevel/experiment.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 3da0344bd..ed275ac83 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -544,12 +544,12 @@ def experiment_config(self, experiment_config: ExperimentConfig) -> None: self._config = experiment_config @property - def sampling_config(self) -> TrainingConfig: + def training_config(self) -> TrainingConfig: return self._training_config - @sampling_config.setter - def sampling_config(self, sampling_config: TrainingConfig) -> None: - self._training_config = sampling_config + @training_config.setter + def training_config(self, config: TrainingConfig) -> None: + self._training_config = config def with_logger_factory(self, logger_factory: LoggerFactory) -> Self: """Allows to customize the logger factory to use. @@ -669,13 +669,13 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: Each experiment in the collection will have a unique name created from the original experiment name and the seeds used. """ - num_train_envs = self.sampling_config.num_train_envs + num_train_envs = self.training_config.num_train_envs seeded_experiments = [] for i in range(num_experiments): builder = self.copy() builder.experiment_config.seed += i - builder.sampling_config.train_seed += i * num_train_envs + builder.training_config.train_seed += i * num_train_envs experiment = builder.build() experiment.name += f"_{experiment.get_seeding_info_as_str()}" seeded_experiments.append(experiment) From c347f40acf5d2fb97c962d120ac0a06a3f264e16 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 12 May 2025 18:22:33 +0200 Subject: [PATCH 147/230] v2: Update references to reward_normalization parameter in high-level examples --- examples/atari/atari_ppo_hl.py | 2 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl_multi.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 92304946d..ab40dd342 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -85,7 +85,7 @@ def main( PPOParams( gamma=gamma, gae_lambda=gae_lambda, - reward_normalization=rew_norm, + return_scaling=rew_norm, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 788dc1e18..8170f46d0 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -68,7 +68,7 @@ def main( gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - reward_normalization=rew_norm, + return_scaling=rew_norm, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index b0c021e1f..12b9365c2 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -72,7 +72,7 @@ def main( gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - reward_normalization=rew_norm, + return_scaling=rew_norm, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index cc2ccedc9..fa12df494 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -101,7 +101,7 @@ def main( gamma=0.99, gae_lambda=0.95, action_bound_method="clip", - reward_normalization=True, + return_scaling=True, ent_coef=0.0, vf_coef=0.25, max_grad_norm=0.5, From 6639d9d2c4fabe2c1e84d4bab9c42477d185f5e6 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 12 May 2025 18:39:51 +0200 Subject: [PATCH 148/230] v2: Improve description of 'action_boun_method' in SACPolicy (largely no effect) --- tianshou/policy/modelfree/sac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index c315189cc..1ed2c7536 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -107,8 +107,8 @@ def __init__( near the boundaries. Should be set to None if the actor model inherently produces bounded outputs. Typically used together with `action_scaling=True`. - This parameter is ignored in SAC, which used tanh squashing after sampling - unbounded from the gaussian policy (as in (arXiv 1801.01290): Equation 21.). + NOTE: This parameter has negligible effect since actions are already bounded by tanh + squashing in the forward method (as in arXiv 1801.01290, Equation 21). :param action_space: the environment's action_space. :param observation_space: the environment's observation space """ From 65a16e776ea7f120d84892ac84ce78fdaef4d008 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 01:00:52 +0200 Subject: [PATCH 149/230] Fix determinism test name --- test/discrete/test_iqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 64c7ea885..a1c5fdbec 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -184,4 +184,4 @@ def test_piqn(args: argparse.Namespace = get_args()) -> None: def test_iqm_determinism() -> None: main_fn = lambda args: test_iqn(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_iqm", main_fn, get_args()).run() + AlgorithmDeterminismTest("discrete_iqn", main_fn, get_args()).run() From 8e593e53b60f108e451ac3bf36c4fe8084ece88d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 02:30:29 +0200 Subject: [PATCH 150/230] v2: Fix parameter initialization in AutoAlpha --- tianshou/policy/modelfree/sac.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 1ed2c7536..285d22119 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -6,7 +6,6 @@ import numpy as np import torch from torch.distributions import Independent, Normal -from torch.nn import ParameterList from tianshou.data import Batch from tianshou.data.types import ( @@ -206,8 +205,8 @@ def __init__(self, target_entropy: float, log_alpha: float, optim: OptimizerFact """ super().__init__() self._target_entropy = target_entropy - self._log_alpha = torch.tensor(log_alpha, requires_grad=True) - self._optim, lr_scheduler = optim.create_instances(ParameterList([self._log_alpha])) + self._log_alpha = torch.nn.Parameter(torch.tensor(log_alpha)) + self._optim, lr_scheduler = optim.create_instances(self) if lr_scheduler is not None: raise ValueError( f"Learning rate schedulers are not supported by {self.__class__.__name__}" From 829be23c6428a5688764b9a2203dff9b2927e34f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 02:32:24 +0200 Subject: [PATCH 151/230] Fix test name --- test/discrete/test_discrete_sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 7b27d19de..ed1df42f2 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -159,6 +159,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_ppo_determinism() -> None: +def test_discrete_sac_determinism() -> None: main_fn = lambda args: test_discrete_sac(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_sac", main_fn, get_args()).run() From 9b403218006c332f82daac60fcf9239b9dadee2a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 11:30:35 +0200 Subject: [PATCH 152/230] v2: Rename test file --- test/discrete/{test_ppo2.py => test_ppo_discrete.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/discrete/{test_ppo2.py => test_ppo_discrete.py} (100%) diff --git a/test/discrete/test_ppo2.py b/test/discrete/test_ppo_discrete.py similarity index 100% rename from test/discrete/test_ppo2.py rename to test/discrete/test_ppo_discrete.py From ed723141e59689f63a09cab98fb3af6739c80503 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 21:06:54 +0200 Subject: [PATCH 153/230] v2: Update docstring --- tianshou/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 36a1d3260..69777646e 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -232,7 +232,7 @@ def __init__( This flag should normally remain False and should be set to True only by the algorithm which performs training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer, - the user should ensure that this flag is set correctly before calling update or learn. + the user should ensure that this flag is set correctly. """ self._compile() From 583305138f4da5865a2e80469732110895452ffa Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 12:31:22 +0200 Subject: [PATCH 154/230] v2: Fix docstring --- tianshou/utils/lagged_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/utils/lagged_network.py b/tianshou/utils/lagged_network.py index 0a1114b3d..61b3536ec 100644 --- a/tianshou/utils/lagged_network.py +++ b/tianshou/utils/lagged_network.py @@ -23,7 +23,7 @@ class EvalModeModuleWrapper(torch.nn.Module): A wrapper around a torch.nn.Module that forces the module to eval mode. The wrapped module supports only the forward method, attribute access is not supported. - NOTE: It is recommended to support attribute/method access beyond this via `__getattr__`, + NOTE: It is *not* recommended to support attribute/method access beyond this via `__getattr__`, because torch.nn.Module already heavily relies on `__getattr__` to provides its own attribute access. Overriding it naively will cause problems! But it's also not necessary for our use cases; forward is enough. From 98789816b0396b2bd55aa14b3d992be8cae3a7ff Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 12:30:01 +0200 Subject: [PATCH 155/230] v2: Rainbow: Do not wrap model_old in EvalModeModuleWrapper to allow NoisyLinear to take effect --- tianshou/policy/modelfree/rainbow.py | 65 ++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 9 deletions(-) diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 88d572d47..b7e35cb33 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -3,8 +3,10 @@ from torch import nn from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.modelfree.c51 import C51 +from tianshou.policy.modelfree.c51 import C51, C51Policy from tianshou.policy.modelfree.pg import LossSequenceTrainingStats +from tianshou.policy.optim import OptimizerFactory +from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.discrete import NoisyLinear @@ -14,13 +16,60 @@ class RainbowTrainingStats: class RainbowDQN(C51): - """Implementation of Rainbow DQN. arXiv:1710.02298. + """Implementation of Rainbow DQN. arXiv:1710.02298.""" - .. seealso:: + def __init__( + self, + *, + policy: C51Policy, + optim: OptimizerFactory, + gamma: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 0, + ) -> None: + """ + :param policy: a policy following the rules (s -> action_values_BA) + :param optim: the optimizer factory for the policy's model. + :param gamma: the discount factor in [0, 1] for future rewards. + This determines how much future rewards are valued compared to immediate ones. + Lower values (closer to 0) make the agent focus on immediate rewards, creating "myopic" + behavior. Higher values (closer to 1) make the agent value long-term rewards more, + potentially improving performance in tasks where delayed rewards are important but + increasing training variance by incorporating more environmental stochasticity. + Typically set between 0.9 and 0.99 for most reinforcement learning tasks + :param estimation_step: the number of future steps (> 0) to consider when computing temporal + difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: + higher values reduce bias (by relying less on potentially inaccurate value estimates) + but increase variance (by incorporating more environmental stochasticity and reducing + the averaging effect). A value of 1 corresponds to standard TD learning with immediate + bootstrapping, while very large values approach Monte Carlo-like estimation that uses + complete episode returns. + :param target_update_freq: the number of training iterations between each complete update of + the target network. + Controls how frequently the target Q-network parameters are updated with the current + Q-network values. + A value of 0 disables the target network entirely, using only a single network for both + action selection and bootstrap targets. + Higher values provide more stable learning targets but slow down the propagation of new + value estimates. Lower positive values allow faster learning but may lead to instability + due to rapidly changing targets. + Typically set between 100-10000 for DQN variants, with exact values depending on environment + complexity. + """ + super().__init__( + policy=policy, + optim=optim, + gamma=gamma, + estimation_step=estimation_step, + target_update_freq=target_update_freq, + ) - Please refer to :class:`~tianshou.policy.C51Policy` for more detailed - explanation. - """ + # Remove the wrapper that forces eval mode for the target network, + # because Rainbow requires it to be set to train mode for sampling noise + # in NoisyLinear layers to take effect. + if self.use_target_network: + assert isinstance(self.model_old, EvalModeModuleWrapper) + self.model_old = self.model_old.module @staticmethod def _sample_noise(model: nn.Module) -> bool: @@ -45,7 +94,5 @@ def _update_with_batch( ) -> LossSequenceTrainingStats: self._sample_noise(self.policy.model) if self.use_target_network: - assert self.model_old is not None - if self._sample_noise(self.model_old): - self.model_old.train() # so that NoisyLinear takes effect + self._sample_noise(self.model_old) return super()._update_with_batch(batch) From 66c0c866ed649067cd98eeedca7bc9fc84406f8f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 21:39:42 +0200 Subject: [PATCH 156/230] v2: Fix LaggedNetworkCollection.full_parameter_update forcing target network to eval mode Achieve this by copying only the parameters (similar to Polyak update case) instead of loading the other model's state dict, which is not strictly what the parameter update was supposed to achieve. Note: While forcing a target network to eval mode was appropriate in most cases, RainbowDQN is a notable exception where a target network is required to be in training mode. --- tianshou/utils/lagged_network.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tianshou/utils/lagged_network.py b/tianshou/utils/lagged_network.py index 61b3536ec..2565d981d 100644 --- a/tianshou/utils/lagged_network.py +++ b/tianshou/utils/lagged_network.py @@ -81,5 +81,7 @@ def polyak_parameter_update(self, tau: float) -> None: def full_parameter_update(self) -> None: """Fully updates the target networks with the source networks' parameters (exact copy).""" for pair in self._lagged_network_pairs: - pair.target.load_state_dict(pair.source.state_dict()) - pair.target.eval() + for tgt_param, src_param in zip( + pair.target.parameters(), pair.source.parameters(), strict=True + ): + tgt_param.data.copy_(src_param.data) From 40381ebd2d8aa0af1d703c17579a0eca759d3b07 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 21:41:06 +0200 Subject: [PATCH 157/230] v2: NoisyLinear: Treat the noise parameters as parameters instead of buffers such that they will be copied with a lagged network parameter update (target network handling) This is a relevant change in order for RainbowDQN to produce exactly the same behaviour as in v1. --- tianshou/utils/net/discrete.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 8da022c7e..2d3da153d 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -328,8 +328,8 @@ def __init__(self, in_features: int, out_features: int, noisy_std: float = 0.5) self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features)) # Factorized noise parameters. - self.register_buffer("eps_p", torch.FloatTensor(in_features)) - self.register_buffer("eps_q", torch.FloatTensor(out_features)) + self.eps_p = nn.Parameter(torch.FloatTensor(in_features), requires_grad=False) + self.eps_q = nn.Parameter(torch.FloatTensor(out_features), requires_grad=False) self.in_features = in_features self.out_features = out_features From aa23efb56f49a402f8ce214d968072dcf944504e Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 May 2025 23:09:32 +0200 Subject: [PATCH 158/230] v2: typing, use EvalModeModuleWrapper in annotations --- tianshou/policy/base.py | 3 ++- tianshou/policy/imitation/discrete_crr.py | 3 +++ tianshou/policy/modelfree/dqn.py | 4 +++- tianshou/policy/modelfree/rainbow.py | 2 ++ tianshou/utils/lagged_network.py | 2 +- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 69777646e..0786df5c6 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -32,6 +32,7 @@ from tianshou.policy.optim import OptimizerFactory from tianshou.utils.determinism import TraceLogger from tianshou.utils.lagged_network import ( + EvalModeModuleWrapper, LaggedNetworkCollection, ) from tianshou.utils.net.common import RandomActor @@ -420,7 +421,7 @@ class LaggedNetworkAlgorithmMixin(ABC): def __init__(self) -> None: self._lagged_networks = LaggedNetworkCollection() - def _add_lagged_network(self, src: torch.nn.Module) -> torch.nn.Module: + def _add_lagged_network(self, src: torch.nn.Module) -> EvalModeModuleWrapper: """ Adds a lagged network to the collection, returning the target network, which is forced to eval mode. The target network is a copy of the source network, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 759702e99..f80f5d70c 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -19,6 +19,7 @@ SimpleLossTrainingStats, ) from tianshou.policy.optim import OptimizerFactory +from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.discrete import DiscreteCritic @@ -96,6 +97,8 @@ def __init__( self._target = target_update_freq > 0 self._freq = target_update_freq self._iter = 0 + self.actor_old: torch.nn.Module | torch.Tensor | EvalModeModuleWrapper + self.critic_old: torch.nn.Module | EvalModeModuleWrapper if self._target: self.actor_old = self._add_lagged_network(self.policy.actor) self.critic_old = self._add_lagged_network(self.critic) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 504ec5b80..ca763a604 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -6,6 +6,7 @@ import numpy as np import torch from sensai.util.helper import mark_used +from torch import nn from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data.batch import BatchProtocol @@ -27,6 +28,7 @@ SimpleLossTrainingStats, ) from tianshou.policy.optim import OptimizerFactory +from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.common import Net mark_used(ActBatchProtocol) @@ -233,7 +235,7 @@ def __init__( self.target_update_freq = target_update_freq # TODO: 1 would be a more reasonable initialization given how it is incremented self._iter = 0 - self.model_old = ( + self.model_old: EvalModeModuleWrapper | nn.Module | None = ( self._add_lagged_network(self.policy.model) if self.use_target_network else None ) diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index b7e35cb33..c1e1d9039 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -64,6 +64,7 @@ def __init__( target_update_freq=target_update_freq, ) + self.model_old: nn.Module | None # tighten type, see below # Remove the wrapper that forces eval mode for the target network, # because Rainbow requires it to be set to train mode for sampling noise # in NoisyLinear layers to take effect. @@ -94,5 +95,6 @@ def _update_with_batch( ) -> LossSequenceTrainingStats: self._sample_noise(self.policy.model) if self.use_target_network: + assert self.model_old is not None self._sample_noise(self.model_old) return super()._update_with_batch(batch) diff --git a/tianshou/utils/lagged_network.py b/tianshou/utils/lagged_network.py index 2565d981d..3f5146580 100644 --- a/tianshou/utils/lagged_network.py +++ b/tianshou/utils/lagged_network.py @@ -54,7 +54,7 @@ class LaggedNetworkCollection: def __init__(self) -> None: self._lagged_network_pairs: list[LaggedNetworkPair] = [] - def add_lagged_network(self, source: torch.nn.Module) -> torch.nn.Module: + def add_lagged_network(self, source: torch.nn.Module) -> EvalModeModuleWrapper: """ Adds a lagged network to the collection, returning the target network, which is forced to eval mode. The target network is a copy of the source network, From 4f9600af4392e9586ee6d66f8ec95c9a2d597b1a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 15 May 2025 00:05:27 +0200 Subject: [PATCH 159/230] v2: minor, typing --- tianshou/policy/modelfree/dqn.py | 3 +-- tianshou/policy/modelfree/rainbow.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index ca763a604..c8fd70907 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -6,7 +6,6 @@ import numpy as np import torch from sensai.util.helper import mark_used -from torch import nn from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data.batch import BatchProtocol @@ -235,7 +234,7 @@ def __init__( self.target_update_freq = target_update_freq # TODO: 1 would be a more reasonable initialization given how it is incremented self._iter = 0 - self.model_old: EvalModeModuleWrapper | nn.Module | None = ( + self.model_old: EvalModeModuleWrapper | None = ( self._add_lagged_network(self.policy.model) if self.use_target_network else None ) diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index c1e1d9039..501330e45 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -64,10 +64,11 @@ def __init__( target_update_freq=target_update_freq, ) - self.model_old: nn.Module | None # tighten type, see below + self.model_old: nn.Module | None # type: ignore[assignment] # Remove the wrapper that forces eval mode for the target network, # because Rainbow requires it to be set to train mode for sampling noise # in NoisyLinear layers to take effect. + # (minor violation of Liskov Substitution Principle) if self.use_target_network: assert isinstance(self.model_old, EvalModeModuleWrapper) self.model_old = self.model_old.module From 7e2ab505c4a64c9b6f04e6cb2231e33684c6a545 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 12:11:37 +0200 Subject: [PATCH 160/230] v2: A2C: Fix gradient step counter not being incremented --- tianshou/policy/modelfree/a2c.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 5cd09d7fd..c34d240bf 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -255,7 +255,7 @@ def _update_with_batch( # type: ignore[override] gradient_steps = 0 for _ in range(repeat): for minibatch in batch.split(split_batch_size, merge_last=True): - gradient_steps = 0 + gradient_steps += 1 # calculate loss for actor dist = self.policy(minibatch).dist From 8b3002882578580c2322df04b1a8c4a7694aa6a5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 12:11:46 +0200 Subject: [PATCH 161/230] v2: Fix docstring --- tianshou/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 5975bbf93..44e27c07a 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -428,7 +428,7 @@ def _add_lagged_network(self, src: torch.nn.Module) -> EvalModeModuleWrapper: which, however, supports only the forward method (hence the type torch.nn.Module); attribute access is not supported. - :param source: the source network whose parameters are to be copied to the target network + :param src: the source network whose parameters are to be copied to the target network :return: the target network, which supports only the forward method and is forced to eval mode """ return self._lagged_networks.add_lagged_network(src) From 87c7fb5f74b2e0687ed51f56adaf8daba17cbf2d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 12:35:29 +0200 Subject: [PATCH 162/230] v2: Improve change log --- CHANGELOG.md | 390 +++++++++++++++++++++++++++------------------------ 1 file changed, 206 insertions(+), 184 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bcffae429..129f75773 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,187 +1,209 @@ -# Changelog - -## Release 2.0.0 - -* `Trainer` abstraction (formerly `BaseTrainer`): - * The trainer logic and configuration is now properly separated between the three cases of on-policy, off-policy - and offline learning: The base class is no longer a "God" class which does it all; logic and functionality has moved - to the respective subclasses (`OnPolicyTrainer`, `OffPolicyTrainer` and `OfflineTrainer`, with `OnlineTrainer` - being introduced as a base class for the two former specialisations). - * The trainers now use configuration objects with central documentation (which has been greatly improved to enhance - clarity and usability in general); every type of trainer now has a dedicated configuration class which provides - precisely the options that are applicable. - * The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely - the methods and attributes a user should reasonably access. - * Further changes potentially affecting usage: - * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. - * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full - episodes) - * See also "Issues resolved" below (as issue resolution can result in usage changes) - * The default value for `test_in_train` was changed from True to False (updating all usage sites to explicitly - set the parameter), because False is the more natural default, which does not make assumptions about - returns/score values computed for the data from a collection step being at all meaningful for early stopping - * The management of episolon-greedy exploration for discrete Q-learning algorithms has been simplified: - * All respective Policy implementations (e.g. `DQNPolicy`, `C51Policy`, etc.) now accept two parameters - `eps_training` and `eps_inference`, which allows the training and test collection cases to be sufficiently - differentiated and makes the use of callback functions (`train_fn`, `test_fn`) unnecessary if only - constants are to be set. - * The setter method `set_eps` has been replaced with `set_eps_training` and `set_eps_inference` accordingly. - * Further internal changes unlikely to affect usage: - * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` - * The two places that collected and evaluated test episodes (`_test_in_train` and `_reset`) in addition to - `_test_step` were unified to use `_test_step` (with some minor parametrisation) and now log the results - of the test step accordingly. - * Issues resolved: - * Methods `run` and `reset`: Parameter `reset_prior_to_run` of `run` was never respected if it was set to `False`, - because the implementation of `__iter__` (now removed) would call `reset` regardless - and calling `reset` - is indeed necessary, because it initializes the training. The parameter was removed and replaced by - `reset_collectors` (such that `run` now replicates the parameters of `reset`). - * Inconsistent configuration options now raise exceptions rather than silently ignoring the issue in the - hope that default behaviour will achieve what the user intended. - One condition where `test_in_train` was silently set to `False` was removed and replaced by a warning. - * The stop criterion `stop_fn` did not consider scores as computed by `compute_score_fn` but instead always used - mean returns (i.e. it was assumed that the default implementation of `compute_score_fn` applies). - This is an inconsistency which has been resolved. - * The `gradient_step` counter was flawed (as it made assumptions about the underlying algorithms, which were - not valid). It has been replaced with an update step counter. - Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed accordingly. - * Migration information at a glance: - * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: - `OnPolicyTrainerParams`, `OffPolicyTrainerParams`, `OfflineTrainerParams`. - * Changed parameter default: Default for `test_in_train` was changed from True to False. - * Trainer classes have been renamed: - * `OnpolicyTrainer` -> `OnPolicyTrainer` - * `OffpolicyTrainer` -> `OffPolicyTrainer` - * Method `run`: The parameter `reset_prior_to_run` was removed and replaced by `reset_collectors` (see above). - * Methods `run` and `reset`: The parameter `reset_buffer` was renamed to `reset_collector_buffers` for clarity - * Trainers are no longer iterators; manual usage (not using `run`) should simply call `reset` followed by - calls of `execute_epoch`. -* `Policy` and `Algorithm` abstractions (formerly unified in `BasePolicy`): - * We now conceptually differentiate between the learning algorithm and the policy being optimised: - * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`. - Migration information: The instantiation of a policy is replaced by the instantiation of an `Algorithm`, - which is passed a `Policy`. In most cases, the former policy class name `Policy` is replaced by algorithm - class ``; exceptions are noted below. - * `ImitationPolicy` -> `OffPolicyImitationLearning`, `OfflineImitationLearning` - * `PGPolicy` -> `Reinforce` - * `MultiAgentPolicyManager` -> `MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm` - * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` - For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. - * Interface changes/improvements: - * Core methods have been renamed (and removed from the public interface): - * `process_fn` -> `_preprocess_batch` - * `post_process_fn` -> `_postprocess_batch` - * `learn` -> `_update_with_batch` - * The updating interface has been cleaned up: - * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. - * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. - * New method `run_training`: The `Algorithm` abstraction can now directly initiate the learning process via this method. - * `Algorithms` no longer require `torch.optim.Optimizer` instances and instead require `OptimizerFactory` - instances, which create the actual optimizers internally. - The new `OptimizerFactory` abstraction simultaneously handles the creation of learning rate schedulers - for the optimizers created (via method `with_lr_scheduler_factory` and accompanying factory abstraction - `LRSchedulerFactory`). - The parameter `lr_scheduler` has thus been removed from all algorithm constructors. - * The flag `updating` has been removed (no internal usage, general usefulness questionable). - * Parameter changes: - * `discount_factor` -> `gamma` (was already used internally almost everywhere) - * `reward_normalization` -> `return_standardization` or `return_scaling` (more precise naming) or removed (was actually unsupported by Q-learning algorithms) - * `return_standardization` in `Reinforce` and `DiscreteCRR` (as it applies standardization of returns) - * `return_scaling` in actor-critic on-policy algorithms (A2C, PPO, GAIL, NPG, TRPO) - * removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) - * `clip_grad` -> `max_grad_norm` (for consistency) - * Internal design improvements: - * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) - in `SAC`, `DiscreteSAC` and other algorithms. - * Class hierarchy: - * Abstract base class `Alpha` base class with value property and update method - * `FixedAlpha` for constant entropy coefficients - * `AutoAlpha` for automatic entropy tuning (replaces the old tuple-based representation) - * The (auto-)updating logic is now completely encapsulated, reducing the complexity of the algorithms. - * Implementations for continuous and discrete cases now share the same abstraction, - making the codebase more consistent while preserving the original functionality. - * Introduced a policy base class `ContinuousPolicyWithExplorationNoise` which encapsulates noise generation - for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). - * Multi-agent RL methods are now differentiated by the type of the sub-algorithms being employed - (`MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm`), which renders all interfaces clean. - Helper class `MARLDispatcher` has been factored out to manage the dispatching of data to the respective agents. - * Algorithms now internally use a wrapper (`Algorithm.Optimizer`) around the optimizers; creation is handled - by method `_create_optimizer`. - * This facilitates backpropagation steps with gradient clipping. - * The optimizers of an Algorithm instance are now centrally tracked, such that we can ensure that the - optimizers' states are handled alongside the model parameters when calling `state_dict` or `load_state_dict` - on the `Algorithm` instance. - Special handling of the restoration of optimizers' state dicts was thus removed from examples and tests. - * Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): - * Introduced base classes (to retain factorization without abusive inheritance): - * `ActorCriticOnPolicyAlgorithm` - * `ActorCriticOffPolicyAlgorithm` - * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) - * `QLearningOffPolicyAlgorithm` - * `A2C`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `Reinforce` - * `BDQN`: - * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` - * Remove parameter `clip_loss_grad` (unused; only passed on to former base class) - * `C51`: - * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` - * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) - * `CQL`: - * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). - * Remove parameter `estimation_step`, which was not actually used (it was only passed it on to its - superclass). - * `DiscreteBCQ`: - * Inherit directly from `OfflineAlgorithm` instead of `DQN` - * Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to - former the base class but actually unused. - * `DiscreteCQL`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to - base class `QRDQN` (and unused by it). - * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) - * `FQF`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to - base class `QRDQN` (and unused by it). - * `IQN`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to - base class `QRDQN` (and unused by it). - * `NPG`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `A2C` - * `QRDQN`: - * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` - * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) - * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` - * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` - * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` -* High-Level API changes: - * Detailed optimizer configuration (analogous to the procedural API) is now possible: - * All optimizers can be configured in the respective algorithm-specific `Params` object by using - `OptimizerFactoryFactory` instances as parameter values (e.g. for `optim`, `actor_optim`, `critic_optim`, etc.). - * Learning rate schedulers remain separate parameters and now use `LRSchedulerFactoryFactory` - instances. The respective parameter names now use the suffix `lr_scheduler` instead of `lr_scheduler_factory` - (as the precise nature need not be reflected in the name; brevity is preferable). - * `SamplingConfig` is replaced by `TrainingConfig` and subclasses differentiating off-policy and on-policy cases - appropriately (`OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`). - * The `test_in_train` parameter is now exposed (default False). - * Inapplicable arguments can no longer be set in the respective subclass (e.g. `OffPolicyTrainingConfig` does not - contain parameter `repeat_per_collect`). -* Peripheral changes: - * The `Actor` classes have been renamed for clarity: - * `BaseActor` -> `Actor` - * `continuous.ActorProb` -> `ContinuousActorProb` - * `coninuous.Actor` -> `ContinuousActorDeterministic` - * `discrete.Actor` -> `DiscreteActor` - * The `Critic` classes have been renamed for clarity: - * `continuous.Critic` -> `ContinuousCritic` - * `discrete.Critic` -> `DiscreteCritic` - * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. - * Fix issues pertaining to the torch device assignment of network components (#810): - * Remove 'device' member (and the corresponding constructor argument) from the following classes: - `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProb`, `ContinuousCritic`, - `DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`, - `IntrinsicCuriosityModule`, `Net`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, - `RecurrentActorProb`, `RecurrentCritic`, `VAE` - * (Peripheral change:) Require the use of keyword arguments for the constructors of all of these classes - * Clean up handling of modules that define attribute `output_dim`, introducing the explicit base class - `ModuleWithVectorOutput` - * Interfaces where one could specify either a module with `output_dim` or additionally provide the output - dimension as an argument were changed to use `ModuleWithVectorOutput`. - * The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance - (via adaptation if necessary). +# Change Log + +## Upcoming Release 2.0.0 + +This major release of Tianshou is a big step towards cleaner design and improved usability. + +Given the large extent of the changes, it was not possible to maintain compatibility with the previous version. + * Persisted agents that were created with earlier versions cannot be loaded in v2. + * Source code from v1 can, however, be migrated to v2 with minimal effort. + See migration information below. For concrete examples, you may use git to diff individual + example scripts with the corresponding ones in `v1.2.0`. + +This release is brought to you by [Applied AI Institute gGmbH](https://www.appliedai-institute.de). + +Developers: + * Dr. Dominik Jain (@opcode81) + * Michael Panchenko (@MischaPanch) + +### Trainer Abstraction + +* The trainer logic and configuration is now properly separated between the three cases of on-policy, off-policy + and offline learning: The base class is no longer a "God" class (formerly `BaseTrainer`) which does it all; logic and functionality has moved + to the respective subclasses (`OnPolicyTrainer`, `OffPolicyTrainer` and `OfflineTrainer`, with `OnlineTrainer` + being introduced as a base class for the two former specialisations). +* The trainers now use configuration objects with central documentation (which has been greatly improved to enhance + clarity and usability in general); every type of trainer now has a dedicated configuration class which provides + precisely the options that are applicable. +* The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely + the methods and attributes a user should reasonably access. +* Further changes potentially affecting usage: + * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. + * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full + episodes) + * See also "Issues resolved" below (as issue resolution can result in usage changes) + * The default value for `test_in_train` was changed from True to False (updating all usage sites to explicitly + set the parameter), because False is the more natural default, which does not make assumptions about + returns/score values computed for the data from a collection step being at all meaningful for early stopping + * The management of episolon-greedy exploration for discrete Q-learning algorithms has been simplified: + * All respective Policy implementations (e.g. `DQNPolicy`, `C51Policy`, etc.) now accept two parameters + `eps_training` and `eps_inference`, which allows the training and test collection cases to be sufficiently + differentiated and makes the use of callback functions (`train_fn`, `test_fn`) unnecessary if only + constants are to be set. + * The setter method `set_eps` has been replaced with `set_eps_training` and `set_eps_inference` accordingly. +* Further internal changes unlikely to affect usage: + * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` + * The two places that collected and evaluated test episodes (`_test_in_train` and `_reset`) in addition to + `_test_step` were unified to use `_test_step` (with some minor parametrisation) and now log the results + of the test step accordingly. +* Issues resolved: + * Methods `run` and `reset`: Parameter `reset_prior_to_run` of `run` was never respected if it was set to `False`, + because the implementation of `__iter__` (now removed) would call `reset` regardless - and calling `reset` + is indeed necessary, because it initializes the training. The parameter was removed and replaced by + `reset_collectors` (such that `run` now replicates the parameters of `reset`). + * Inconsistent configuration options now raise exceptions rather than silently ignoring the issue in the + hope that default behaviour will achieve what the user intended. + One condition where `test_in_train` was silently set to `False` was removed and replaced by a warning. + * The stop criterion `stop_fn` did not consider scores as computed by `compute_score_fn` but instead always used + mean returns (i.e. it was assumed that the default implementation of `compute_score_fn` applies). + This is an inconsistency which has been resolved. + * The `gradient_step` counter was flawed (as it made assumptions about the underlying algorithms, which were + not valid). It has been replaced with an update step counter. + Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed accordingly. +* Migration information at a glance: + * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: + `OnPolicyTrainerParams`, `OffPolicyTrainerParams`, `OfflineTrainerParams`. + * Changed parameter default: Default for `test_in_train` was changed from True to False. + * Trainer classes have been renamed: + * `OnpolicyTrainer` -> `OnPolicyTrainer` + * `OffpolicyTrainer` -> `OffPolicyTrainer` + * Method `run`: The parameter `reset_prior_to_run` was removed and replaced by `reset_collectors` (see above). + * Methods `run` and `reset`: The parameter `reset_buffer` was renamed to `reset_collector_buffers` for clarity + * Trainers are no longer iterators; manual usage (not using `run`) should simply call `reset` followed by + calls of `execute_epoch`. + +### Algorithms and Policies + +* We now conceptually differentiate between the learning algorithm and the policy being optimised: + * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`. + * Migration information: The instantiation of a policy is replaced by the instantiation of an `Algorithm`, + which is passed a `Policy`. In most cases, the former policy class name `Policy` is replaced by algorithm + class ``; exceptions are noted below. + * `ImitationPolicy` -> `OffPolicyImitationLearning`, `OfflineImitationLearning` + * `PGPolicy` -> `Reinforce` + * `MultiAgentPolicyManager` -> `MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm` + * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` + For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. +* Interface changes/improvements: + * Core methods have been renamed (and removed from the public interface): + * `process_fn` -> `_preprocess_batch` + * `post_process_fn` -> `_postprocess_batch` + * `learn` -> `_update_with_batch` + * The updating interface has been cleaned up: + * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. + * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. + * New method `run_training`: The `Algorithm` abstraction can now directly initiate the learning process via this method. + * `Algorithms` no longer require `torch.optim.Optimizer` instances and instead require `OptimizerFactory` + instances, which create the actual optimizers internally. + The new `OptimizerFactory` abstraction simultaneously handles the creation of learning rate schedulers + for the optimizers created (via method `with_lr_scheduler_factory` and accompanying factory abstraction + `LRSchedulerFactory`). + The parameter `lr_scheduler` has thus been removed from all algorithm constructors. + * The flag `updating` has been removed (no internal usage, general usefulness questionable). + * Parameter changes: + * `discount_factor` -> `gamma` (was already used internally almost everywhere) + * `reward_normalization` -> `return_standardization` or `return_scaling` (more precise naming) or removed (was actually unsupported by Q-learning algorithms) + * `return_standardization` in `Reinforce` and `DiscreteCRR` (as it applies standardization of returns) + * `return_scaling` in actor-critic on-policy algorithms (A2C, PPO, GAIL, NPG, TRPO) + * removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) + * `clip_grad` -> `max_grad_norm` (for consistency) +* Internal design improvements: + * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) + in `SAC`, `DiscreteSAC` and other algorithms. + * Class hierarchy: + * Abstract base class `Alpha` base class with value property and update method + * `FixedAlpha` for constant entropy coefficients + * `AutoAlpha` for automatic entropy tuning (replaces the old tuple-based representation) + * The (auto-)updating logic is now completely encapsulated, reducing the complexity of the algorithms. + * Implementations for continuous and discrete cases now share the same abstraction, + making the codebase more consistent while preserving the original functionality. + * Introduced a policy base class `ContinuousPolicyWithExplorationNoise` which encapsulates noise generation + for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). + * Multi-agent RL methods are now differentiated by the type of the sub-algorithms being employed + (`MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm`), which renders all interfaces clean. + Helper class `MARLDispatcher` has been factored out to manage the dispatching of data to the respective agents. + * Algorithms now internally use a wrapper (`Algorithm.Optimizer`) around the optimizers; creation is handled + by method `_create_optimizer`. + * This facilitates backpropagation steps with gradient clipping. + * The optimizers of an Algorithm instance are now centrally tracked, such that we can ensure that the + optimizers' states are handled alongside the model parameters when calling `state_dict` or `load_state_dict` + on the `Algorithm` instance. + Special handling of the restoration of optimizers' state dicts was thus removed from examples and tests. +* Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): + * Introduced base classes (to retain factorization without abusive inheritance): + * `ActorCriticOnPolicyAlgorithm` + * `ActorCriticOffPolicyAlgorithm` + * `ActorDualCriticsOffPolicyAlgorithm` (extends `ActorCriticOffPolicyAlgorithm`) + * `QLearningOffPolicyAlgorithm` + * `A2C`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `Reinforce` + * `BDQN`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameter `clip_loss_grad` (unused; only passed on to former base class) + * `C51`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) + * `CQL`: + * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). + * Remove parameter `estimation_step`, which was not actually used (it was only passed it on to its + superclass). + * `DiscreteBCQ`: + * Inherit directly from `OfflineAlgorithm` instead of `DQN` + * Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + former the base class but actually unused. + * `DiscreteCQL`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). + * `DiscreteCRR`: Inherit directly from `OfflineAlgorithm` instead of `Reinforce` (on-policy) + * `FQF`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). + * `IQN`: Remove unused parameters `clip_loss_grad` and `is_double`, which were only passed on to + base class `QRDQN` (and unused by it). + * `NPG`: Inherit from `ActorCriticOnPolicyAlgorithm` instead of `A2C` + * `QRDQN`: + * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` + * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) + * `REDQ`: Inherit from `ActorCriticOffPolicyAlgorithm` instead of `DDPG` + * `SAC`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` + * `TD3`: Inherit from `ActorDualCriticsOffPolicyAlgorithm` instead of `DDPG` + +### High-Level API + +* Detailed optimizer configuration (analogous to the procedural API) is now possible: + * All optimizers can be configured in the respective algorithm-specific `Params` object by using + `OptimizerFactoryFactory` instances as parameter values (e.g. for `optim`, `actor_optim`, `critic_optim`, etc.). + * Learning rate schedulers remain separate parameters and now use `LRSchedulerFactoryFactory` + instances. The respective parameter names now use the suffix `lr_scheduler` instead of `lr_scheduler_factory` + (as the precise nature need not be reflected in the name; brevity is preferable). +* `SamplingConfig` is replaced by `TrainingConfig` and subclasses differentiating off-policy and on-policy cases + appropriately (`OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`). + * The `test_in_train` parameter is now exposed (default False). + * Inapplicable arguments can no longer be set in the respective subclass (e.g. `OffPolicyTrainingConfig` does not + contain parameter `repeat_per_collect`). + +### Peripheral Changes + +* The `Actor` classes have been renamed for clarity: + * `BaseActor` -> `Actor` + * `continuous.ActorProb` -> `ContinuousActorProb` + * `coninuous.Actor` -> `ContinuousActorDeterministic` + * `discrete.Actor` -> `DiscreteActor` +* The `Critic` classes have been renamed for clarity: + * `continuous.Critic` -> `ContinuousCritic` + * `discrete.Critic` -> `DiscreteCritic` +* Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. +* Fix issues pertaining to the torch device assignment of network components (#810): + * Remove 'device' member (and the corresponding constructor argument) from the following classes: + `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProb`, `ContinuousCritic`, + `DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`, + `IntrinsicCuriosityModule`, `Net`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, + `RecurrentActorProb`, `RecurrentCritic`, `VAE` + * (Peripheral change:) Require the use of keyword arguments for the constructors of all of these classes +* Clean up handling of modules that define attribute `output_dim`, introducing the explicit base class + `ModuleWithVectorOutput` + * Interfaces where one could specify either a module with `output_dim` or additionally provide the output + dimension as an argument were changed to use `ModuleWithVectorOutput`. + * The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance + (via adaptation if necessary). + ## Unreleased From a724abb15866b28fe154028fc6e72c12d2732c88 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 15 May 2025 15:26:59 +0200 Subject: [PATCH 163/230] v1: minor rename --- tianshou/policy/base.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 44e27c07a..dda5ebadd 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -385,8 +385,8 @@ def _compile() -> None: f32 = np.array([0, 1], dtype=np.float32) b = np.array([False, True], dtype=np.bool_) i64 = np.array([[0, 1]], dtype=np.int64) - _gae_return(f64, f64, f64, b, 0.1, 0.1) - _gae_return(f32, f32, f64, b, 0.1, 0.1) + _gae(f64, f64, f64, b, 0.1, 0.1) + _gae(f32, f32, f64, b, 0.1, 0.1) _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") @@ -749,7 +749,7 @@ def compute_episodic_return( end_flag = np.logical_or(batch.terminated, batch.truncated) end_flag[np.isin(indices, buffer.unfinished_index())] = True - advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + advantage = _gae(v_s, v_s_, rew, end_flag, gamma, gae_lambda) returns = advantage + v_s # normalization varies from each policy, so we don't do it here return returns, advantage @@ -1117,9 +1117,8 @@ def forward( return cast(ActStateBatchProtocol, Batch(act=act, state=next_state)) -# TODO: rename? See docstring @njit -def _gae_return( +def _gae( v_s: np.ndarray, v_s_: np.ndarray, rew: np.ndarray, @@ -1129,8 +1128,7 @@ def _gae_return( ) -> np.ndarray: r"""Computes advantages with GAE. - Note: doesn't compute returns but rather advantages. The return - is given by the output of this + v_s. Note that the advantages plus v_s + The return is given by the output of this + v_s. Note that the advantages plus v_s is exactly the same as the TD-lambda target, which is computed by the recursive formula: From f5c5fcee2d1b1a9804f2d42a55733155c339835c Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 15 May 2025 15:26:26 +0200 Subject: [PATCH 164/230] v2: Add base classes/marker interfaces for actors to more clearly define the interfaces of associated policies --- CHANGELOG.md | 4 +-- examples/box2d/bipedal_hardcore_sac.py | 4 +-- examples/box2d/mcc_sac.py | 4 +-- examples/inverse/irl_gail.py | 8 ++--- examples/mujoco/mujoco_a2c.py | 8 ++--- examples/mujoco/mujoco_npg.py | 8 ++--- examples/mujoco/mujoco_ppo.py | 8 ++--- examples/mujoco/mujoco_redq.py | 4 +-- examples/mujoco/mujoco_reinforce.py | 8 ++--- examples/mujoco/mujoco_sac.py | 4 +-- examples/mujoco/mujoco_trpo.py | 8 ++--- examples/offline/d4rl_cql.py | 4 +-- examples/vizdoom/vizdoom_ppo.py | 4 +-- test/base/test_policy.py | 10 +++--- test/continuous/test_npg.py | 8 ++--- test/continuous/test_ppo.py | 8 ++--- test/continuous/test_redq.py | 4 +-- test/continuous/test_sac_with_il.py | 4 +-- test/continuous/test_trpo.py | 8 ++--- test/discrete/test_a2c_with_il.py | 4 +-- test/discrete/test_pg.py | 4 +-- test/modelbased/test_ppo_icm.py | 4 +-- test/offline/gather_pendulum_data.py | 4 +-- test/offline/test_cql.py | 4 +-- test/offline/test_gail.py | 8 ++--- test/pettingzoo/pistonball_continuous.py | 8 ++--- tianshou/highlevel/algorithm.py | 6 ++-- tianshou/highlevel/module/actor.py | 2 +- tianshou/policy/base.py | 5 ++- tianshou/policy/imitation/discrete_crr.py | 4 +-- tianshou/policy/imitation/gail.py | 4 +-- tianshou/policy/modelfree/a2c.py | 8 ++--- tianshou/policy/modelfree/ddpg.py | 7 +++-- tianshou/policy/modelfree/npg.py | 4 +-- tianshou/policy/modelfree/pg.py | 24 ++++++++++----- tianshou/policy/modelfree/ppo.py | 4 +-- tianshou/policy/modelfree/redq.py | 4 +-- tianshou/policy/modelfree/sac.py | 4 +-- tianshou/policy/modelfree/td3.py | 2 +- tianshou/policy/modelfree/trpo.py | 4 +-- tianshou/trainer/base.py | 5 +-- tianshou/utils/net/common.py | 37 ++++++++++++++++++----- tianshou/utils/net/continuous.py | 28 ++++++++++------- tianshou/utils/net/discrete.py | 27 ++++++++++++----- 44 files changed, 192 insertions(+), 141 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eca734dab..bd389c68b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -183,7 +183,7 @@ Developers: * The `Actor` classes have been renamed for clarity: * `BaseActor` -> `Actor` - * `continuous.ActorProb` -> `ContinuousActorProb` + * `continuous.ActorProb` -> `ContinuousActorProbabilistic` * `coninuous.Actor` -> `ContinuousActorDeterministic` * `discrete.Actor` -> `DiscreteActor` * The `Critic` classes have been renamed for clarity: @@ -192,7 +192,7 @@ Developers: * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. * Fix issues pertaining to the torch device assignment of network components (#810): * Remove 'device' member (and the corresponding constructor argument) from the following classes: - `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProb`, `ContinuousCritic`, + `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProbabilistic`, `ContinuousCritic`, `DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`, `IntrinsicCuriosityModule`, `Net`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, `RecurrentActorProb`, `RecurrentCritic`, `VAE` diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index ce494ff90..72aae0b34 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -18,7 +18,7 @@ from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -111,7 +111,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 9fa0310a8..03eeec085 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -17,7 +17,7 @@ from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -69,7 +69,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index cbb160451..56ed5b445 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -25,12 +25,12 @@ from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.policy import GAIL from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -127,7 +127,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, @@ -204,7 +204,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: ) print("dataset loaded") - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index e0041b99b..a7f2a6eac 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -15,11 +15,11 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import A2C from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: @@ -94,7 +94,7 @@ def main(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, @@ -140,7 +140,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index ca178557f..86f1d30df 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -15,11 +15,11 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import NPG from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: @@ -99,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, @@ -138,7 +138,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index c65b302f4..ffbaa1d3a 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -15,11 +15,11 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPO from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: @@ -99,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, @@ -141,7 +141,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 6d8ae40d7..09fdadad3 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: @@ -90,7 +90,7 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 42546f2cf..5cbb77ef9 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -15,11 +15,11 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import Reinforce from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb +from tianshou.utils.net.continuous import ContinuousActorProbabilistic def get_args() -> argparse.Namespace: @@ -91,7 +91,7 @@ def main(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, @@ -124,7 +124,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 1c57aa955..685276af6 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: @@ -86,7 +86,7 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # model net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 1620e7ac3..6f291cf47 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -15,11 +15,11 @@ from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TRPO from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic def get_args() -> argparse.Namespace: @@ -102,7 +102,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, @@ -141,7 +141,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 64eccfd2e..b9c3d17b9 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -20,7 +20,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -250,7 +250,7 @@ def test_cql() -> None: action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index bcf94f70b..f6ff24e93 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -15,7 +15,7 @@ from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import ( @@ -149,7 +149,7 @@ def dist(logits: torch.Tensor) -> Categorical: return Categorical(logits=logits) # define policy and algorithm - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_scaling=False, diff --git a/test/base/test_policy.py b/test/base/test_policy.py index bd94e4ecf..559c685e2 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -7,10 +7,10 @@ from tianshou.data import Batch from tianshou.policy import PPO from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.net.discrete import DiscreteActor obs_shape = (5,) @@ -31,10 +31,10 @@ def test_calculate_discounted_returns() -> None: def algorithm(request: pytest.FixtureRequest) -> PPO: action_type = request.param action_space: gym.spaces.Box | gym.spaces.Discrete - actor: DiscreteActor | ContinuousActorProb + actor: DiscreteActor | ContinuousActorProbabilistic if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=Net( state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape ), @@ -64,7 +64,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: optim = AdamOptimizerFactory(lr=1e-3) algorithm: PPO - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist_fn, action_space=action_space, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index cd563b248..35f6b58f9 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -13,12 +13,12 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import NPG from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -83,7 +83,7 @@ def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( @@ -106,7 +106,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index aba64787e..c265d2955 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -12,12 +12,12 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import PPO from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -85,7 +85,7 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( @@ -105,7 +105,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 8d401b869..1662f2abe 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -18,7 +18,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -83,7 +83,7 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 3013c6e06..b2d9e7f77 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -19,7 +19,7 @@ from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ( ContinuousActorDeterministic, - ContinuousActorProb, + ContinuousActorProbabilistic, ContinuousCritic, ) from tianshou.utils.space_info import SpaceInfo @@ -95,7 +95,7 @@ def test_sac_with_il( # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index ac91cf8fe..658df9efe 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -13,12 +13,12 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import TRPO from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -84,7 +84,7 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T hidden_sizes=args.hidden_sizes, activation=nn.Tanh, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( @@ -107,7 +107,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 18577252a..862f10450 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -13,7 +13,7 @@ from tianshou.policy import A2C, OffPolicyImitationLearning from tianshou.policy.base import Algorithm from tianshou.policy.imitation.base import ImitationPolicy -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -102,7 +102,7 @@ def test_a2c_with_il( critic = DiscreteCritic(preprocess_net=net).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 395224090..9893f0dc5 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv from tianshou.policy import Reinforce from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -76,7 +76,7 @@ def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tru ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist_fn = torch.distributions.Categorical - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=net, dist_fn=dist_fn, action_space=env.action_space, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 566ea4746..9045e29c8 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -12,7 +12,7 @@ from tianshou.policy import PPO from tianshou.policy.base import Algorithm from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger @@ -118,7 +118,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # base algorithm: PPO optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index cca234f91..b5291f21c 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -16,7 +16,7 @@ from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -93,7 +93,7 @@ def gather_data() -> VectorReplayBuffer: test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 24bf8cef5..acc60461c 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -18,7 +18,7 @@ from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -110,7 +110,7 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, ) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, unbounded=True, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index b4253de9b..81a25a613 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -13,12 +13,12 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import GAIL, Algorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -93,7 +93,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T test_envs.seed(args.seed) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, @@ -133,7 +133,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index b6df7a8a2..490eae6c3 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -17,13 +17,13 @@ from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import PPO, Algorithm from tianshou.policy.base import OnPolicyAlgorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ModuleWithVectorOutput -from tianshou.utils.net.continuous import ContinuousActorProb, ContinuousCritic +from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic class DQNet(ModuleWithVectorOutput): @@ -171,7 +171,7 @@ def get_agents( device=args.device, ).to(args.device) - actor = ContinuousActorProb( + actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action, @@ -193,7 +193,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicy( + policy = ActorPolicyProbabilistic( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index cb20b0c28..fa7027630 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -70,7 +70,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy from tianshou.policy.modelfree.iqn import IQNPolicy -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.modelfree.redq import REDQPolicy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.trainer import OffPolicyTrainer, OnPolicyTrainer, Trainer @@ -305,7 +305,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: dist_fn = self.actor_factory.create_dist_fn(envs) assert dist_fn is not None policy = self._create_policy_from_args( - ActorPolicy, + ActorPolicyProbabilistic, kwargs, ["action_scaling", "action_bound_method", "deterministic_eval"], actor=actor, @@ -363,7 +363,7 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: params = self._create_kwargs(envs, device) policy = self._create_policy_from_args( - ActorPolicy, + ActorPolicyProbabilistic, params, [ "actor", diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 297109829..e3807569b 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -187,7 +187,7 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: hidden_sizes=self.hidden_sizes, activation=self.activation, ) - actor = continuous.ContinuousActorProb( + actor = continuous.ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=envs.get_action_shape(), unbounded=self.unbounded, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index dda5ebadd..55957ecb2 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -159,14 +159,13 @@ class Policy(nn.Module, ABC): def __init__( self, action_space: gym.Space, - # TODO: does the policy actually need the observation space? observation_space: gym.Space | None = None, action_scaling: bool = False, action_bound_method: Literal["clip", "tanh"] | None = "clip", ): """ :param action_space: the environment's action_space. - :param observation_space: the environment's observation space + :param observation_space: the environment's observation space. :param action_scaling: flag indicating whether, for continuous action spaces, actions should be scaled from the standard neural network output range [-1, 1] to the environment's action space range [action_space.low, action_space.high]. @@ -849,7 +848,7 @@ def compute_nstep_return( batch.returns = to_torch_as(n_step_return_IA, target_q_torch_IA) - # TODO: this is simply casting to a certain type. Why is this necessary, and why is it happening here? + # TODO: this is simply converting to a certain type. Why is this necessary, and why is it happening here? if hasattr(batch, "weight"): batch.weight = to_torch_as(batch.weight, target_q_torch_IA) diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index f80f5d70c..c97258815 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -97,13 +97,13 @@ def __init__( self._target = target_update_freq > 0 self._freq = target_update_freq self._iter = 0 - self.actor_old: torch.nn.Module | torch.Tensor | EvalModeModuleWrapper + self.actor_old: torch.nn.Module | EvalModeModuleWrapper self.critic_old: torch.nn.Module | EvalModeModuleWrapper if self._target: self.actor_old = self._add_lagged_network(self.policy.actor) self.critic_old = self._add_lagged_network(self.critic) else: - self.actor_old = self.actor + self.actor_old = self.policy.actor self.critic_old = self.critic self._policy_improvement_mode = policy_improvement_mode self._ratio_upper_bound = ratio_upper_bound diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index c3bb18894..43901fda5 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -12,7 +12,7 @@ ) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy.modelfree.a2c import A2CTrainingStats -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.modelfree.ppo import PPO from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.common import ModuleWithVectorOutput @@ -34,7 +34,7 @@ class GAIL(PPO): def __init__( self, *, - policy: ActorPolicy, + policy: ActorPolicyProbabilistic, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, expert_buffer: ReplayBuffer, diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index c34d240bf..52bcf45c0 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -12,7 +12,7 @@ OnPolicyAlgorithm, TrainingStats, ) -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ActorCritic @@ -29,13 +29,13 @@ class A2CTrainingStats(TrainingStats): gradient_steps: int -class ActorCriticOnPolicyAlgorithm(OnPolicyAlgorithm[ActorPolicy], ABC): +class ActorCriticOnPolicyAlgorithm(OnPolicyAlgorithm[ActorPolicyProbabilistic], ABC): """Abstract base class for actor-critic algorithms that use generalized advantage estimation (GAE).""" def __init__( self, *, - policy: ActorPolicy, + policy: ActorPolicyProbabilistic, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_include_actor: bool, @@ -157,7 +157,7 @@ class A2C(ActorCriticOnPolicyAlgorithm): def __init__( self, *, - policy: ActorPolicy, + policy: ActorPolicyProbabilistic, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, vf_coef: float = 0.5, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 2335b787e..235cf1861 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -28,7 +28,10 @@ TrainingStats, ) from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic +from tianshou.utils.net.continuous import ( + ContinuousActorDeterministicInterface, + ContinuousCritic, +) mark_used(ActBatchProtocol) @@ -114,7 +117,7 @@ class ContinuousDeterministicPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, - actor: torch.nn.Module | ContinuousActorDeterministic, + actor: ContinuousActorDeterministicInterface, exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.Space, observation_space: gym.Space | None = None, diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 0e7674d22..b47145257 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -11,7 +11,7 @@ from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy.base import TrainingStats from tianshou.policy.modelfree.a2c import ActorCriticOnPolicyAlgorithm -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @@ -33,7 +33,7 @@ class NPG(ActorCriticOnPolicyAlgorithm): def __init__( self, *, - policy: ActorPolicy, + policy: ActorPolicyProbabilistic, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_critic_iters: int = 5, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 4b0c217e0..0dd4b9685 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -30,8 +30,11 @@ ) from tianshou.policy.optim import OptimizerFactory from tianshou.utils import RunningMeanStd -from tianshou.utils.net.continuous import ContinuousActorProb -from tianshou.utils.net.discrete import DiscreteActor, dist_fn_categorical_from_logits +from tianshou.utils.net.common import ( + ContinuousActorProbabilisticInterface, + DiscreteActorInterface, +) +from tianshou.utils.net.discrete import dist_fn_categorical_from_logits log = logging.getLogger(__name__) @@ -61,11 +64,16 @@ class SimpleLossTrainingStats(TrainingStats): loss: float -class ActorPolicy(Policy): +class ActorPolicyProbabilistic(Policy): + """ + A policy that outputs (representations of) probability distributions from which + actions can be sampled. + """ + def __init__( self, *, - actor: torch.nn.Module | ContinuousActorProb | DiscreteActor, + actor: ContinuousActorProbabilisticInterface | DiscreteActorInterface, dist_fn: TDistFnDiscrOrCont, deterministic_eval: bool = False, action_space: gym.Space, @@ -180,11 +188,11 @@ def forward( return cast(DistBatchProtocol, result) -class DiscreteActorPolicy(ActorPolicy): +class DiscreteActorPolicy(ActorPolicyProbabilistic): def __init__( self, *, - actor: torch.nn.Module | DiscreteActor, + actor: DiscreteActorInterface, dist_fn: TDistFnDiscrete = dist_fn_categorical_from_logits, deterministic_eval: bool = False, action_space: gym.Space, @@ -231,7 +239,7 @@ def __init__( ) -TActorPolicy = TypeVar("TActorPolicy", bound=ActorPolicy) +TActorPolicy = TypeVar("TActorPolicy", bound=ActorPolicyProbabilistic) class DiscountedReturnComputation: @@ -301,7 +309,7 @@ def add_discounted_returns( return batch -class Reinforce(OnPolicyAlgorithm[ActorPolicy]): +class Reinforce(OnPolicyAlgorithm[ActorPolicyProbabilistic]): """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm.""" def __init__( diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index f580b7485..a1d2d73cf 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -7,7 +7,7 @@ from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import A2C from tianshou.policy.modelfree.a2c import A2CTrainingStats -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @@ -19,7 +19,7 @@ class PPO(A2C): def __init__( self, *, - policy: ActorPolicy, + policy: ActorPolicyProbabilistic, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, eps_clip: float = 0.2, diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 1da3b84af..2eda560c9 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -20,7 +20,7 @@ ) from tianshou.policy.modelfree.sac import Alpha from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.continuous import ContinuousActorProb +from tianshou.utils.net.continuous import ContinuousActorProbabilistic @dataclass @@ -38,7 +38,7 @@ class REDQPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, - actor: torch.nn.Module | ContinuousActorProb, + actor: torch.nn.Module | ContinuousActorProbabilistic, exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.spaces.Space, deterministic_eval: bool = True, diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 285d22119..fa96dce67 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -19,7 +19,7 @@ from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.policy.optim import OptimizerFactory from tianshou.utils.conversion import to_optional_float -from tianshou.utils.net.continuous import ContinuousActorProb +from tianshou.utils.net.continuous import ContinuousActorProbabilistic def correct_log_prob_gaussian_tanh( @@ -55,7 +55,7 @@ class SACPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, - actor: torch.nn.Module | ContinuousActorProb, + actor: torch.nn.Module | ContinuousActorProbabilistic, exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, action_scaling: bool = True, diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index cca3e4fdf..9100a37e8 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -180,7 +180,7 @@ def __init__( gamma=gamma, estimation_step=estimation_step, ) - self.actor_old = self._add_lagged_network(self.policy.actor) # type: ignore[has-type] + self.actor_old = self._add_lagged_network(self.policy.actor) self.policy_noise = policy_noise self.update_actor_freq = update_actor_freq self.noise_clip = noise_clip diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 2bbe43e76..953da5b0b 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -9,7 +9,7 @@ from tianshou.data.types import BatchWithAdvantagesProtocol from tianshou.policy import NPG from tianshou.policy.modelfree.npg import NPGTrainingStats -from tianshou.policy.modelfree.pg import ActorPolicy +from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @@ -26,7 +26,7 @@ class TRPO(NPG): def __init__( self, *, - policy: ActorPolicy, + policy: ActorPolicyProbabilistic, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, max_kl: float = 0.01, diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index de5c35d8d..4ef5ab93c 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -562,16 +562,17 @@ def execute_epoch(self) -> EpochStats: t.update(training_step_result.get_steps_in_epoch_advancement()) self._stop_fn_flag = training_step_result.is_training_done() self._env_step += training_step_result.get_env_step_advancement() + training_stats = training_step_result.get_training_stats() + assert training_stats is not None TraceLogger.log( log, - lambda: f"Training step complete: stats={training_step_result.get_training_stats().get_loss_stats_dict()}", + lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict()}", ) self._log_params(self.algorithm) collect_stats = training_step_result.get_collect_stats() if collect_stats is not None: self._logger.log_train_data(asdict(collect_stats), self._env_step) - training_stats = training_step_result.get_training_stats() pbar_data_dict = self._create_epoch_pbar_data_dict(training_step_result) pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 2c4a3834d..1516bc936 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -663,21 +663,44 @@ def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: class Actor(ModuleWithVectorOutput, ABC): @abstractmethod def get_preprocess_net(self) -> ModuleWithVectorOutput: - pass + """Typically a first part of the network that preprocesses the input into a latent representation. + E.g., a CNN (often used in atari examples). We need this method to be able to + share latent representation with other networks (e.g., critic) within an Algorithm. + Networks that don't have this can use nn.Identity() as a preprocess net (see :class:`RandomActor`). + """ @abstractmethod def forward( self, obs: np.ndarray | torch.Tensor, - state: Any = None, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[Any, Any]: - # TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction. - # Return type needs to be more specific - pass + ) -> tuple[np.ndarray | torch.Tensor | Sequence[torch.Tensor], T | None]: + """ + The main method for tianshou to compute actions from env observations. + Implementations will always make use of the preprocess_net as the first processing step. + + :param obs: the observation from the environment + :param state: the hidden state of the RNN, if applicable + :param info: the info object from the environment step + :return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or + a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution), + and hidden_state is the new hidden state of the RNN, if applicable. + """ + + +class ContinuousActorProbabilisticInterface(Actor, ABC): + """Marker interface for probabilistic actors defined by users (outside of Tianshou code).""" + + +class DiscreteActorInterface(Actor, ABC): + """Marker interface for discrete actors defined by users (outside of Tianshou code). + + See docstring of :class:`DiscreteActor` + """ -class RandomActor(Actor): +class RandomActor(ContinuousActorProbabilisticInterface, DiscreteActorInterface): """An actor that returns random actions. For continuous action spaces, forward returns a batch of random actions sampled from the action space. diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 87fcba69e..fc06bc455 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,7 +1,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any +from typing import Any, TypeVar import numpy as np import torch @@ -11,6 +11,7 @@ from tianshou.utils.net.common import ( MLP, Actor, + ContinuousActorProbabilisticInterface, ModuleWithVectorOutput, TActionShape, TLinearLayer, @@ -20,15 +21,20 @@ SIGMA_MIN = -20 SIGMA_MAX = 2 +T = TypeVar("T") -class ContinuousActorDeterministic(Actor): - """Simple actor network that directly outputs actions for continuous action space. - Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`. + +class ContinuousActorDeterministicInterface(Actor, ABC): + """Marker interface for continuous deterministic actors (DDPG like).""" + + +class ContinuousActorDeterministic(ContinuousActorDeterministicInterface): + """Actor network that directly outputs actions for continuous action space. + Used primarily in DDPG and its variants. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. - :param preprocess_net: a self-defined preprocess_net, see usage. - Typically, an instance of :class:`~tianshou.utils.net.common.Net`. + :param preprocess_net: first part of input processing. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. @@ -66,9 +72,9 @@ def get_output_dim(self) -> int: def forward( self, obs: np.ndarray | torch.Tensor, - state: Any = None, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: + ) -> tuple[torch.Tensor, T | None]: """Mapping: s_B -> action_values_BA, hidden_state_BH | None. Returns a tensor representing the actions directly, i.e, of shape @@ -168,7 +174,7 @@ def forward( return self.last(obs) -class ContinuousActorProb(Actor): +class ContinuousActorProbabilistic(ContinuousActorProbabilisticInterface): """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian). Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. @@ -222,9 +228,9 @@ def get_preprocess_net(self) -> ModuleWithVectorOutput: def forward( self, obs: np.ndarray | torch.Tensor, - state: Any = None, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]: + ) -> tuple[tuple[torch.Tensor, torch.Tensor], T | None]: """Mapping: obs -> logits -> (mu, sigma).""" if info is None: info = {} diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 2d3da153d..e1405be1f 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any +from typing import Any, TypeVar import numpy as np import torch @@ -9,20 +9,30 @@ from tianshou.data import Batch, to_torch from tianshou.utils.net.common import ( MLP, - Actor, + DiscreteActorInterface, ModuleWithVectorOutput, TActionShape, ) from tianshou.utils.torch_utils import torch_device +T = TypeVar("T") + def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions.Categorical: """Default distribution function for categorical actors.""" return torch.distributions.Categorical(logits=logits) -class DiscreteActor(Actor): - """Simple actor network for discrete action spaces.""" +class DiscreteActor(DiscreteActorInterface): + """For on-policy algos like Reinforce, this usually directly outputs unnormalized log + probabilities. + + In Tianshou, discrete actors are also used for computing action distributions within + Q-learning type algorithms, discrete actors + typically the values of the Q function for each action (as tensor), + which are then later re-interpreted as unnormalized log-probabilities for sampling + discrete actions. So such an actor is essentially a critic. + """ def __init__( self, @@ -59,13 +69,14 @@ def get_preprocess_net(self) -> ModuleWithVectorOutput: def forward( self, obs: np.ndarray | torch.Tensor, - state: Any = None, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None. + ) -> tuple[torch.Tensor, T | None]: + r"""Mapping: (s_B, ...) -> action_values_BA, hidden_state_BH | None. + Returns a tensor representing the values of each action, i.e, of shape - `(n_actions, )`, and + `(n_actions, )` (see class docstring for more info on the meaning of that), and a hidden state (which may be None). If `self.softmax_output` is True, they are the probabilities for taking each action. Otherwise, they will be action values. The hidden state is only From 16eda46cadf9a775c898af6c027251d817cd4fb0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 15 May 2025 19:13:19 +0200 Subject: [PATCH 165/230] v2: remove irrelevant action_bound param in SACPolicy --- tianshou/policy/modelfree/sac.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index fa96dce67..bb5489f0e 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -59,7 +59,6 @@ def __init__( exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, action_scaling: bool = True, - action_bound_method: Literal["clip"] | None = "clip", action_space: gym.Space, observation_space: gym.Space | None = None, ): @@ -92,22 +91,6 @@ def __init__( across environments with different action ranges, and standardizes exploration strategies. Should be disabled if the actor model already produces outputs in the correct range. - :param action_bound_method: the method used for bounding actions in continuous action spaces - to the range [-1, 1] before scaling them to the environment's action space (provided - that `action_scaling` is enabled). - This applies to continuous action spaces only (`gym.spaces.Box`) and should be set to None - for discrete spaces. - When set to "clip", actions exceeding the [-1, 1] range are simply clipped to this - range. When set to "tanh", a hyperbolic tangent function is applied, which smoothly - constrains outputs to [-1, 1] while preserving gradients. - The choice of bounding method affects both training dynamics and exploration behavior. - Clipping provides hard boundaries but may create plateau regions in the gradient - landscape, while tanh provides smoother transitions but can compress sensitivity - near the boundaries. - Should be set to None if the actor model inherently produces bounded outputs. - Typically used together with `action_scaling=True`. - NOTE: This parameter has negligible effect since actions are already bounded by tanh - squashing in the forward method (as in arXiv 1801.01290, Equation 21). :param action_space: the environment's action_space. :param observation_space: the environment's observation space """ @@ -116,7 +99,8 @@ def __init__( action_space=action_space, observation_space=observation_space, action_scaling=action_scaling, - action_bound_method=action_bound_method, + # actions already squashed by tanh + action_bound_method=None, ) self.actor = actor self.deterministic_eval = deterministic_eval From e4e7b7586dfa63fa2640496911ff46ebac320fe9 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 19:44:16 +0200 Subject: [PATCH 166/230] v2: Change parameter clip_loss_grad to huber_loss_delta (allowing to control not only the use of the Huber loss but also its essential parameter) --- CHANGELOG.md | 1 + tianshou/policy/modelfree/dqn.py | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd389c68b..82edf3c8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,6 +107,7 @@ Developers: * `return_scaling` in actor-critic on-policy algorithms (A2C, PPO, GAIL, NPG, TRPO) * removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) * `clip_grad` -> `max_grad_norm` (for consistency) + * `clip_loss_grad` -> `huber_loss_delta` (allowing to control not only the use of the Huber loss but also its essential parameter) * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c8fd70907..209ab8269 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -301,7 +301,7 @@ def __init__( estimation_step: int = 1, target_update_freq: int = 0, is_double: bool = True, - clip_loss_grad: bool = False, + huber_loss_delta: float | None = None, ) -> None: """ :param policy: the policy @@ -338,11 +338,14 @@ def __init__( If False, the algorithm follows the vanilla DQN method that directly takes the maximum Q-value from the target network. Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). - :param clip_loss_grad: flag indicating whether to use the Huber loss instead of the MSE loss for the TD error. - If True, uses the Huber loss as described in the Nature DQN paper (nature14236), which limits the influence - of outliers. Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber + :param huber_loss_delta: controls whether to use the Huber loss instead of the MSE loss for the TD error + and the threshold for the Huber loss. + If None, the MSE loss is used. + If not None, uses the Huber loss as described in the Nature DQN paper (nature14236) with the given delta, + which limits the influence of outliers. + Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber loss causes the gradients to plateau at a constant value for large errors, providing more stable training. - If False, uses the standard MSE loss where the gradient magnitude continues to scale with the error size. + NOTE: The magnitude of delta should depend on the scale of the returns obtained in the environment. """ super().__init__( policy=policy, @@ -352,7 +355,7 @@ def __init__( target_update_freq=target_update_freq, ) self.is_double = is_double - self.clip_loss_grad = clip_loss_grad + self.huber_loss_delta = huber_loss_delta def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( @@ -381,10 +384,12 @@ def _update_with_batch( returns = to_torch_as(batch.returns.flatten(), q) td_error = returns - q - if self.clip_loss_grad: + if self.huber_loss_delta is not None: y = q.reshape(-1, 1) t = returns.reshape(-1, 1) - loss = torch.nn.functional.huber_loss(y, t, reduction="mean") + loss = torch.nn.functional.huber_loss( + y, t, delta=self.huber_loss_delta, reduction="mean" + ) else: loss = (td_error.pow(2) * weight).mean() From 99d320468cbe0eabe607cd4016da66eadbdbb4eb Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 22:33:52 +0200 Subject: [PATCH 167/230] v2: Rename package policy -> algorithm --- docs/01_tutorials/01_concepts.rst | 22 +- docs/01_tutorials/04_tictactoe.rst | 12 +- docs/01_tutorials/07_cheatsheet.rst | 2 +- docs/02_notebooks/L0_overview.ipynb | 2 +- docs/02_notebooks/L4_Policy.ipynb | 4 +- docs/02_notebooks/L5_Collector.ipynb | 2 +- docs/02_notebooks/L6_Trainer.ipynb | 2 +- docs/02_notebooks/L7_Experiment.ipynb | 2 +- docs/index.rst | 62 ++-- examples/atari/atari_c51.py | 8 +- examples/atari/atari_dqn.py | 10 +- examples/atari/atari_fqf.py | 8 +- examples/atari/atari_iqn.py | 8 +- examples/atari/atari_ppo.py | 10 +- examples/atari/atari_qrdqn.py | 8 +- examples/atari/atari_rainbow.py | 8 +- examples/atari/atari_sac.py | 10 +- examples/box2d/acrobot_dualdqn.py | 8 +- examples/box2d/bipedal_bdq.py | 8 +- examples/box2d/bipedal_hardcore_sac.py | 8 +- examples/box2d/lunarlander_dqn.py | 8 +- examples/box2d/mcc_sac.py | 8 +- examples/discrete/discrete_dqn.py | 4 +- examples/inverse/irl_gail.py | 8 +- examples/mujoco/fetch_her_ddpg.py | 8 +- examples/mujoco/mujoco_a2c.py | 8 +- examples/mujoco/mujoco_ddpg.py | 8 +- examples/mujoco/mujoco_npg.py | 8 +- examples/mujoco/mujoco_ppo.py | 8 +- examples/mujoco/mujoco_redq.py | 10 +- examples/mujoco/mujoco_reinforce.py | 8 +- examples/mujoco/mujoco_sac.py | 8 +- examples/mujoco/mujoco_td3.py | 8 +- examples/mujoco/mujoco_trpo.py | 8 +- examples/offline/atari_bcq.py | 8 +- examples/offline/atari_cql.py | 8 +- examples/offline/atari_crr.py | 8 +- examples/offline/atari_il.py | 6 +- examples/offline/d4rl_bcq.py | 8 +- examples/offline/d4rl_cql.py | 8 +- examples/offline/d4rl_il.py | 6 +- examples/offline/d4rl_td3_bc.py | 8 +- examples/vizdoom/vizdoom_c51.py | 8 +- examples/vizdoom/vizdoom_ppo.py | 10 +- test/base/test_collector.py | 2 +- test/base/test_env_finite.py | 2 +- test/base/test_policy.py | 8 +- test/base/test_returns.py | 2 +- test/base/test_stats.py | 2 +- test/continuous/test_ddpg.py | 8 +- test/continuous/test_npg.py | 8 +- test/continuous/test_ppo.py | 8 +- test/continuous/test_redq.py | 10 +- test/continuous/test_sac_with_il.py | 10 +- test/continuous/test_td3.py | 8 +- test/continuous/test_trpo.py | 8 +- test/discrete/test_a2c_with_il.py | 10 +- test/discrete/test_bdqn.py | 6 +- test/discrete/test_c51.py | 8 +- test/discrete/test_discrete_sac.py | 10 +- test/discrete/test_dqn.py | 8 +- test/discrete/test_drqn.py | 8 +- test/discrete/test_fqf.py | 8 +- test/discrete/test_iqn.py | 8 +- test/discrete/test_pg.py | 8 +- test/discrete/test_ppo_discrete.py | 8 +- test/discrete/test_qrdqn.py | 8 +- test/discrete/test_rainbow.py | 8 +- test/modelbased/test_dqn_icm.py | 6 +- test/modelbased/test_ppo_icm.py | 10 +- test/modelbased/test_psrl.py | 4 +- test/offline/gather_cartpole_data.py | 8 +- test/offline/gather_pendulum_data.py | 8 +- test/offline/test_bcq.py | 6 +- test/offline/test_cql.py | 6 +- test/offline/test_discrete_bcq.py | 6 +- test/offline/test_discrete_cql.py | 6 +- test/offline/test_discrete_crr.py | 6 +- test/offline/test_gail.py | 6 +- test/offline/test_td3_bc.py | 8 +- test/pettingzoo/pistonball.py | 8 +- test/pettingzoo/pistonball_continuous.py | 10 +- test/pettingzoo/tic_tac_toe.py | 8 +- tianshou/__init__.py | 4 +- tianshou/algorithm/__init__.py | 35 +++ tianshou/{policy => algorithm}/base.py | 2 +- .../imitation/__init__.py | 0 .../{policy => algorithm}/imitation/base.py | 6 +- .../{policy => algorithm}/imitation/bcq.py | 4 +- .../{policy => algorithm}/imitation/cql.py | 6 +- .../imitation/discrete_bcq.py | 8 +- .../imitation/discrete_cql.py | 10 +- .../imitation/discrete_crr.py | 6 +- .../{policy => algorithm}/imitation/gail.py | 8 +- .../{policy => algorithm}/imitation/td3_bc.py | 10 +- .../modelbased/__init__.py | 0 .../{policy => algorithm}/modelbased/icm.py | 6 +- .../{policy => algorithm}/modelbased/psrl.py | 2 +- .../modelfree/__init__.py | 0 .../{policy => algorithm}/modelfree/a2c.py | 6 +- .../{policy => algorithm}/modelfree/bdqn.py | 8 +- .../{policy => algorithm}/modelfree/c51.py | 6 +- .../{policy => algorithm}/modelfree/ddpg.py | 6 +- .../modelfree/discrete_sac.py | 8 +- .../{policy => algorithm}/modelfree/dqn.py | 8 +- .../{policy => algorithm}/modelfree/fqf.py | 10 +- .../{policy => algorithm}/modelfree/iqn.py | 8 +- .../{policy => algorithm}/modelfree/npg.py | 8 +- .../{policy => algorithm}/modelfree/pg.py | 6 +- .../{policy => algorithm}/modelfree/ppo.py | 8 +- .../{policy => algorithm}/modelfree/qrdqn.py | 6 +- .../modelfree/rainbow.py | 6 +- .../{policy => algorithm}/modelfree/redq.py | 6 +- .../{policy => algorithm}/modelfree/sac.py | 8 +- .../{policy => algorithm}/modelfree/td3.py | 6 +- .../{policy => algorithm}/modelfree/trpo.py | 8 +- .../multiagent/__init__.py | 0 .../multiagent/mapolicy.py | 4 +- tianshou/{policy => algorithm}/optim.py | 278 +++++++++--------- tianshou/{policy => algorithm}/random.py | 4 +- tianshou/data/collector.py | 4 +- tianshou/data/stats.py | 2 +- tianshou/env/atari/atari_network.py | 2 +- tianshou/highlevel/algorithm.py | 18 +- tianshou/highlevel/experiment.py | 2 +- tianshou/highlevel/module/actor.py | 2 +- tianshou/highlevel/optim.py | 2 +- tianshou/highlevel/params/alpha.py | 2 +- tianshou/highlevel/params/dist_fn.py | 2 +- tianshou/highlevel/params/lr_scheduler.py | 2 +- tianshou/highlevel/params/policy_wrapper.py | 6 +- tianshou/highlevel/trainer.py | 4 +- tianshou/highlevel/world.py | 2 +- tianshou/policy/__init__.py | 35 --- tianshou/trainer/base.py | 2 +- tianshou/utils/torch_utils.py | 2 +- 136 files changed, 638 insertions(+), 638 deletions(-) create mode 100644 tianshou/algorithm/__init__.py rename tianshou/{policy => algorithm}/base.py (99%) rename tianshou/{policy => algorithm}/imitation/__init__.py (100%) rename tianshou/{policy => algorithm}/imitation/base.py (98%) rename tianshou/{policy => algorithm}/imitation/bcq.py (99%) rename tianshou/{policy => algorithm}/imitation/cql.py (99%) rename tianshou/{policy => algorithm}/imitation/discrete_bcq.py (98%) rename tianshou/{policy => algorithm}/imitation/discrete_cql.py (94%) rename tianshou/{policy => algorithm}/imitation/discrete_crr.py (98%) rename tianshou/{policy => algorithm}/imitation/gail.py (98%) rename tianshou/{policy => algorithm}/imitation/td3_bc.py (95%) rename tianshou/{policy => algorithm}/modelbased/__init__.py (100%) rename tianshou/{policy => algorithm}/modelbased/icm.py (98%) rename tianshou/{policy => algorithm}/modelbased/psrl.py (99%) rename tianshou/{policy => algorithm}/modelfree/__init__.py (100%) rename tianshou/{policy => algorithm}/modelfree/a2c.py (98%) rename tianshou/{policy => algorithm}/modelfree/bdqn.py (97%) rename tianshou/{policy => algorithm}/modelfree/c51.py (97%) rename tianshou/{policy => algorithm}/modelfree/ddpg.py (99%) rename tianshou/{policy => algorithm}/modelfree/discrete_sac.py (97%) rename tianshou/{policy => algorithm}/modelfree/dqn.py (98%) rename tianshou/{policy => algorithm}/modelfree/fqf.py (97%) rename tianshou/{policy => algorithm}/modelfree/iqn.py (97%) rename tianshou/{policy => algorithm}/modelfree/npg.py (97%) rename tianshou/{policy => algorithm}/modelfree/pg.py (99%) rename tianshou/{policy => algorithm}/modelfree/ppo.py (98%) rename tianshou/{policy => algorithm}/modelfree/qrdqn.py (97%) rename tianshou/{policy => algorithm}/modelfree/rainbow.py (96%) rename tianshou/{policy => algorithm}/modelfree/redq.py (99%) rename tianshou/{policy => algorithm}/modelfree/sac.py (98%) rename tianshou/{policy => algorithm}/modelfree/td3.py (98%) rename tianshou/{policy => algorithm}/modelfree/trpo.py (98%) rename tianshou/{policy => algorithm}/multiagent/__init__.py (100%) rename tianshou/{policy => algorithm}/multiagent/mapolicy.py (99%) rename tianshou/{policy => algorithm}/optim.py (97%) rename tianshou/{policy => algorithm}/random.py (95%) delete mode 100644 tianshou/policy/__init__.py diff --git a/docs/01_tutorials/01_concepts.rst b/docs/01_tutorials/01_concepts.rst index 5107bd690..aa244b137 100644 --- a/docs/01_tutorials/01_concepts.rst +++ b/docs/01_tutorials/01_concepts.rst @@ -223,16 +223,16 @@ Tianshou provides other type of data buffer such as :class:`~tianshou.data.Prior Policy ------ -Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. +Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.algorithm.BasePolicy`. A policy class typically has the following parts: -* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on; -* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation; -* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; -* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. -* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. -* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. +* :meth:`~tianshou.algorithm.BasePolicy.__init__`: initialize the policy, including copying the target network and so on; +* :meth:`~tianshou.algorithm.BasePolicy.forward`: compute action with given observation; +* :meth:`~tianshou.algorithm.BasePolicy.process_fn`: pre-process data from the replay buffer; +* :meth:`~tianshou.algorithm.BasePolicy.learn`: update policy with a given batch of data. +* :meth:`~tianshou.algorithm.BasePolicy.post_process_fn`: update the buffer with a given batch of data. +* :meth:`~tianshou.algorithm.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. .. _policy_state: @@ -245,7 +245,7 @@ During the training process, the policy has two main states: training state and The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process. As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer; -we define the updating state as performing a model update by :meth:`~tianshou.policy.BasePolicy.update` during training process. +we define the updating state as performing a model update by :meth:`~tianshou.algorithm.BasePolicy.update` during training process. In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows: @@ -270,7 +270,7 @@ The ``forward`` function computes the action over given observations. The input The input batch is the environment data (e.g., observation, reward, done flag and info). It comes from either :meth:`~tianshou.data.Collector.collect` or :meth:`~tianshou.data.ReplayBuffer.sample`. The first dimension of all variables in the input ``batch`` should be equal to the batch-size. -The output is also a ``Batch`` which must contain "act" (action) and may contain "state" (hidden state of policy), "policy" (the intermediate result of policy which needs to save into the buffer, see :meth:`~tianshou.policy.BasePolicy.forward`), and some other algorithm-specific keys. +The output is also a ``Batch`` which must contain "act" (action) and may contain "state" (hidden state of policy), "policy" (the intermediate result of policy which needs to save into the buffer, see :meth:`~tianshou.algorithm.BasePolicy.forward`), and some other algorithm-specific keys. For example, if you try to use your policy to evaluate one episode (and don't want to use :meth:`~tianshou.data.Collector.collect`), use the following code-snippet: :: @@ -317,7 +317,7 @@ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is # update DQN policy agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) -Thus, we need a time-related interface for calculating the 2-step return. :meth:`~tianshou.policy.BasePolicy.process_fn` finishes this work by providing the replay buffer, the sample index, and the sample batch data. Since we store all the data in the order of time, you can simply compute the 2-step return as: +Thus, we need a time-related interface for calculating the 2-step return. :meth:`~tianshou.algorithm.BasePolicy.process_fn` finishes this work by providing the replay buffer, the sample index, and the sample batch data. Since we store all the data in the order of time, you can simply compute the 2-step return as: :: class DQN_2step(BasePolicy): @@ -337,7 +337,7 @@ Thus, we need a time-related interface for calculating the 2-step return. :meth: + self._gamma ** 2 * maxQ return batch -This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.policy.BasePolicy.process_fn`. +This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.algorithm.BasePolicy.process_fn`. For other method, you can check out :doc:`/03_api/policy/index`. We give the usage of policy class a high-level explanation in :ref:`pseudocode`. diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst index ff7918e1f..b370c7728 100644 --- a/docs/01_tutorials/04_tictactoe.rst +++ b/docs/01_tutorials/04_tictactoe.rst @@ -122,13 +122,13 @@ Two Random Agents .. Figure:: ../_static/images/marl.png -Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.MARLRandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation. +Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.algorithm.MARLRandomPolicy` and :class:`~tianshou.algorithm.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation. :: >>> from tianshou.data import Collector >>> from tianshou.env import DummyVectorEnv - >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager + >>> from tianshou.algorithm import RandomPolicy, MultiAgentPolicyManager >>> >>> # agents should be wrapped into one policy, >>> # which is responsible for calling the acting agent correctly @@ -198,7 +198,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv - from tianshou.policy import ( + from tianshou.algorithm import ( BasePolicy, DQNPolicy, MultiAgentPolicyManager, @@ -285,10 +285,10 @@ The explanation of each Tianshou class/function will be deferred to their first The following ``get_agents`` function returns agents and their optimizers from either constructing a new policy, or loading from disk, or using the pass-in arguments. For the models: - The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function; -- The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; -- The opponent can be either a random agent :class:`~tianshou.policy.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves. +- The network model is passed to a :class:`~tianshou.algorithm.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; +- The opponent can be either a random agent :class:`~tianshou.algorithm.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.algorithm.DQNPolicy` allowing learned agents to play with themselves. -Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.policy.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment. +Both agents are passed to :class:`~tianshou.algorithm.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.algorithm.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment. Here it is: :: diff --git a/docs/01_tutorials/07_cheatsheet.rst b/docs/01_tutorials/07_cheatsheet.rst index 51fece131..79391747d 100644 --- a/docs/01_tutorials/07_cheatsheet.rst +++ b/docs/01_tutorials/07_cheatsheet.rst @@ -23,7 +23,7 @@ See :ref:`build_the_network`. Build New Policy ---------------- -See :class:`~tianshou.policy.BasePolicy`. +See :class:`~tianshou.algorithm.BasePolicy`. .. _eval_policy: diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index a9bf617bc..b38ead76b 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -60,7 +60,7 @@ "\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PPOPolicy\n", + "from tianshou.algorithm import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import ActorCritic, Net\n", "from tianshou.utils.net.discrete import Actor, Critic\n", diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index eed8ea344..2707e2869 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -66,8 +66,8 @@ " ObsBatchProtocol,\n", " RolloutBatchProtocol,\n", ")\n", - "from tianshou.policy import BasePolicy\n", - "from tianshou.policy.modelfree.pg import (\n", + "from tianshou.algorithm import BasePolicy\n", + "from tianshou.algorithm.modelfree.pg import (\n", " PGTrainingStats,\n", " TDistFnDiscrOrCont,\n", " TPGTrainingStats,\n", diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index d10df1666..04f259722 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -60,7 +60,7 @@ "\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PGPolicy\n", + "from tianshou.algorithm import PGPolicy\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor" ] diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index c10cfcbe2..58fa3d40e 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -75,7 +75,7 @@ "\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PGPolicy\n", + "from tianshou.algorithm import PGPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor\n", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 47e4cb0c9..c30fa08f7 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -73,7 +73,7 @@ "\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import PPOPolicy\n", + "from tianshou.algorithm import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import ActorCritic, Net\n", "from tianshou.utils.net.discrete import Actor, Critic\n", diff --git a/docs/index.rst b/docs/index.rst index c7c217759..4bfbbdd17 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,37 +9,37 @@ Welcome to Tianshou! **Tianshou** (`天授 `_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include: -* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network `_ -* :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ -* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ -* :class:`~tianshou.policy.BranchingDQNPolicy` `Branching DQN `_ -* :class:`~tianshou.policy.C51Policy` `Categorical DQN `_ -* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN `_ -* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ -* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network `_ -* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function `_ -* :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ -* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient `_ -* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ -* :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization `_ -* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ -* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ -* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ -* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ -* :class:`~tianshou.policy.REDQPolicy` `Randomized Ensembled Double Q-Learning `_ -* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ -* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning -* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning `_ -* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning `_ -* :class:`~tianshou.policy.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning `_ -* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ -* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ -* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ -* :class:`~tianshou.policy.GAILPolicy` `Generative Adversarial Imitation Learning `_ -* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ -* :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module `_ +* :class:`~tianshou.algorithm.DQNPolicy` `Deep Q-Network `_ +* :class:`~tianshou.algorithm.DQNPolicy` `Double DQN `_ +* :class:`~tianshou.algorithm.DQNPolicy` `Dueling DQN `_ +* :class:`~tianshou.algorithm.BranchingDQNPolicy` `Branching DQN `_ +* :class:`~tianshou.algorithm.C51Policy` `Categorical DQN `_ +* :class:`~tianshou.algorithm.RainbowPolicy` `Rainbow DQN `_ +* :class:`~tianshou.algorithm.QRDQNPolicy` `Quantile Regression DQN `_ +* :class:`~tianshou.algorithm.IQNPolicy` `Implicit Quantile Network `_ +* :class:`~tianshou.algorithm.FQFPolicy` `Fully-parameterized Quantile Function `_ +* :class:`~tianshou.algorithm.PGPolicy` `Policy Gradient `_ +* :class:`~tianshou.algorithm.NPGPolicy` `Natural Policy Gradient `_ +* :class:`~tianshou.algorithm.A2CPolicy` `Advantage Actor-Critic `_ +* :class:`~tianshou.algorithm.TRPOPolicy` `Trust Region Policy Optimization `_ +* :class:`~tianshou.algorithm.PPOPolicy` `Proximal Policy Optimization `_ +* :class:`~tianshou.algorithm.DDPGPolicy` `Deep Deterministic Policy Gradient `_ +* :class:`~tianshou.algorithm.TD3Policy` `Twin Delayed DDPG `_ +* :class:`~tianshou.algorithm.SACPolicy` `Soft Actor-Critic `_ +* :class:`~tianshou.algorithm.REDQPolicy` `Randomized Ensembled Double Q-Learning `_ +* :class:`~tianshou.algorithm.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ +* :class:`~tianshou.algorithm.ImitationPolicy` Imitation Learning +* :class:`~tianshou.algorithm.BCQPolicy` `Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.algorithm.CQLPolicy` `Conservative Q-Learning `_ +* :class:`~tianshou.algorithm.TD3BCPolicy` `Twin Delayed DDPG with Behavior Cloning `_ +* :class:`~tianshou.algorithm.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.algorithm.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ +* :class:`~tianshou.algorithm.DiscreteCRRPolicy` `Critic Regularized Regression `_ +* :class:`~tianshou.algorithm.GAILPolicy` `Generative Adversarial Imitation Learning `_ +* :class:`~tianshou.algorithm.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ +* :class:`~tianshou.algorithm.ICMPolicy` `Intrinsic Curiosity Module `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ -* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ +* :meth:`~tianshou.algorithm.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ * :class:`~tianshou.data.HERReplayBuffer` `Hindsight Experience Replay `_ Here is Tianshou's other features: @@ -51,7 +51,7 @@ Here is Tianshou's other features: * Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` * Support any type of environment state/action (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` * Support :ref:`customize_training` -* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation +* Support n-step returns estimation :meth:`~tianshou.algorithm.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support :doc:`/01_tutorials/04_tictactoe` * Support both `TensorBoard `_ and `W&B `_ log tools * Support multi-GPU training :ref:`multi_gpu` diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index c9106570c..763d7b107 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -11,10 +11,10 @@ from tianshou.env.atari.atari_network import C51Net from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51 -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.c51 import C51Policy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import C51 +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 97a6d1299..5957e4ee5 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -11,11 +11,11 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import IntrinsicCuriosityModule diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 1a466c651..327f7a51c 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -11,10 +11,10 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import FQF -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.fqf import FQFPolicy -from tianshou.policy.optim import AdamOptimizerFactory, RMSpropOptimizerFactory +from tianshou.algorithm import FQF +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.fqf import FQFPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index d4e0fe005..f5486ccd9 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -11,10 +11,10 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import IQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.iqn import IQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import IQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.iqn import IQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import ImplicitQuantileNetwork diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 4060f55e7..213ad31b2 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -18,11 +18,11 @@ ) from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PPO -from tianshou.policy.base import Algorithm -from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.policy.modelfree.pg import DiscreteActorPolicy -from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.algorithm import PPO +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 7b0c61e7c..049fe42c5 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -11,10 +11,10 @@ from tianshou.env.atari.atari_network import QRDQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import QRDQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import QRDQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 4c885f0cf..5692d81ea 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -16,10 +16,10 @@ from tianshou.env.atari.atari_network import Rainbow from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51, RainbowDQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.c51 import C51Policy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import C51, RainbowDQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 7c44c7ad6..249e31c02 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -11,11 +11,11 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteSAC, ICMOffPolicyWrapper -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy -from tianshou.policy.modelfree.sac import AutoAlpha -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DiscreteSAC, ICMOffPolicyWrapper +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.algorithm.modelfree.sac import AutoAlpha +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 1f1f16c1c..f70d01735 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -9,10 +9,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index d53999193..e9230bc3e 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -10,10 +10,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv -from tianshou.policy import BDQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.bdqn import BDQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import BDQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.bdqn import BDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 72aae0b34..213f1b08f 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -11,10 +11,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.policy import SAC -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import SAC +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 5744451d1..bbfded999 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -9,10 +9,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import DQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 03eeec085..5f242bc79 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -10,10 +10,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise -from tianshou.policy import SAC -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import SAC +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 3fbfb7801..f98520143 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -3,8 +3,8 @@ import tianshou as ts from tianshou.data import CollectStats -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils.space_info import SpaceInfo diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 56ed5b445..9049a691b 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -23,10 +23,10 @@ ) from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs -from tianshou.policy import GAIL -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.algorithm import GAIL +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 061edab8f..acd652662 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -22,10 +22,10 @@ from tianshou.env.venvs import BaseVectorEnv from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DDPG -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DDPG +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index a7f2a6eac..f4acfb1d9 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -13,10 +13,10 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import A2C -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory +from tianshou.algorithm import A2C +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 56f1a233f..471616174 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -12,10 +12,10 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DDPG -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DDPG +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 86f1d30df..a64891a3f 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -13,10 +13,10 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import NPG -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.algorithm import NPG +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index ffbaa1d3a..57bb14011 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -13,10 +13,10 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PPO -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.algorithm import PPO +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 09fdadad3..88337eeb4 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -11,11 +11,11 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import REDQ -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.redq import REDQPolicy -from tianshou.policy.modelfree.sac import AutoAlpha -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import REDQ +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.redq import REDQPolicy +from tianshou.algorithm.modelfree.sac import AutoAlpha +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 5cbb77ef9..0718a6002 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -13,10 +13,10 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import Reinforce -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.algorithm import Reinforce +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 685276af6..d301cc3e3 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -11,10 +11,10 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import SAC -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import SAC +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 06d974535..379bcc34d 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -12,10 +12,10 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import TD3 -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import TD3 +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 6f291cf47..63958a761 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -13,10 +13,10 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import TRPO -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.algorithm import TRPO +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 5ef5aa59f..0f71ca265 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -16,10 +16,10 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteBCQ -from tianshou.policy.base import Algorithm -from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DiscreteBCQ +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils.net.discrete import DiscreteActor diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 937e3d46e..9cdaef5bc 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -17,10 +17,10 @@ from tianshou.env.atari.atari_network import QRDQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteCQL -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DiscreteCQL +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 9447e90dc..2e303b632 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -16,10 +16,10 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import DiscreteCRR -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import DiscreteActorPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DiscreteCRR +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 885267efa..4f0fd3326 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -15,9 +15,9 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy.base import Algorithm -from tianshou.policy.imitation.base import ImitationPolicy, OfflineImitationLearning -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.imitation.base import ImitationPolicy, OfflineImitationLearning +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 252d915e8..9b151eb74 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -13,10 +13,10 @@ from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import BCQ -from tianshou.policy.base import Algorithm -from tianshou.policy.imitation.bcq import BCQPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import BCQ +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.imitation.bcq import BCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index b9c3d17b9..bbb57f5f7 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -13,10 +13,10 @@ from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy import CQL -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import CQL +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 8a4a9c520..b6ec5cf0f 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -13,9 +13,9 @@ from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.policy.base import Algorithm -from tianshou.policy.imitation.base import ImitationPolicy, OfflineImitationLearning -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.imitation.base import ImitationPolicy, OfflineImitationLearning +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 7a1224c2a..1d4ac1036 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -14,10 +14,10 @@ from tianshou.data import Collector, CollectStats from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3BC -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import TD3BC +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 4a867c5e8..019151be2 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -11,10 +11,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import C51Net from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import C51 -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.c51 import C51Policy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import C51 +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index f6ff24e93..6ce17c17c 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -12,11 +12,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.policy import PPO -from tianshou.policy.base import Algorithm -from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.algorithm import PPO +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 0d47c456f..5636e73de 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -25,7 +25,7 @@ ) from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy.base import Policy, episode_mc_return_to_go +from tianshou.algorithm.base import Policy, episode_mc_return_to_go try: import envpool diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 7e6065ada..a6daf7010 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -19,7 +19,7 @@ ) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type -from tianshou.policy.base import Policy +from tianshou.algorithm.base import Policy class DummyDataset(Dataset): diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 559c685e2..ed9b3a79c 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -5,10 +5,10 @@ from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.data import Batch -from tianshou.policy import PPO -from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import PPO +from tianshou.algorithm.base import RandomActionPolicy, episode_mc_return_to_go +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.net.discrete import DiscreteActor diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 8b05319d0..b1a409068 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -5,7 +5,7 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import Algorithm +from tianshou.algorithm import Algorithm def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: diff --git a/test/base/test_stats.py b/test/base/test_stats.py index 821152e83..0efa284a7 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -7,7 +7,7 @@ from tianshou.data import Batch, CollectStats from tianshou.data.collector import CollectStepBatchProtocol, get_stddev_from_dist -from tianshou.policy.base import TrainingStats, TrainingStatsWrapper +from tianshou.algorithm.base import TrainingStats, TrainingStatsWrapper class DummyTrainingStatsWrapper(TrainingStatsWrapper): diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 518d53ef5..108480095 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -10,10 +10,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import DDPG -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DDPG +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 35f6b58f9..f13e05c1a 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -11,10 +11,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import NPG -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import NPG +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index c265d2955..ed1b91292 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -10,10 +10,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PPO -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import PPO +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 1662f2abe..0c9fc3c65 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -10,11 +10,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import REDQ -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.redq import REDQPolicy -from tianshou.policy.modelfree.sac import AutoAlpha -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import REDQ +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.redq import REDQPolicy +from tianshou.algorithm.modelfree.sac import AutoAlpha +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index b2d9e7f77..ab0928491 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -9,11 +9,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import SAC, OffPolicyImitationLearning -from tianshou.policy.base import Algorithm -from tianshou.policy.imitation.base import ImitationPolicy -from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import SAC, OffPolicyImitationLearning +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.imitation.base import ImitationPolicy +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 3567f7668..8e4e4e021 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -10,10 +10,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3 -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import TD3 +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 658df9efe..6ec5c3986 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -11,10 +11,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import TRPO -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import TRPO +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 862f10450..14e626469 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -10,11 +10,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import A2C, OffPolicyImitationLearning -from tianshou.policy.base import Algorithm -from tianshou.policy.imitation.base import ImitationPolicy -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import A2C, OffPolicyImitationLearning +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.imitation.base import ImitationPolicy +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 33b3c4cd6..86ba3dd38 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -7,9 +7,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, DummyVectorEnv -from tianshou.policy import BDQN -from tianshou.policy.modelfree.bdqn import BDQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import BDQN +from tianshou.algorithm.modelfree.bdqn import BDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils.net.common import BranchingNet from tianshou.utils.torch_utils import policy_within_training_step diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index ec291da18..c88e675a5 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -16,10 +16,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import C51 -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.c51 import C51Policy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import C51 +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index ed1df42f2..6a1b76fb1 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -9,13 +9,13 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DiscreteSAC -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.discrete_sac import ( +from tianshou.algorithm import DiscreteSAC +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.discrete_sac import ( DiscreteSACPolicy, ) -from tianshou.policy.modelfree.sac import AutoAlpha -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm.modelfree.sac import AutoAlpha +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index cf002bbd2..f848a4bd1 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -15,10 +15,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 267320b0c..bb779361e 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -9,10 +9,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 5b43866aa..8b32c0aed 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -15,10 +15,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import FQF -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.fqf import FQFPolicy -from tianshou.policy.optim import AdamOptimizerFactory, RMSpropOptimizerFactory +from tianshou.algorithm import FQF +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.fqf import FQFPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index a6932309b..30c5bee2a 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -15,10 +15,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import IQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.iqn import IQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import IQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.iqn import IQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 9893f0dc5..8197727b3 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -10,10 +10,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import Reinforce -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import Reinforce +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index dbfda23f2..6eead6b0e 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -10,10 +10,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PPO -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.pg import DiscreteActorPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import PPO +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index bc1c17102..03cada4cc 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -14,10 +14,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import QRDQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import QRDQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 07d59a2a7..4cae649b0 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -15,10 +15,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import RainbowDQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.c51 import C51Policy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import RainbowDQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index d314ec2ab..a3286b36c 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -13,9 +13,9 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import DQN, Algorithm, ICMOffPolicyWrapper -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DQN, Algorithm, ICMOffPolicyWrapper +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 9045e29c8..88c0718c2 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -9,11 +9,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import PPO -from tianshou.policy.base import Algorithm -from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import PPO +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 086d12f76..51e81e62c 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -7,8 +7,8 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.policy import PSRL -from tianshou.policy.modelbased.psrl import PSRLPolicy +from tianshou.algorithm import PSRL +from tianshou.algorithm.modelbased.psrl import PSRLPolicy from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 9ed763027..dd77c783f 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -14,10 +14,10 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import QRDQN -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import QRDQN +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index b5291f21c..b7ea3d820 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -9,10 +9,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import SAC -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import SAC +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 6839847f3..83615fd93 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -12,9 +12,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import BCQ, Algorithm -from tianshou.policy.imitation.bcq import BCQPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import BCQ, Algorithm +from tianshou.algorithm.imitation.bcq import BCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index acc60461c..8cda37a7a 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -12,9 +12,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import CQL, Algorithm -from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import CQL, Algorithm +from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 4baeaa128..12faddb05 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -16,9 +16,9 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, DiscreteBCQ -from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import Algorithm, DiscreteBCQ +from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 2f1efbec6..1bdde5b14 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -16,9 +16,9 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, DiscreteCQL -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import Algorithm, DiscreteCQL +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 278268ff5..6878fdfe3 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -16,9 +16,9 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.policy import Algorithm, DiscreteCRR -from tianshou.policy.modelfree.pg import DiscreteActorPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import Algorithm, DiscreteCRR +from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 81a25a613..86a84da35 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -12,9 +12,9 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import GAIL, Algorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import GAIL, Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index e62472224..8a3eb2ee0 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -13,10 +13,10 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.policy import TD3BC -from tianshou.policy.base import Algorithm -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import TD3BC +from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 1634581fe..9dfcccb41 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -11,10 +11,10 @@ from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import DQN, Algorithm, MultiAgentOffPolicyAlgorithm -from tianshou.policy.base import OffPolicyAlgorithm -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import DQN, Algorithm, MultiAgentOffPolicyAlgorithm +from tianshou.algorithm.base import OffPolicyAlgorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 490eae6c3..8736b8a9c 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -15,11 +15,11 @@ from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import PPO, Algorithm -from tianshou.policy.base import OnPolicyAlgorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm -from tianshou.policy.optim import AdamOptimizerFactory +from tianshou.algorithm import PPO, Algorithm +from tianshou.algorithm.base import OnPolicyAlgorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ModuleWithVectorOutput diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index e4ea9ba2d..9b4de1ca3 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -13,15 +13,15 @@ from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv -from tianshou.policy import ( +from tianshou.algorithm import ( DQN, Algorithm, MARLRandomDiscreteMaskedOffPolicyAlgorithm, MultiAgentOffPolicyAlgorithm, ) -from tianshou.policy.base import OffPolicyAlgorithm -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.optim import AdamOptimizerFactory, OptimizerFactory +from tianshou.algorithm.base import OffPolicyAlgorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 73f74aa6f..0b87ad0f9 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,4 +1,4 @@ -from tianshou import data, env, exploration, policy, trainer, utils +from tianshou import data, env, exploration, algorithm, trainer, utils __version__ = "1.2.0-dev" @@ -19,7 +19,7 @@ def configure() -> None: "env", "data", "utils", - "policy", + "algorithm", "trainer", "exploration", ] diff --git a/tianshou/algorithm/__init__.py b/tianshou/algorithm/__init__.py new file mode 100644 index 000000000..9a5a7203a --- /dev/null +++ b/tianshou/algorithm/__init__.py @@ -0,0 +1,35 @@ +"""Algorithm package.""" +# isort:skip_file + +from tianshou.algorithm.base import Algorithm, TrainingStats +from tianshou.algorithm.modelfree.pg import Reinforce +from tianshou.algorithm.modelfree.dqn import DQN +from tianshou.algorithm.modelfree.ddpg import DDPG + +from tianshou.algorithm.random import MARLRandomDiscreteMaskedOffPolicyAlgorithm +from tianshou.algorithm.modelfree.bdqn import BDQN +from tianshou.algorithm.modelfree.c51 import C51 +from tianshou.algorithm.modelfree.rainbow import RainbowDQN +from tianshou.algorithm.modelfree.qrdqn import QRDQN +from tianshou.algorithm.modelfree.iqn import IQN +from tianshou.algorithm.modelfree.fqf import FQF +from tianshou.algorithm.modelfree.a2c import A2C +from tianshou.algorithm.modelfree.npg import NPG +from tianshou.algorithm.modelfree.ppo import PPO +from tianshou.algorithm.modelfree.trpo import TRPO +from tianshou.algorithm.modelfree.td3 import TD3 +from tianshou.algorithm.modelfree.sac import SAC +from tianshou.algorithm.modelfree.redq import REDQ +from tianshou.algorithm.modelfree.discrete_sac import DiscreteSAC +from tianshou.algorithm.imitation.base import OffPolicyImitationLearning +from tianshou.algorithm.imitation.bcq import BCQ +from tianshou.algorithm.imitation.cql import CQL +from tianshou.algorithm.imitation.td3_bc import TD3BC +from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQ +from tianshou.algorithm.imitation.discrete_cql import DiscreteCQL +from tianshou.algorithm.imitation.discrete_crr import DiscreteCRR +from tianshou.algorithm.imitation.gail import GAIL +from tianshou.algorithm.modelbased.psrl import PSRL +from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.multiagent.mapolicy import MultiAgentOffPolicyAlgorithm diff --git a/tianshou/policy/base.py b/tianshou/algorithm/base.py similarity index 99% rename from tianshou/policy/base.py rename to tianshou/algorithm/base.py index 55957ecb2..67bfe8094 100644 --- a/tianshou/policy/base.py +++ b/tianshou/algorithm/base.py @@ -29,7 +29,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.determinism import TraceLogger from tianshou.utils.lagged_network import ( EvalModeModuleWrapper, diff --git a/tianshou/policy/imitation/__init__.py b/tianshou/algorithm/imitation/__init__.py similarity index 100% rename from tianshou/policy/imitation/__init__.py rename to tianshou/algorithm/imitation/__init__.py diff --git a/tianshou/policy/imitation/base.py b/tianshou/algorithm/imitation/base.py similarity index 98% rename from tianshou/policy/imitation/base.py rename to tianshou/algorithm/imitation/base.py index a13d50809..127cf625c 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/algorithm/imitation/base.py @@ -13,14 +13,14 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import Algorithm -from tianshou.policy.base import ( +from tianshou.algorithm import Algorithm +from tianshou.algorithm.base import ( OfflineAlgorithm, OffPolicyAlgorithm, Policy, TrainingStats, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory # Dimension Naming Convention # B - Batch Size diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/algorithm/imitation/bcq.py similarity index 99% rename from tianshou/policy/imitation/bcq.py rename to tianshou/algorithm/imitation/bcq.py index 91005a6a3..4cb6afca9 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/algorithm/imitation/bcq.py @@ -10,13 +10,13 @@ from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, Policy, TrainingStats, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.continuous import VAE diff --git a/tianshou/policy/imitation/cql.py b/tianshou/algorithm/imitation/cql.py similarity index 99% rename from tianshou/policy/imitation/cql.py rename to tianshou/algorithm/imitation/cql.py index c9365b56e..a6fd81044 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/algorithm/imitation/cql.py @@ -10,12 +10,12 @@ from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.buffer.base import TBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, ) -from tianshou.policy.modelfree.sac import Alpha, SACPolicy, SACTrainingStats -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.sac import Alpha, SACPolicy, SACTrainingStats +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.conversion import to_optional_float from tianshou.utils.torch_utils import torch_device diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/algorithm/imitation/discrete_bcq.py similarity index 98% rename from tianshou/policy/imitation/discrete_bcq.py rename to tianshou/algorithm/imitation/discrete_bcq.py index 6a66f986c..d8dbb3044 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/algorithm/imitation/discrete_bcq.py @@ -14,13 +14,13 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, ) -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.modelfree.pg import SimpleLossTrainingStats -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory float_info = torch.finfo(torch.float32) INF = float_info.max diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/algorithm/imitation/discrete_cql.py similarity index 94% rename from tianshou/policy/imitation/discrete_cql.py rename to tianshou/algorithm/imitation/discrete_cql.py index 56632850a..b8d1b8e58 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/algorithm/imitation/discrete_cql.py @@ -6,11 +6,11 @@ from tianshou.data import to_torch from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import QRDQN -from tianshou.policy.base import OfflineAlgorithm -from tianshou.policy.modelfree.pg import SimpleLossTrainingStats -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm import QRDQN +from tianshou.algorithm.base import OfflineAlgorithm +from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import OptimizerFactory @dataclass(kw_only=True) diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/algorithm/imitation/discrete_crr.py similarity index 98% rename from tianshou/policy/imitation/discrete_crr.py rename to tianshou/algorithm/imitation/discrete_crr.py index c97258815..ea7452144 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/algorithm/imitation/discrete_crr.py @@ -9,16 +9,16 @@ from tianshou.data import ReplayBuffer, to_torch, to_torch_as from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, ) -from tianshou.policy.modelfree.pg import ( +from tianshou.algorithm.modelfree.pg import ( DiscountedReturnComputation, DiscreteActorPolicy, SimpleLossTrainingStats, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/policy/imitation/gail.py b/tianshou/algorithm/imitation/gail.py similarity index 98% rename from tianshou/policy/imitation/gail.py rename to tianshou/algorithm/imitation/gail.py index 43901fda5..af83b12b8 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/algorithm/imitation/gail.py @@ -11,10 +11,10 @@ to_torch, ) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol -from tianshou.policy.modelfree.a2c import A2CTrainingStats -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.modelfree.ppo import PPO -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.a2c import A2CTrainingStats +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.ppo import PPO +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.common import ModuleWithVectorOutput from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/algorithm/imitation/td3_bc.py similarity index 95% rename from tianshou/policy/imitation/td3_bc.py rename to tianshou/algorithm/imitation/td3_bc.py index b6ed83086..c237694b9 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/algorithm/imitation/td3_bc.py @@ -3,11 +3,11 @@ from tianshou.data import to_torch_as from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import TD3 -from tianshou.policy.base import OfflineAlgorithm -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.modelfree.td3 import TD3TrainingStats -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm import TD3 +from tianshou.algorithm.base import OfflineAlgorithm +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.modelfree.td3 import TD3TrainingStats +from tianshou.algorithm.optim import OptimizerFactory # NOTE: This uses diamond inheritance to convert from off-policy to offline diff --git a/tianshou/policy/modelbased/__init__.py b/tianshou/algorithm/modelbased/__init__.py similarity index 100% rename from tianshou/policy/modelbased/__init__.py rename to tianshou/algorithm/modelbased/__init__.py diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/algorithm/modelbased/icm.py similarity index 98% rename from tianshou/policy/modelbased/icm.py rename to tianshou/algorithm/modelbased/icm.py index 887b59ac5..3f8313d87 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/algorithm/modelbased/icm.py @@ -5,8 +5,8 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy import Algorithm -from tianshou.policy.base import ( +from tianshou.algorithm import Algorithm +from tianshou.algorithm.base import ( OffPolicyAlgorithm, OffPolicyWrapperAlgorithm, OnPolicyAlgorithm, @@ -15,7 +15,7 @@ TrainingStats, TrainingStatsWrapper, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.discrete import IntrinsicCuriosityModule diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/algorithm/modelbased/psrl.py similarity index 99% rename from tianshou/policy/modelbased/psrl.py rename to tianshou/algorithm/modelbased/psrl.py index b7c93ee72..d8603b051 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/algorithm/modelbased/psrl.py @@ -8,7 +8,7 @@ from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( OnPolicyAlgorithm, Policy, TrainingStats, diff --git a/tianshou/policy/modelfree/__init__.py b/tianshou/algorithm/modelfree/__init__.py similarity index 100% rename from tianshou/policy/modelfree/__init__.py rename to tianshou/algorithm/modelfree/__init__.py diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/algorithm/modelfree/a2c.py similarity index 98% rename from tianshou/policy/modelfree/a2c.py rename to tianshou/algorithm/modelfree/a2c.py index 52bcf45c0..43b25af50 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/algorithm/modelfree/a2c.py @@ -8,12 +8,12 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( OnPolicyAlgorithm, TrainingStats, ) -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.continuous import ContinuousCritic diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py similarity index 97% rename from tianshou/policy/modelfree/bdqn.py rename to tianshou/algorithm/modelfree/bdqn.py index 195a22666..cd83b1df0 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -14,13 +14,13 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy.base import TArrOrActBatch -from tianshou.policy.modelfree.dqn import ( +from tianshou.algorithm.base import TArrOrActBatch +from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) -from tianshou.policy.modelfree.pg import SimpleLossTrainingStats -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.common import BranchingNet mark_used(ActBatchProtocol) diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/algorithm/modelfree/c51.py similarity index 97% rename from tianshou/policy/modelfree/c51.py rename to tianshou/algorithm/modelfree/c51.py index 01dbbbeab..65b89a914 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/algorithm/modelfree/c51.py @@ -4,12 +4,12 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.modelfree.dqn import ( +from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) -from tianshou.policy.modelfree.pg import LossSequenceTrainingStats -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.pg import LossSequenceTrainingStats +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.common import Net diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/algorithm/modelfree/ddpg.py similarity index 99% rename from tianshou/policy/modelfree/ddpg.py rename to tianshou/algorithm/modelfree/ddpg.py index 235cf1861..97865e2ab 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/algorithm/modelfree/ddpg.py @@ -18,8 +18,8 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise, GaussianNoise -from tianshou.policy import Algorithm -from tianshou.policy.base import ( +from tianshou.algorithm import Algorithm +from tianshou.algorithm.base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OffPolicyAlgorithm, Policy, @@ -27,7 +27,7 @@ TPolicy, TrainingStats, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.continuous import ( ContinuousActorDeterministicInterface, ContinuousCritic, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/algorithm/modelfree/discrete_sac.py similarity index 97% rename from tianshou/policy/modelfree/discrete_sac.py rename to tianshou/algorithm/modelfree/discrete_sac.py index a3875e35b..a547d86d4 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/algorithm/modelfree/discrete_sac.py @@ -13,10 +13,10 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy.base import Policy -from tianshou.policy.modelfree.sac import Alpha, SACTrainingStats -from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.base import Policy +from tianshou.algorithm.modelfree.sac import Alpha, SACTrainingStats +from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py similarity index 98% rename from tianshou/policy/modelfree/dqn.py rename to tianshou/algorithm/modelfree/dqn.py index 209ab8269..d0099c0cb 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -16,17 +16,17 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import Algorithm -from tianshou.policy.base import ( +from tianshou.algorithm import Algorithm +from tianshou.algorithm.base import ( LaggedNetworkFullUpdateAlgorithmMixin, OffPolicyAlgorithm, Policy, TArrOrActBatch, ) -from tianshou.policy.modelfree.pg import ( +from tianshou.algorithm.modelfree.pg import ( SimpleLossTrainingStats, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.common import Net diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/algorithm/modelfree/fqf.py similarity index 97% rename from tianshou/policy/modelfree/fqf.py rename to tianshou/algorithm/modelfree/fqf.py index c709056b9..b2110f2eb 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/algorithm/modelfree/fqf.py @@ -9,11 +9,11 @@ from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import QRDQN, Algorithm -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.modelfree.pg import SimpleLossTrainingStats -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm import QRDQN, Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/algorithm/modelfree/iqn.py similarity index 97% rename from tianshou/policy/modelfree/iqn.py rename to tianshou/algorithm/modelfree/iqn.py index dec4d87a0..977e64797 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/algorithm/modelfree/iqn.py @@ -12,10 +12,10 @@ QuantileRegressionBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import QRDQN -from tianshou.policy.modelfree.pg import SimpleLossTrainingStats -from tianshou.policy.modelfree.qrdqn import QRDQNPolicy -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm import QRDQN +from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import OptimizerFactory class IQNPolicy(QRDQNPolicy): diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/algorithm/modelfree/npg.py similarity index 97% rename from tianshou/policy/modelfree/npg.py rename to tianshou/algorithm/modelfree/npg.py index b47145257..7d7949572 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/algorithm/modelfree/npg.py @@ -9,10 +9,10 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.policy.base import TrainingStats -from tianshou.policy.modelfree.a2c import ActorCriticOnPolicyAlgorithm -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.base import TrainingStats +from tianshou.algorithm.modelfree.a2c import ActorCriticOnPolicyAlgorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/algorithm/modelfree/pg.py similarity index 99% rename from tianshou/policy/modelfree/pg.py rename to tianshou/algorithm/modelfree/pg.py index 0dd4b9685..689e21d06 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/algorithm/modelfree/pg.py @@ -22,13 +22,13 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy import Algorithm -from tianshou.policy.base import ( +from tianshou.algorithm import Algorithm +from tianshou.algorithm.base import ( OnPolicyAlgorithm, Policy, TrainingStats, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ( ContinuousActorProbabilisticInterface, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/algorithm/modelfree/ppo.py similarity index 98% rename from tianshou/policy/modelfree/ppo.py rename to tianshou/algorithm/modelfree/ppo.py index a1d2d73cf..b47d9be3c 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/algorithm/modelfree/ppo.py @@ -5,10 +5,10 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol -from tianshou.policy import A2C -from tianshou.policy.modelfree.a2c import A2CTrainingStats -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm import A2C +from tianshou.algorithm.modelfree.a2c import A2CTrainingStats +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/algorithm/modelfree/qrdqn.py similarity index 97% rename from tianshou/policy/modelfree/qrdqn.py rename to tianshou/algorithm/modelfree/qrdqn.py index 8f9daa58c..5e0e9ebb7 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/algorithm/modelfree/qrdqn.py @@ -7,12 +7,12 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.modelfree.dqn import ( +from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) -from tianshou.policy.modelfree.pg import SimpleLossTrainingStats -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory class QRDQNPolicy(DiscreteQLearningPolicy): diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/algorithm/modelfree/rainbow.py similarity index 96% rename from tianshou/policy/modelfree/rainbow.py rename to tianshou/algorithm/modelfree/rainbow.py index 501330e45..ba9466a7e 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/algorithm/modelfree/rainbow.py @@ -3,9 +3,9 @@ from torch import nn from tianshou.data.types import RolloutBatchProtocol -from tianshou.policy.modelfree.c51 import C51, C51Policy -from tianshou.policy.modelfree.pg import LossSequenceTrainingStats -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.c51 import C51, C51Policy +from tianshou.algorithm.modelfree.pg import LossSequenceTrainingStats +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.discrete import NoisyLinear diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/algorithm/modelfree/redq.py similarity index 99% rename from tianshou/policy/modelfree/redq.py rename to tianshou/algorithm/modelfree/redq.py index 2eda560c9..fa1ceb5a4 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/algorithm/modelfree/redq.py @@ -13,13 +13,13 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.policy.modelfree.ddpg import ( +from tianshou.algorithm.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, ContinuousPolicyWithExplorationNoise, DDPGTrainingStats, ) -from tianshou.policy.modelfree.sac import Alpha -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.modelfree.sac import Alpha +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousActorProbabilistic diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/algorithm/modelfree/sac.py similarity index 98% rename from tianshou/policy/modelfree/sac.py rename to tianshou/algorithm/modelfree/sac.py index bb5489f0e..fbfaeac84 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/algorithm/modelfree/sac.py @@ -14,10 +14,10 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.policy.base import TrainingStats -from tianshou.policy.modelfree.ddpg import ContinuousPolicyWithExplorationNoise -from tianshou.policy.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.base import TrainingStats +from tianshou.algorithm.modelfree.ddpg import ContinuousPolicyWithExplorationNoise +from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ContinuousActorProbabilistic diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/algorithm/modelfree/td3.py similarity index 98% rename from tianshou/policy/modelfree/td3.py rename to tianshou/algorithm/modelfree/td3.py index 9100a37e8..37704bf14 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/algorithm/modelfree/td3.py @@ -10,16 +10,16 @@ ActStateBatchProtocol, RolloutBatchProtocol, ) -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( TPolicy, TrainingStats, ) -from tianshou.policy.modelfree.ddpg import ( +from tianshou.algorithm.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, ContinuousDeterministicPolicy, TActBatchProtocol, ) -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm.optim import OptimizerFactory @dataclass(kw_only=True) diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/algorithm/modelfree/trpo.py similarity index 98% rename from tianshou/policy/modelfree/trpo.py rename to tianshou/algorithm/modelfree/trpo.py index 953da5b0b..71e5ba9d2 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/algorithm/modelfree/trpo.py @@ -7,10 +7,10 @@ from tianshou.data import SequenceSummaryStats from tianshou.data.types import BatchWithAdvantagesProtocol -from tianshou.policy import NPG -from tianshou.policy.modelfree.npg import NPGTrainingStats -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.optim import OptimizerFactory +from tianshou.algorithm import NPG +from tianshou.algorithm.modelfree.npg import NPGTrainingStats +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/policy/multiagent/__init__.py b/tianshou/algorithm/multiagent/__init__.py similarity index 100% rename from tianshou/policy/multiagent/__init__.py rename to tianshou/algorithm/multiagent/__init__.py diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/algorithm/multiagent/mapolicy.py similarity index 99% rename from tianshou/policy/multiagent/mapolicy.py rename to tianshou/algorithm/multiagent/mapolicy.py index 588ad4a50..825e0b331 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/algorithm/multiagent/mapolicy.py @@ -9,8 +9,8 @@ from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import Algorithm -from tianshou.policy.base import ( +from tianshou.algorithm import Algorithm +from tianshou.algorithm.base import ( OffPolicyAlgorithm, OnPolicyAlgorithm, Policy, diff --git a/tianshou/policy/optim.py b/tianshou/algorithm/optim.py similarity index 97% rename from tianshou/policy/optim.py rename to tianshou/algorithm/optim.py index a03c95871..b949a96f2 100644 --- a/tianshou/policy/optim.py +++ b/tianshou/algorithm/optim.py @@ -1,139 +1,139 @@ -from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable -from typing import Any, Self, TypeAlias - -import numpy as np -import torch -from sensai.util.string import ToStringMixin -from torch.optim import Adam, RMSprop -from torch.optim.lr_scheduler import LambdaLR, LRScheduler - -ParamsType: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] - - -class LRSchedulerFactory(ToStringMixin, ABC): - """Factory for the creation of a learning rate scheduler.""" - - @abstractmethod - def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - pass - - -class LRSchedulerFactoryLinear(LRSchedulerFactory): - """ - Factory for a learning rate scheduler where the learning rate linearly decays towards - zero for the given trainer parameters. - """ - - def __init__(self, num_epochs: int, step_per_epoch: int, step_per_collect: int): - self.num_epochs = num_epochs - self.step_per_epoch = step_per_epoch - self.step_per_collect = step_per_collect - - def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: - return LambdaLR(optim, lr_lambda=self._LRLambda(self).compute) - - class _LRLambda: - def __init__(self, parent: "LRSchedulerFactoryLinear"): - self.max_update_num = ( - np.ceil(parent.step_per_epoch / parent.step_per_collect) * parent.num_epochs - ) - - def compute(self, epoch: int) -> float: - return 1.0 - epoch / self.max_update_num - - -class OptimizerFactory(ABC, ToStringMixin): - def __init__(self) -> None: - self.lr_scheduler_factory: LRSchedulerFactory | None = None - - def with_lr_scheduler_factory(self, lr_scheduler_factory: LRSchedulerFactory) -> Self: - self.lr_scheduler_factory = lr_scheduler_factory - return self - - def create_instances( - self, - module: torch.nn.Module, - ) -> tuple[torch.optim.Optimizer, LRScheduler | None]: - optimizer = self._create_optimizer_for_params(module.parameters()) - lr_scheduler = None - if self.lr_scheduler_factory is not None: - lr_scheduler = self.lr_scheduler_factory.create_scheduler(optimizer) - return optimizer, lr_scheduler - - @abstractmethod - def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: - pass - - -class TorchOptimizerFactory(OptimizerFactory): - """General factory for arbitrary torch optimizers.""" - - def __init__(self, optim_class: Callable[..., torch.optim.Optimizer], **kwargs: Any): - """ - - :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), - which will be passed the module parameters, the learning rate as `lr` and the - kwargs provided. - :param kwargs: keyword arguments to provide at optimizer construction - """ - super().__init__() - self.optim_class = optim_class - self.kwargs = kwargs - - def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: - return self.optim_class(params, **self.kwargs) - - -class AdamOptimizerFactory(OptimizerFactory): - def __init__( - self, - lr: float = 1e-3, - betas: tuple[float, float] = (0.9, 0.999), - eps: float = 1e-08, - weight_decay: float = 0, - ): - super().__init__() - self.lr = lr - self.weight_decay = weight_decay - self.eps = eps - self.betas = betas - - def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: - return Adam( - params, - lr=self.lr, - betas=self.betas, - eps=self.eps, - weight_decay=self.weight_decay, - ) - - -class RMSpropOptimizerFactory(OptimizerFactory): - def __init__( - self, - lr: float = 1e-2, - alpha: float = 0.99, - eps: float = 1e-08, - weight_decay: float = 0, - momentum: float = 0, - centered: bool = False, - ): - super().__init__() - self.lr = lr - self.alpha = alpha - self.momentum = momentum - self.centered = centered - self.weight_decay = weight_decay - self.eps = eps - - def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: - return RMSprop( - params, - lr=self.lr, - alpha=self.alpha, - eps=self.eps, - weight_decay=self.weight_decay, - momentum=self.momentum, - centered=self.centered, - ) +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable +from typing import Any, Self, TypeAlias + +import numpy as np +import torch +from sensai.util.string import ToStringMixin +from torch.optim import Adam, RMSprop +from torch.optim.lr_scheduler import LambdaLR, LRScheduler + +ParamsType: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] + + +class LRSchedulerFactory(ToStringMixin, ABC): + """Factory for the creation of a learning rate scheduler.""" + + @abstractmethod + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + pass + + +class LRSchedulerFactoryLinear(LRSchedulerFactory): + """ + Factory for a learning rate scheduler where the learning rate linearly decays towards + zero for the given trainer parameters. + """ + + def __init__(self, num_epochs: int, step_per_epoch: int, step_per_collect: int): + self.num_epochs = num_epochs + self.step_per_epoch = step_per_epoch + self.step_per_collect = step_per_collect + + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + return LambdaLR(optim, lr_lambda=self._LRLambda(self).compute) + + class _LRLambda: + def __init__(self, parent: "LRSchedulerFactoryLinear"): + self.max_update_num = ( + np.ceil(parent.step_per_epoch / parent.step_per_collect) * parent.num_epochs + ) + + def compute(self, epoch: int) -> float: + return 1.0 - epoch / self.max_update_num + + +class OptimizerFactory(ABC, ToStringMixin): + def __init__(self) -> None: + self.lr_scheduler_factory: LRSchedulerFactory | None = None + + def with_lr_scheduler_factory(self, lr_scheduler_factory: LRSchedulerFactory) -> Self: + self.lr_scheduler_factory = lr_scheduler_factory + return self + + def create_instances( + self, + module: torch.nn.Module, + ) -> tuple[torch.optim.Optimizer, LRScheduler | None]: + optimizer = self._create_optimizer_for_params(module.parameters()) + lr_scheduler = None + if self.lr_scheduler_factory is not None: + lr_scheduler = self.lr_scheduler_factory.create_scheduler(optimizer) + return optimizer, lr_scheduler + + @abstractmethod + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + pass + + +class TorchOptimizerFactory(OptimizerFactory): + """General factory for arbitrary torch optimizers.""" + + def __init__(self, optim_class: Callable[..., torch.optim.Optimizer], **kwargs: Any): + """ + + :param optim_class: the optimizer class (e.g. subclass of `torch.optim.Optimizer`), + which will be passed the module parameters, the learning rate as `lr` and the + kwargs provided. + :param kwargs: keyword arguments to provide at optimizer construction + """ + super().__init__() + self.optim_class = optim_class + self.kwargs = kwargs + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return self.optim_class(params, **self.kwargs) + + +class AdamOptimizerFactory(OptimizerFactory): + def __init__( + self, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, + ): + super().__init__() + self.lr = lr + self.weight_decay = weight_decay + self.eps = eps + self.betas = betas + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return Adam( + params, + lr=self.lr, + betas=self.betas, + eps=self.eps, + weight_decay=self.weight_decay, + ) + + +class RMSpropOptimizerFactory(OptimizerFactory): + def __init__( + self, + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-08, + weight_decay: float = 0, + momentum: float = 0, + centered: bool = False, + ): + super().__init__() + self.lr = lr + self.alpha = alpha + self.momentum = momentum + self.centered = centered + self.weight_decay = weight_decay + self.eps = eps + + def _create_optimizer_for_params(self, params: ParamsType) -> torch.optim.Optimizer: + return RMSprop( + params, + lr=self.lr, + alpha=self.alpha, + eps=self.eps, + weight_decay=self.weight_decay, + momentum=self.momentum, + centered=self.centered, + ) diff --git a/tianshou/policy/random.py b/tianshou/algorithm/random.py similarity index 95% rename from tianshou/policy/random.py rename to tianshou/algorithm/random.py index db0ad27a1..fb862e657 100644 --- a/tianshou/policy/random.py +++ b/tianshou/algorithm/random.py @@ -6,8 +6,8 @@ from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.policy import base -from tianshou.policy.base import OffPolicyAlgorithm, TrainingStats +from tianshou.algorithm import base +from tianshou.algorithm.base import OffPolicyAlgorithm, TrainingStats class MARLRandomTrainingStats(TrainingStats): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index f84709de6..c5826e5cf 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -30,8 +30,8 @@ RolloutBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.policy import Algorithm -from tianshou.policy.base import Policy, episode_mc_return_to_go +from tianshou.algorithm import Algorithm +from tianshou.algorithm.base import Policy, episode_mc_return_to_go from tianshou.utils.determinism import TraceLogger from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index 51fba5c2d..828dffe64 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from tianshou.data import CollectStats, CollectStatsBase - from tianshou.policy.base import TrainingStats + from tianshou.algorithm.base import TrainingStats log = logging.getLogger(__name__) diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index f14d73670..6b2848a29 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -15,7 +15,7 @@ IntermediateModuleFactory, ) from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.algorithm.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import NetBase from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear from tianshou.utils.torch_utils import torch_device diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index fa7027630..ad62af203 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -46,7 +46,7 @@ from tianshou.highlevel.persistence import PolicyPersistence from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.world import World -from tianshou.policy import ( +from tianshou.algorithm import ( A2C, DDPG, DQN, @@ -61,18 +61,18 @@ DiscreteSAC, Reinforce, ) -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( OffPolicyAlgorithm, OnPolicyAlgorithm, Policy, ) -from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.policy.modelfree.iqn import IQNPolicy -from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic -from tianshou.policy.modelfree.redq import REDQPolicy -from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.modelfree.iqn import IQNPolicy +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.redq import REDQPolicy +from tianshou.algorithm.modelfree.sac import SACPolicy from tianshou.trainer import OffPolicyTrainer, OnPolicyTrainer, Trainer from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils.net.discrete import DiscreteActor diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index ed275ac83..107376297 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -112,7 +112,7 @@ TrainerCallbacks, ) from tianshou.highlevel.world import World -from tianshou.policy import Algorithm +from tianshou.algorithm import Algorithm from tianshou.utils import LazyLogger from tianshou.utils.net.common import ModuleType from tianshou.utils.print import DataclassPPrintMixin diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index e3807569b..0a6055a5f 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -22,7 +22,7 @@ DistributionFunctionFactoryCategorical, DistributionFunctionFactoryIndependentGaussians, ) -from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.algorithm.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import Actor, ModuleType, ModuleWithVectorOutput, Net diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 66e6d154c..c80a93331 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -5,7 +5,7 @@ import torch from sensai.util.string import ToStringMixin -from tianshou.policy.optim import ( +from tianshou.algorithm.optim import ( AdamOptimizerFactory, OptimizerFactory, RMSpropOptimizerFactory, diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 39413d965..7bd1a3baf 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -6,7 +6,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.optim import OptimizerFactoryFactory -from tianshou.policy.modelfree.sac import Alpha, AutoAlpha +from tianshou.algorithm.modelfree.sac import Alpha, AutoAlpha class AutoAlphaFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index 6cb436185..11e37af2a 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -6,7 +6,7 @@ from sensai.util.string import ToStringMixin from tianshou.highlevel.env import Environments -from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont +from tianshou.algorithm.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont class DistributionFunctionFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 47ea04897..fb9652e5f 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -3,7 +3,7 @@ from sensai.util.string import ToStringMixin from tianshou.highlevel.config import TrainingConfig -from tianshou.policy.optim import LRSchedulerFactory, LRSchedulerFactoryLinear +from tianshou.algorithm.optim import LRSchedulerFactory, LRSchedulerFactoryLinear class LRSchedulerFactoryFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 41469cf7c..e85c7e37d 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -8,9 +8,9 @@ from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.optim import OptimizerFactoryFactory -from tianshou.policy import Algorithm, ICMOffPolicyWrapper -from tianshou.policy.base import OffPolicyAlgorithm, OnPolicyAlgorithm -from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm import Algorithm, ICMOffPolicyWrapper +from tianshou.algorithm.base import OffPolicyAlgorithm, OnPolicyAlgorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.utils.net.discrete import IntrinsicCuriosityModule TAlgorithmOut = TypeVar("TAlgorithmOut", bound=Algorithm) diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 8ae3f3b7c..bc4c8cf62 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -8,8 +8,8 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger -from tianshou.policy import DQN, Algorithm -from tianshou.policy.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm import DQN, Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) log = logging.getLogger(__name__) diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 2a68d3ff8..bd92178bf 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -6,7 +6,7 @@ from tianshou.data import BaseCollector from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger - from tianshou.policy import Algorithm + from tianshou.algorithm import Algorithm from tianshou.trainer import Trainer diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py deleted file mode 100644 index cf1815273..000000000 --- a/tianshou/policy/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Policy package.""" -# isort:skip_file - -from tianshou.policy.base import Algorithm, TrainingStats -from tianshou.policy.modelfree.pg import Reinforce -from tianshou.policy.modelfree.dqn import DQN -from tianshou.policy.modelfree.ddpg import DDPG - -from tianshou.policy.random import MARLRandomDiscreteMaskedOffPolicyAlgorithm -from tianshou.policy.modelfree.bdqn import BDQN -from tianshou.policy.modelfree.c51 import C51 -from tianshou.policy.modelfree.rainbow import RainbowDQN -from tianshou.policy.modelfree.qrdqn import QRDQN -from tianshou.policy.modelfree.iqn import IQN -from tianshou.policy.modelfree.fqf import FQF -from tianshou.policy.modelfree.a2c import A2C -from tianshou.policy.modelfree.npg import NPG -from tianshou.policy.modelfree.ppo import PPO -from tianshou.policy.modelfree.trpo import TRPO -from tianshou.policy.modelfree.td3 import TD3 -from tianshou.policy.modelfree.sac import SAC -from tianshou.policy.modelfree.redq import REDQ -from tianshou.policy.modelfree.discrete_sac import DiscreteSAC -from tianshou.policy.imitation.base import OffPolicyImitationLearning -from tianshou.policy.imitation.bcq import BCQ -from tianshou.policy.imitation.cql import CQL -from tianshou.policy.imitation.td3_bc import TD3BC -from tianshou.policy.imitation.discrete_bcq import DiscreteBCQ -from tianshou.policy.imitation.discrete_cql import DiscreteCQL -from tianshou.policy.imitation.discrete_crr import DiscreteCRR -from tianshou.policy.imitation.gail import GAIL -from tianshou.policy.modelbased.psrl import PSRL -from tianshou.policy.modelbased.icm import ICMOffPolicyWrapper -from tianshou.policy.modelbased.icm import ICMOnPolicyWrapper -from tianshou.policy.multiagent.mapolicy import MultiAgentOffPolicyAlgorithm diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 4ef5ab93c..726bad1a1 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -47,7 +47,7 @@ ) from tianshou.data.buffer.base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase -from tianshou.policy.base import ( +from tianshou.algorithm.base import ( Algorithm, OfflineAlgorithm, OffPolicyAlgorithm, diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 723d19fad..ebfa262db 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -8,7 +8,7 @@ from torch import nn if TYPE_CHECKING: - from tianshou.policy.base import Policy + from tianshou.algorithm.base import Policy @contextmanager From d88eff59b6b6e25a163e9a15156d211d182bdcbe Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 23:23:00 +0200 Subject: [PATCH 168/230] v2: Rename module algorithm.base -> algorithm.algorithm_base --- examples/atari/atari_c51.py | 2 +- examples/atari/atari_dqn.py | 2 +- examples/atari/atari_fqf.py | 2 +- examples/atari/atari_iqn.py | 2 +- examples/atari/atari_ppo.py | 2 +- examples/atari/atari_qrdqn.py | 2 +- examples/atari/atari_rainbow.py | 2 +- examples/atari/atari_sac.py | 2 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/bipedal_bdq.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- examples/box2d/mcc_sac.py | 2 +- examples/inverse/irl_gail.py | 2 +- examples/mujoco/fetch_her_ddpg.py | 2 +- examples/mujoco/mujoco_a2c.py | 2 +- examples/mujoco/mujoco_ddpg.py | 2 +- examples/mujoco/mujoco_npg.py | 2 +- examples/mujoco/mujoco_ppo.py | 2 +- examples/mujoco/mujoco_redq.py | 2 +- examples/mujoco/mujoco_reinforce.py | 2 +- examples/mujoco/mujoco_sac.py | 2 +- examples/mujoco/mujoco_td3.py | 2 +- examples/mujoco/mujoco_trpo.py | 2 +- examples/offline/atari_bcq.py | 2 +- examples/offline/atari_cql.py | 2 +- examples/offline/atari_crr.py | 2 +- examples/offline/atari_il.py | 2 +- examples/offline/d4rl_bcq.py | 2 +- examples/offline/d4rl_cql.py | 2 +- examples/offline/d4rl_il.py | 2 +- examples/offline/d4rl_td3_bc.py | 2 +- examples/vizdoom/vizdoom_c51.py | 2 +- examples/vizdoom/vizdoom_ppo.py | 2 +- test/base/test_collector.py | 2 +- test/base/test_env_finite.py | 2 +- test/base/test_policy.py | 2 +- test/base/test_stats.py | 2 +- test/continuous/test_ddpg.py | 2 +- test/continuous/test_npg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_redq.py | 2 +- test/continuous/test_sac_with_il.py | 2 +- test/continuous/test_td3.py | 2 +- test/continuous/test_trpo.py | 2 +- test/discrete/test_a2c_with_il.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_discrete_sac.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_fqf.py | 2 +- test/discrete/test_iqn.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo_discrete.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_rainbow.py | 2 +- test/modelbased/test_ppo_icm.py | 2 +- test/offline/gather_cartpole_data.py | 2 +- test/offline/gather_pendulum_data.py | 2 +- test/offline/test_td3_bc.py | 2 +- test/pettingzoo/pistonball.py | 2 +- test/pettingzoo/pistonball_continuous.py | 2 +- test/pettingzoo/tic_tac_toe.py | 2 +- tianshou/algorithm/__init__.py | 2 +- tianshou/algorithm/{base.py => algorithm_base.py} | 0 tianshou/algorithm/imitation/base.py | 2 +- tianshou/algorithm/imitation/bcq.py | 2 +- tianshou/algorithm/imitation/cql.py | 2 +- tianshou/algorithm/imitation/discrete_bcq.py | 2 +- tianshou/algorithm/imitation/discrete_cql.py | 2 +- tianshou/algorithm/imitation/discrete_crr.py | 2 +- tianshou/algorithm/imitation/td3_bc.py | 2 +- tianshou/algorithm/modelbased/icm.py | 2 +- tianshou/algorithm/modelbased/psrl.py | 2 +- tianshou/algorithm/modelfree/a2c.py | 2 +- tianshou/algorithm/modelfree/bdqn.py | 2 +- tianshou/algorithm/modelfree/ddpg.py | 2 +- tianshou/algorithm/modelfree/discrete_sac.py | 2 +- tianshou/algorithm/modelfree/dqn.py | 2 +- tianshou/algorithm/modelfree/npg.py | 2 +- tianshou/algorithm/modelfree/pg.py | 2 +- tianshou/algorithm/modelfree/sac.py | 2 +- tianshou/algorithm/modelfree/td3.py | 2 +- tianshou/algorithm/multiagent/mapolicy.py | 2 +- tianshou/algorithm/random.py | 6 +++--- tianshou/data/collector.py | 2 +- tianshou/data/stats.py | 2 +- tianshou/highlevel/algorithm.py | 2 +- tianshou/highlevel/params/policy_wrapper.py | 2 +- tianshou/trainer/base.py | 2 +- tianshou/utils/torch_utils.py | 2 +- 91 files changed, 92 insertions(+), 92 deletions(-) rename tianshou/algorithm/{base.py => algorithm_base.py} (100%) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 763d7b107..37936f5aa 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -12,7 +12,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import C51 -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 5957e4ee5..8c57d076f 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -12,7 +12,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 327f7a51c..6cc0aa0ca 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -12,7 +12,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import FQF -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.fqf import FQFPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index f5486ccd9..056da4288 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -12,7 +12,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import IQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.iqn import IQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 213ad31b2..bb76758e4 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -19,7 +19,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import PPO -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 049fe42c5..0d0a4095d 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -12,7 +12,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import QRDQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 5692d81ea..ac1a1de04 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -17,7 +17,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import C51, RainbowDQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 249e31c02..c88dd27ee 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -12,7 +12,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DiscreteSAC, ICMOffPolicyWrapper -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index f70d01735..de5c5c739 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import DQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index e9230bc3e..31f7af4d4 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.algorithm import BDQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.bdqn import BDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 213f1b08f..4729490ff 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.algorithm import SAC -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index bbfded999..e2e4ca628 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.algorithm import DQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 5f242bc79..6aa140f9e 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.algorithm import SAC -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 9049a691b..f1d2f6fb6 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -24,7 +24,7 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.algorithm import GAIL -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index acd652662..8517e3153 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -23,7 +23,7 @@ from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DDPG -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index f4acfb1d9..717a194a6 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import A2C -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 471616174..b9849dca0 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -13,7 +13,7 @@ from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DDPG -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index a64891a3f..10116ac8a 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import NPG -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 57bb14011..d504c8b41 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import PPO -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 88337eeb4..d8e313385 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import REDQ -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 0718a6002..7c7c92f34 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import Reinforce -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index d301cc3e3..67547e0a0 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import SAC -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 379bcc34d..8367b0a07 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -13,7 +13,7 @@ from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import TD3 -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 63958a761..e5e6e40f7 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import TRPO -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 0f71ca265..4b239ea45 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -17,7 +17,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DiscreteBCQ -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 9cdaef5bc..591a44b80 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -18,7 +18,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DiscreteCQL -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 2e303b632..e94492283 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -17,7 +17,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DiscreteCRR -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 4f0fd3326..3d08a5ed7 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -15,7 +15,7 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.base import ImitationPolicy, OfflineImitationLearning from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 9b151eb74..91e51690e 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.algorithm import BCQ -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.bcq import BCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index bbb57f5f7..04ba3ce0f 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.algorithm import CQL -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index b6ec5cf0f..dc741e0c3 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -13,7 +13,7 @@ from examples.offline.utils import load_buffer_d4rl from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.base import ImitationPolicy, OfflineImitationLearning from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 1d4ac1036..b75f3eedb 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -15,7 +15,7 @@ from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise from tianshou.algorithm import TD3BC -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 019151be2..b6fd383f7 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -12,7 +12,7 @@ from tianshou.env.atari.atari_network import C51Net from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import C51 -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 6ce17c17c..e23af08f3 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -13,7 +13,7 @@ from tianshou.env.atari.atari_network import DQNet from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import PPO -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 5636e73de..e3a42f3c9 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -25,7 +25,7 @@ ) from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.algorithm.base import Policy, episode_mc_return_to_go +from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go try: import envpool diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index a6daf7010..3d7a02e31 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -19,7 +19,7 @@ ) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type -from tianshou.algorithm.base import Policy +from tianshou.algorithm.algorithm_base import Policy class DummyDataset(Dataset): diff --git a/test/base/test_policy.py b/test/base/test_policy.py index ed9b3a79c..f0048d931 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -6,7 +6,7 @@ from tianshou.data import Batch from tianshou.algorithm import PPO -from tianshou.algorithm.base import RandomActionPolicy, episode_mc_return_to_go +from tianshou.algorithm.algorithm_base import RandomActionPolicy, episode_mc_return_to_go from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.utils.net.common import Net diff --git a/test/base/test_stats.py b/test/base/test_stats.py index 0efa284a7..59ae349da 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -7,7 +7,7 @@ from tianshou.data import Batch, CollectStats from tianshou.data.collector import CollectStepBatchProtocol, get_stddev_from_dist -from tianshou.algorithm.base import TrainingStats, TrainingStatsWrapper +from tianshou.algorithm.algorithm_base import TrainingStats, TrainingStatsWrapper class DummyTrainingStatsWrapper(TrainingStatsWrapper): diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 108480095..36eecc702 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.algorithm import DDPG -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index f13e05c1a..e157be305 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import NPG -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index ed1b91292..6fd9a48cd 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import PPO -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 0c9fc3c65..da8f5bf36 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import REDQ -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index ab0928491..955bb05bb 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import SAC, OffPolicyImitationLearning -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.base import ImitationPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 8e4e4e021..524e17513 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.algorithm import TD3 -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 6ec5c3986..918eeeef0 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import TRPO -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 14e626469..bd9472845 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import A2C, OffPolicyImitationLearning -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.base import ImitationPolicy from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index c88e675a5..6d7e22599 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -17,7 +17,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.algorithm import C51 -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 6a1b76fb1..9acf43612 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import DiscreteSAC -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.discrete_sac import ( DiscreteSACPolicy, ) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index f848a4bd1..6ba1347e7 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -16,7 +16,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.algorithm import DQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index bb779361e..912f6f7fd 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import DQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 8b32c0aed..2ccad1927 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -16,7 +16,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.algorithm import FQF -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.fqf import FQFPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 30c5bee2a..091d9e4a1 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -16,7 +16,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.algorithm import IQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.iqn import IQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 8197727b3..497befbbc 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import Reinforce -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index 6eead6b0e..21f10f378 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import PPO -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 03cada4cc..e9463c314 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -15,7 +15,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.algorithm import QRDQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 4cae649b0..8d570b8b1 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -16,7 +16,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.algorithm import RainbowDQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 88c0718c2..8511f9b4d 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import PPO -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index dd77c783f..cba273ffb 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -15,7 +15,7 @@ ) from tianshou.env import DummyVectorEnv from tianshou.algorithm import QRDQN -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index b7ea3d820..f1655fca8 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -10,7 +10,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import SAC -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 8a3eb2ee0..d0f6bcfc4 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -14,7 +14,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.algorithm import TD3BC -from tianshou.algorithm.base import Algorithm +from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 9dfcccb41..84d4c920a 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.algorithm import DQN, Algorithm, MultiAgentOffPolicyAlgorithm -from tianshou.algorithm.base import OffPolicyAlgorithm +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 8736b8a9c..bfb09b38f 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -16,7 +16,7 @@ from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.algorithm import PPO, Algorithm -from tianshou.algorithm.base import OnPolicyAlgorithm +from tianshou.algorithm.algorithm_base import OnPolicyAlgorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm from tianshou.algorithm.optim import AdamOptimizerFactory diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 9b4de1ca3..706e27787 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -19,7 +19,7 @@ MARLRandomDiscreteMaskedOffPolicyAlgorithm, MultiAgentOffPolicyAlgorithm, ) -from tianshou.algorithm.base import OffPolicyAlgorithm +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/tianshou/algorithm/__init__.py b/tianshou/algorithm/__init__.py index 9a5a7203a..44de57573 100644 --- a/tianshou/algorithm/__init__.py +++ b/tianshou/algorithm/__init__.py @@ -1,7 +1,7 @@ """Algorithm package.""" # isort:skip_file -from tianshou.algorithm.base import Algorithm, TrainingStats +from tianshou.algorithm.algorithm_base import Algorithm, TrainingStats from tianshou.algorithm.modelfree.pg import Reinforce from tianshou.algorithm.modelfree.dqn import DQN from tianshou.algorithm.modelfree.ddpg import DDPG diff --git a/tianshou/algorithm/base.py b/tianshou/algorithm/algorithm_base.py similarity index 100% rename from tianshou/algorithm/base.py rename to tianshou/algorithm/algorithm_base.py diff --git a/tianshou/algorithm/imitation/base.py b/tianshou/algorithm/imitation/base.py index 127cf625c..61fbd3a6b 100644 --- a/tianshou/algorithm/imitation/base.py +++ b/tianshou/algorithm/imitation/base.py @@ -14,7 +14,7 @@ RolloutBatchProtocol, ) from tianshou.algorithm import Algorithm -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( OfflineAlgorithm, OffPolicyAlgorithm, Policy, diff --git a/tianshou/algorithm/imitation/bcq.py b/tianshou/algorithm/imitation/bcq.py index 4cb6afca9..7c7129104 100644 --- a/tianshou/algorithm/imitation/bcq.py +++ b/tianshou/algorithm/imitation/bcq.py @@ -10,7 +10,7 @@ from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, Policy, diff --git a/tianshou/algorithm/imitation/cql.py b/tianshou/algorithm/imitation/cql.py index a6fd81044..4d0f33ac3 100644 --- a/tianshou/algorithm/imitation/cql.py +++ b/tianshou/algorithm/imitation/cql.py @@ -10,7 +10,7 @@ from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.buffer.base import TBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, ) diff --git a/tianshou/algorithm/imitation/discrete_bcq.py b/tianshou/algorithm/imitation/discrete_bcq.py index d8dbb3044..6fabc0bb3 100644 --- a/tianshou/algorithm/imitation/discrete_bcq.py +++ b/tianshou/algorithm/imitation/discrete_bcq.py @@ -14,7 +14,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, ) diff --git a/tianshou/algorithm/imitation/discrete_cql.py b/tianshou/algorithm/imitation/discrete_cql.py index b8d1b8e58..3f2985b39 100644 --- a/tianshou/algorithm/imitation/discrete_cql.py +++ b/tianshou/algorithm/imitation/discrete_cql.py @@ -7,7 +7,7 @@ from tianshou.data import to_torch from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm import QRDQN -from tianshou.algorithm.base import OfflineAlgorithm +from tianshou.algorithm.algorithm_base import OfflineAlgorithm from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import OptimizerFactory diff --git a/tianshou/algorithm/imitation/discrete_crr.py b/tianshou/algorithm/imitation/discrete_crr.py index ea7452144..69feb7c97 100644 --- a/tianshou/algorithm/imitation/discrete_crr.py +++ b/tianshou/algorithm/imitation/discrete_crr.py @@ -9,7 +9,7 @@ from tianshou.data import ReplayBuffer, to_torch, to_torch_as from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, ) diff --git a/tianshou/algorithm/imitation/td3_bc.py b/tianshou/algorithm/imitation/td3_bc.py index c237694b9..f091bba0a 100644 --- a/tianshou/algorithm/imitation/td3_bc.py +++ b/tianshou/algorithm/imitation/td3_bc.py @@ -4,7 +4,7 @@ from tianshou.data import to_torch_as from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm import TD3 -from tianshou.algorithm.base import OfflineAlgorithm +from tianshou.algorithm.algorithm_base import OfflineAlgorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.modelfree.td3 import TD3TrainingStats from tianshou.algorithm.optim import OptimizerFactory diff --git a/tianshou/algorithm/modelbased/icm.py b/tianshou/algorithm/modelbased/icm.py index 3f8313d87..f40e17116 100644 --- a/tianshou/algorithm/modelbased/icm.py +++ b/tianshou/algorithm/modelbased/icm.py @@ -6,7 +6,7 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm import Algorithm -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( OffPolicyAlgorithm, OffPolicyWrapperAlgorithm, OnPolicyAlgorithm, diff --git a/tianshou/algorithm/modelbased/psrl.py b/tianshou/algorithm/modelbased/psrl.py index d8603b051..c9c5296cf 100644 --- a/tianshou/algorithm/modelbased/psrl.py +++ b/tianshou/algorithm/modelbased/psrl.py @@ -8,7 +8,7 @@ from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( OnPolicyAlgorithm, Policy, TrainingStats, diff --git a/tianshou/algorithm/modelfree/a2c.py b/tianshou/algorithm/modelfree/a2c.py index 43b25af50..0bff04489 100644 --- a/tianshou/algorithm/modelfree/a2c.py +++ b/tianshou/algorithm/modelfree/a2c.py @@ -8,7 +8,7 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( OnPolicyAlgorithm, TrainingStats, ) diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index cd83b1df0..01707de58 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -14,7 +14,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm.base import TArrOrActBatch +from tianshou.algorithm.algorithm_base import TArrOrActBatch from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, diff --git a/tianshou/algorithm/modelfree/ddpg.py b/tianshou/algorithm/modelfree/ddpg.py index 97865e2ab..62868a2f4 100644 --- a/tianshou/algorithm/modelfree/ddpg.py +++ b/tianshou/algorithm/modelfree/ddpg.py @@ -19,7 +19,7 @@ ) from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.algorithm import Algorithm -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OffPolicyAlgorithm, Policy, diff --git a/tianshou/algorithm/modelfree/discrete_sac.py b/tianshou/algorithm/modelfree/discrete_sac.py index a547d86d4..16d211754 100644 --- a/tianshou/algorithm/modelfree/discrete_sac.py +++ b/tianshou/algorithm/modelfree/discrete_sac.py @@ -13,7 +13,7 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm.base import Policy +from tianshou.algorithm.algorithm_base import Policy from tianshou.algorithm.modelfree.sac import Alpha, SACTrainingStats from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.algorithm.optim import OptimizerFactory diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index d0099c0cb..8e2042c21 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -17,7 +17,7 @@ RolloutBatchProtocol, ) from tianshou.algorithm import Algorithm -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, OffPolicyAlgorithm, Policy, diff --git a/tianshou/algorithm/modelfree/npg.py b/tianshou/algorithm/modelfree/npg.py index 7d7949572..cc257466c 100644 --- a/tianshou/algorithm/modelfree/npg.py +++ b/tianshou/algorithm/modelfree/npg.py @@ -9,7 +9,7 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol -from tianshou.algorithm.base import TrainingStats +from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.algorithm.modelfree.a2c import ActorCriticOnPolicyAlgorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory diff --git a/tianshou/algorithm/modelfree/pg.py b/tianshou/algorithm/modelfree/pg.py index 689e21d06..e82485b0c 100644 --- a/tianshou/algorithm/modelfree/pg.py +++ b/tianshou/algorithm/modelfree/pg.py @@ -23,7 +23,7 @@ RolloutBatchProtocol, ) from tianshou.algorithm import Algorithm -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( OnPolicyAlgorithm, Policy, TrainingStats, diff --git a/tianshou/algorithm/modelfree/sac.py b/tianshou/algorithm/modelfree/sac.py index fbfaeac84..62eaa824f 100644 --- a/tianshou/algorithm/modelfree/sac.py +++ b/tianshou/algorithm/modelfree/sac.py @@ -14,7 +14,7 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.algorithm.base import TrainingStats +from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.algorithm.modelfree.ddpg import ContinuousPolicyWithExplorationNoise from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm from tianshou.algorithm.optim import OptimizerFactory diff --git a/tianshou/algorithm/modelfree/td3.py b/tianshou/algorithm/modelfree/td3.py index 37704bf14..faf7bf36b 100644 --- a/tianshou/algorithm/modelfree/td3.py +++ b/tianshou/algorithm/modelfree/td3.py @@ -10,7 +10,7 @@ ActStateBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( TPolicy, TrainingStats, ) diff --git a/tianshou/algorithm/multiagent/mapolicy.py b/tianshou/algorithm/multiagent/mapolicy.py index 825e0b331..29be11905 100644 --- a/tianshou/algorithm/multiagent/mapolicy.py +++ b/tianshou/algorithm/multiagent/mapolicy.py @@ -10,7 +10,7 @@ from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.algorithm import Algorithm -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( OffPolicyAlgorithm, OnPolicyAlgorithm, Policy, diff --git a/tianshou/algorithm/random.py b/tianshou/algorithm/random.py index fb862e657..92da3fad6 100644 --- a/tianshou/algorithm/random.py +++ b/tianshou/algorithm/random.py @@ -6,8 +6,8 @@ from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.algorithm import base -from tianshou.algorithm.base import OffPolicyAlgorithm, TrainingStats +from tianshou.algorithm import algorithm_base +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, TrainingStats, Policy class MARLRandomTrainingStats(TrainingStats): @@ -20,7 +20,7 @@ class MARLRandomDiscreteMaskedOffPolicyAlgorithm(OffPolicyAlgorithm): It randomly chooses an action from the legal actions (according to the given mask). """ - class Policy(base.Policy): + class Policy(Policy): """A random agent used in multi-agent learning. It randomly chooses an action from the legal actions. diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index c5826e5cf..37fea3d49 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -31,7 +31,7 @@ ) from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.algorithm import Algorithm -from tianshou.algorithm.base import Policy, episode_mc_return_to_go +from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.utils.determinism import TraceLogger from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index 828dffe64..c56e1c075 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from tianshou.data import CollectStats, CollectStatsBase - from tianshou.algorithm.base import TrainingStats + from tianshou.algorithm.algorithm_base import TrainingStats log = logging.getLogger(__name__) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index ad62af203..5acb56e92 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -61,7 +61,7 @@ DiscreteSAC, Reinforce, ) -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( OffPolicyAlgorithm, OnPolicyAlgorithm, Policy, diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index e85c7e37d..21457f10b 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -9,7 +9,7 @@ from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.optim import OptimizerFactoryFactory from tianshou.algorithm import Algorithm, ICMOffPolicyWrapper -from tianshou.algorithm.base import OffPolicyAlgorithm, OnPolicyAlgorithm +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, OnPolicyAlgorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.utils.net.discrete import IntrinsicCuriosityModule diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 726bad1a1..40373617a 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -47,7 +47,7 @@ ) from tianshou.data.buffer.base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase -from tianshou.algorithm.base import ( +from tianshou.algorithm.algorithm_base import ( Algorithm, OfflineAlgorithm, OffPolicyAlgorithm, diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index ebfa262db..f675502d3 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -8,7 +8,7 @@ from torch import nn if TYPE_CHECKING: - from tianshou.algorithm.base import Policy + from tianshou.algorithm.algorithm_base import Policy @contextmanager From de18c80686bafee6496b4e7ce4790212ec44b26c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 23:32:10 +0200 Subject: [PATCH 169/230] v2: Rename module trainer.base -> trainer.trainer --- examples/atari/atari_c51.py | 2 +- examples/discrete/discrete_dqn.py | 2 +- test/continuous/test_ddpg.py | 2 +- test/continuous/test_npg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_redq.py | 2 +- test/continuous/test_sac_with_il.py | 2 +- test/continuous/test_td3.py | 2 +- test/continuous/test_trpo.py | 2 +- test/discrete/test_a2c_with_il.py | 4 ++-- test/discrete/test_bdqn.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_discrete_sac.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_fqf.py | 2 +- test/discrete/test_iqn.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_rainbow.py | 2 +- test/offline/gather_pendulum_data.py | 2 +- test/offline/test_bcq.py | 2 +- tianshou/algorithm/algorithm_base.py | 14 ++++++++------ tianshou/highlevel/algorithm.py | 2 +- tianshou/trainer/__init__.py | 3 ++- tianshou/trainer/{base.py => trainer.py} | 0 25 files changed, 33 insertions(+), 30 deletions(-) rename tianshou/trainer/{base.py => trainer.py} (100%) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 37936f5aa..ade97544c 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -15,7 +15,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams def get_args() -> argparse.Namespace: diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index f98520143..39ffeb4ea 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -5,7 +5,7 @@ from tianshou.data import CollectStats from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.space_info import SpaceInfo diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 36eecc702..aca34fa27 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -14,7 +14,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index e157be305..638332701 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -15,7 +15,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OnPolicyTrainerParams +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 6fd9a48cd..c10a1fe8f 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -14,7 +14,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OnPolicyTrainerParams +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index da8f5bf36..9769b946d 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -15,7 +15,7 @@ from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 955bb05bb..566de2822 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -14,7 +14,7 @@ from tianshou.algorithm.imitation.base import ImitationPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 524e17513..e9c31f4e5 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -14,7 +14,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 918eeeef0..8865c47a0 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -15,7 +15,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OnPolicyTrainerParams +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index bd9472845..90b1891d5 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -11,11 +11,11 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.algorithm import A2C, OffPolicyImitationLearning -from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm import Algorithm from tianshou.algorithm.imitation.base import ImitationPolicy from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 86ba3dd38..2a5a797da 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -10,7 +10,7 @@ from tianshou.algorithm import BDQN from tianshou.algorithm.modelfree.bdqn import BDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import BranchingNet from tianshou.utils.torch_utils import policy_within_training_step diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 6d7e22599..eaefe8d32 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -20,7 +20,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 9acf43612..879115f0f 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -16,7 +16,7 @@ ) from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6ba1347e7..008593b53 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -19,7 +19,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 2ccad1927..3a6ba65f1 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -19,7 +19,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.fqf import FQFPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 091d9e4a1..dd93df6ea 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -19,7 +19,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.iqn import IQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import ImplicitQuantileNetwork diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 497befbbc..461131e42 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -14,7 +14,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OnPolicyTrainerParams +from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index e9463c314..c99d02563 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -18,7 +18,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 8d570b8b1..4d74976c8 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -19,7 +19,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index f1655fca8..be7d5c774 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -13,7 +13,7 @@ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OffPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 83615fd93..715347716 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -15,7 +15,7 @@ from tianshou.algorithm import BCQ, Algorithm from tianshou.algorithm.imitation.bcq import BCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory -from tianshou.trainer.base import OfflineTrainerParams +from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation diff --git a/tianshou/algorithm/algorithm_base.py b/tianshou/algorithm/algorithm_base.py index 67bfe8094..f5bd3fdbb 100644 --- a/tianshou/algorithm/algorithm_base.py +++ b/tianshou/algorithm/algorithm_base.py @@ -13,12 +13,14 @@ from numpy.typing import ArrayLike from overrides import override from sensai.util.hash import pickle_hash +from sensai.util.helper import mark_used from torch import nn from torch.nn.modules.module import ( _IncompatibleKeys, # we have to do this since we override load_state_dict ) from torch.optim.lr_scheduler import LRScheduler +from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as from tianshou.data.batch import Batch, BatchProtocol, TArr from tianshou.data.buffer.base import TBuffer @@ -29,7 +31,6 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.determinism import TraceLogger from tianshou.utils.lagged_network import ( EvalModeModuleWrapper, @@ -40,8 +41,7 @@ from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode if TYPE_CHECKING: - from tianshou.trainer.base import ( - InfoStats, + from tianshou.trainer import ( OfflineTrainer, OfflineTrainerParams, OffPolicyTrainer, @@ -51,6 +51,8 @@ Trainer, TrainerParams, ) + from tianshou.data.stats import InfoStats + mark_used(TrainerParams) logger = logging.getLogger(__name__) @@ -871,7 +873,7 @@ class OnPolicyAlgorithm( """Base class for on-policy RL algorithms.""" def create_trainer(self, params: "OnPolicyTrainerParams") -> "OnPolicyTrainer": - from tianshou.trainer.base import OnPolicyTrainer + from tianshou.trainer import OnPolicyTrainer return OnPolicyTrainer(self, params) @@ -911,7 +913,7 @@ class OffPolicyAlgorithm( """Base class for off-policy RL algorithms.""" def create_trainer(self, params: "OffPolicyTrainerParams") -> "OffPolicyTrainer": - from tianshou.trainer.base import OffPolicyTrainer + from tianshou.trainer import OffPolicyTrainer return OffPolicyTrainer(self, params) @@ -957,7 +959,7 @@ def run_training(self, params: "OfflineTrainerParams") -> "InfoStats": return super().run_training(params) def create_trainer(self, params: "OfflineTrainerParams") -> "OfflineTrainer": - from tianshou.trainer.base import OfflineTrainer + from tianshou.trainer import OfflineTrainer return OfflineTrainer(self, params) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 5acb56e92..2c96534b5 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -74,7 +74,7 @@ from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.sac import SACPolicy from tianshou.trainer import OffPolicyTrainer, OnPolicyTrainer, Trainer -from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams +from tianshou.trainer import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils.net.discrete import DiscreteActor CHECKPOINT_DICT_KEY_MODEL = "model" diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index f36ee3035..702ce6bf4 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,6 +1,6 @@ """Trainer package.""" -from .base import ( +from .trainer import ( OfflineTrainer, OfflineTrainerParams, OffPolicyTrainer, @@ -8,4 +8,5 @@ OnPolicyTrainer, OnPolicyTrainerParams, Trainer, + TrainerParams, ) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/trainer.py similarity index 100% rename from tianshou/trainer/base.py rename to tianshou/trainer/trainer.py From 33a8235167871644f3f083e153b0cc9ecc532491 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 23:39:21 +0200 Subject: [PATCH 170/230] Rename module mapolicy -> marl --- test/pettingzoo/pistonball_continuous.py | 2 +- tianshou/algorithm/__init__.py | 2 +- tianshou/algorithm/multiagent/{mapolicy.py => marl.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename tianshou/algorithm/multiagent/{mapolicy.py => marl.py} (100%) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index bfb09b38f..c4cf23106 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -18,7 +18,7 @@ from tianshou.algorithm import PPO, Algorithm from tianshou.algorithm.algorithm_base import OnPolicyAlgorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic -from tianshou.algorithm.multiagent.mapolicy import MultiAgentOnPolicyAlgorithm +from tianshou.algorithm.multiagent.marl import MultiAgentOnPolicyAlgorithm from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger diff --git a/tianshou/algorithm/__init__.py b/tianshou/algorithm/__init__.py index 44de57573..07e3f855e 100644 --- a/tianshou/algorithm/__init__.py +++ b/tianshou/algorithm/__init__.py @@ -32,4 +32,4 @@ from tianshou.algorithm.modelbased.psrl import PSRL from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper -from tianshou.algorithm.multiagent.mapolicy import MultiAgentOffPolicyAlgorithm +from tianshou.algorithm.multiagent.marl import MultiAgentOffPolicyAlgorithm diff --git a/tianshou/algorithm/multiagent/mapolicy.py b/tianshou/algorithm/multiagent/marl.py similarity index 100% rename from tianshou/algorithm/multiagent/mapolicy.py rename to tianshou/algorithm/multiagent/marl.py From 43be06ba635aa191e6f2b3a2b97d650ff71107c4 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 23:40:13 +0200 Subject: [PATCH 171/230] Rename module policy_params -> algorithm_params --- README.md | 8 ++++---- examples/atari/atari_dqn_hl.py | 2 +- examples/atari/atari_iqn_hl.py | 2 +- examples/atari/atari_ppo_hl.py | 2 +- examples/atari/atari_sac_hl.py | 2 +- examples/discrete/discrete_dqn_hl.py | 2 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ddpg_hl.py | 2 +- examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl_multi.py | 2 +- examples/mujoco/mujoco_redq_hl.py | 2 +- examples/mujoco/mujoco_reinforce_hl.py | 2 +- examples/mujoco/mujoco_sac_hl.py | 2 +- examples/mujoco/mujoco_td3_hl.py | 2 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- tianshou/highlevel/algorithm.py | 2 +- tianshou/highlevel/experiment.py | 2 +- .../params/{policy_params.py => algorithm_params.py} | 0 19 files changed, 21 insertions(+), 21 deletions(-) rename tianshou/highlevel/params/{policy_params.py => algorithm_params.py} (100%) diff --git a/README.md b/README.md index 5dbb66226..0dc181d94 100644 --- a/README.md +++ b/README.md @@ -218,13 +218,13 @@ and some configuration classes: ```python from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.env import ( - EnvFactoryRegistered, - VectorEnvType, + EnvFactoryRegistered, + VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig -from tianshou.highlevel.params.policy_params import DQNParams +from tianshou.highlevel.params.algorithm_params import DQNParams from tianshou.highlevel.trainer import ( - EpochStopCallbackRewardThreshold, + EpochStopCallbackRewardThreshold, ) ``` diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 6d76b5dde..77bae646b 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -15,7 +15,7 @@ DQNExperimentBuilder, ExperimentConfig, ) -from tianshou.highlevel.params.policy_params import DQNParams +from tianshou.highlevel.params.algorithm_params import DQNParams from tianshou.highlevel.params.policy_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 5a18dd82f..141bdf215 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -15,7 +15,7 @@ ExperimentConfig, IQNExperimentBuilder, ) -from tianshou.highlevel.params.policy_params import IQNParams +from tianshou.highlevel.params.algorithm_params import IQNParams from tianshou.highlevel.trainer import ( EpochTrainCallbackDQNEpsLinearDecay, ) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index ab40dd342..e99ebc224 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -17,7 +17,7 @@ PPOExperimentBuilder, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.highlevel.params.algorithm_params import PPOParams from tianshou.highlevel.params.policy_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 2e5fe7c90..8e8d743d1 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -17,7 +17,7 @@ ExperimentConfig, ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault -from tianshou.highlevel.params.policy_params import DiscreteSACParams +from tianshou.highlevel.params.algorithm_params import DiscreteSACParams from tianshou.highlevel.params.policy_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index 5f85b2b51..7f8ce777e 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -6,7 +6,7 @@ VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig -from tianshou.highlevel.params.policy_params import DQNParams +from tianshou.highlevel.params.algorithm_params import DQNParams from tianshou.highlevel.trainer import ( EpochStopCallbackRewardThreshold, ) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 8170f46d0..2f21cd56a 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -16,7 +16,7 @@ ) from tianshou.highlevel.optim import OptimizerFactoryFactoryRMSprop from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_params import A2CParams +from tianshou.highlevel.params.algorithm_params import A2CParams def main( diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 2acfa2f1f..deb44b485 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -13,7 +13,7 @@ ExperimentConfig, ) from tianshou.highlevel.params.noise import MaxActionScaledGaussian -from tianshou.highlevel.params.policy_params import DDPGParams +from tianshou.highlevel.params.algorithm_params import DDPGParams def main( diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 3b2630c5d..bb29a92d9 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -15,7 +15,7 @@ NPGExperimentBuilder, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_params import NPGParams +from tianshou.highlevel.params.algorithm_params import NPGParams def main( diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 12b9365c2..213bf8218 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -15,7 +15,7 @@ PPOExperimentBuilder, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.highlevel.params.algorithm_params import PPOParams def main( diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index fa12df494..36a6ae239 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -29,7 +29,7 @@ ) from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.highlevel.params.algorithm_params import PPOParams log = logging.getLogger(__name__) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 24ab3a073..633f34466 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -14,7 +14,7 @@ REDQExperimentBuilder, ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault -from tianshou.highlevel.params.policy_params import REDQParams +from tianshou.highlevel.params.algorithm_params import REDQParams def main( diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 9fd13e462..15a1b66c2 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -15,7 +15,7 @@ ReinforceExperimentBuilder, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_params import ReinforceParams +from tianshou.highlevel.params.algorithm_params import ReinforceParams def main( diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 6e4af1c91..cf94953e8 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -13,7 +13,7 @@ SACExperimentBuilder, ) from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault -from tianshou.highlevel.params.policy_params import SACParams +from tianshou.highlevel.params.algorithm_params import SACParams def main( diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 3f9afb237..a04f03d7a 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -17,7 +17,7 @@ from tianshou.highlevel.params.noise import ( MaxActionScaledGaussian, ) -from tianshou.highlevel.params.policy_params import TD3Params +from tianshou.highlevel.params.algorithm_params import TD3Params def main( diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 310d2a39c..a70ba692c 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -15,7 +15,7 @@ TRPOExperimentBuilder, ) from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_params import TRPOParams +from tianshou.highlevel.params.algorithm_params import TRPOParams def main( diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 2c96534b5..1ccdbed22 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -24,7 +24,7 @@ ) from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory from tianshou.highlevel.optim import OptimizerFactoryFactory -from tianshou.highlevel.params.policy_params import ( +from tianshou.highlevel.params.algorithm_params import ( A2CParams, DDPGParams, DiscreteSACParams, diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 107376297..3126fa7c4 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -86,7 +86,7 @@ OptimizerFactoryFactory, OptimizerFactoryFactoryAdam, ) -from tianshou.highlevel.params.policy_params import ( +from tianshou.highlevel.params.algorithm_params import ( A2CParams, DDPGParams, DiscreteSACParams, diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/algorithm_params.py similarity index 100% rename from tianshou/highlevel/params/policy_params.py rename to tianshou/highlevel/params/algorithm_params.py From 1d03c1cf3621af8f2d3e5cd8e292babc60c35a94 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 23:41:48 +0200 Subject: [PATCH 172/230] Rename module logger.base -> logger.logger_base --- tianshou/evaluation/rliable_evaluation_hl.py | 2 +- tianshou/utils/__init__.py | 2 +- tianshou/utils/logger/{base.py => logger_base.py} | 0 tianshou/utils/logger/tensorboard.py | 2 +- tianshou/utils/logger/wandb.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename tianshou/utils/logger/{base.py => logger_base.py} (100%) diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index 2b8ff5131..cc18d2840 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -15,7 +15,7 @@ from tianshou.highlevel.experiment import Experiment from tianshou.utils import TensorboardLogger -from tianshou.utils.logger.base import DataScope +from tianshou.utils.logger.logger_base import DataScope log = logging.getLogger(__name__) diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 42e7152b2..a23841b36 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,6 +1,6 @@ """Utils package.""" -from tianshou.utils.logger.base import BaseLogger, LazyLogger +from tianshou.utils.logger.logger_base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import TensorboardLogger from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.progress_bar import DummyTqdm, tqdm_config diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/logger_base.py similarity index 100% rename from tianshou/utils/logger/base.py rename to tianshou/utils/logger/logger_base.py diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 1400d8a52..dba11b555 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -6,7 +6,7 @@ from tensorboard.backend.event_processing import event_accumulator from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.logger.base import ( +from tianshou.utils.logger.logger_base import ( VALID_LOG_VALS, VALID_LOG_VALS_TYPE, BaseLogger, diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index fcdceb0de..f92c3fd2c 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger -from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, TRestoredData +from tianshou.utils.logger.logger_base import VALID_LOG_VALS_TYPE, TRestoredData with contextlib.suppress(ImportError): import wandb From 3061e5b9e03ad90130cb55f02ff96528a02bb18a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 23:42:21 +0200 Subject: [PATCH 173/230] Rename module buffer.base -> buffer.buffer_base --- tianshou/algorithm/algorithm_base.py | 2 +- tianshou/algorithm/imitation/cql.py | 2 +- tianshou/data/__init__.py | 2 +- tianshou/data/buffer/{base.py => buffer_base.py} | 0 tianshou/data/collector.py | 2 +- tianshou/trainer/trainer.py | 2 +- 6 files changed, 5 insertions(+), 5 deletions(-) rename tianshou/data/buffer/{base.py => buffer_base.py} (100%) diff --git a/tianshou/algorithm/algorithm_base.py b/tianshou/algorithm/algorithm_base.py index f5bd3fdbb..d404fc392 100644 --- a/tianshou/algorithm/algorithm_base.py +++ b/tianshou/algorithm/algorithm_base.py @@ -23,7 +23,7 @@ from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as from tianshou.data.batch import Batch, BatchProtocol, TArr -from tianshou.data.buffer.base import TBuffer +from tianshou.data.buffer.buffer_base import TBuffer from tianshou.data.types import ( ActBatchProtocol, ActStateBatchProtocol, diff --git a/tianshou/algorithm/imitation/cql.py b/tianshou/algorithm/imitation/cql.py index 4d0f33ac3..1013a146a 100644 --- a/tianshou/algorithm/imitation/cql.py +++ b/tianshou/algorithm/imitation/cql.py @@ -8,7 +8,7 @@ from overrides import override from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.buffer.base import TBuffer +from tianshou.data.buffer.buffer_base import TBuffer from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index c84c2ec7d..7e1d5298d 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -4,7 +4,7 @@ from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree -from tianshou.data.buffer.base import ReplayBuffer +from tianshou.data.buffer.buffer_base import ReplayBuffer from tianshou.data.buffer.prio import PrioritizedReplayBuffer from tianshou.data.buffer.her import HERReplayBuffer from tianshou.data.buffer.manager import ( diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/buffer_base.py similarity index 100% rename from tianshou/data/buffer/base.py rename to tianshou/data/buffer/buffer_base.py diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 37fea3d49..60200a307 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -21,7 +21,7 @@ VectorReplayBuffer, to_numpy, ) -from tianshou.data.buffer.base import MalformedBufferError +from tianshou.data.buffer.buffer_base import MalformedBufferError from tianshou.data.stats import compute_dim_to_summary_stats from tianshou.data.types import ( ActBatchProtocol, diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 40373617a..f8c651d82 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -45,7 +45,7 @@ SequenceSummaryStats, TimingStats, ) -from tianshou.data.buffer.base import MalformedBufferError +from tianshou.data.buffer.buffer_base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase from tianshou.algorithm.algorithm_base import ( Algorithm, From d058d924269cd283811ca5b11288013b98bfc979 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 23:43:13 +0200 Subject: [PATCH 174/230] Rename module imitation.base -> imitation.imitation_base --- examples/offline/d4rl_il.py | 2 +- test/continuous/test_sac_with_il.py | 2 +- test/discrete/test_a2c_with_il.py | 2 +- tianshou/algorithm/imitation/{base.py => imitation_base.py} | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename tianshou/algorithm/imitation/{base.py => imitation_base.py} (100%) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index dc741e0c3..3e6c09150 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.imitation.base import ImitationPolicy, OfflineImitationLearning +from tianshou.algorithm.imitation.imitation_base import ImitationPolicy, OfflineImitationLearning from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 566de2822..7d4d43a9a 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.algorithm import SAC, OffPolicyImitationLearning from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.imitation.base import ImitationPolicy +from tianshou.algorithm.imitation.imitation_base import ImitationPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 90b1891d5..f5c123bab 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -12,7 +12,7 @@ from tianshou.env import DummyVectorEnv from tianshou.algorithm import A2C, OffPolicyImitationLearning from tianshou.algorithm import Algorithm -from tianshou.algorithm.imitation.base import ImitationPolicy +from tianshou.algorithm.imitation.imitation_base import ImitationPolicy from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams, OnPolicyTrainerParams diff --git a/tianshou/algorithm/imitation/base.py b/tianshou/algorithm/imitation/imitation_base.py similarity index 100% rename from tianshou/algorithm/imitation/base.py rename to tianshou/algorithm/imitation/imitation_base.py From 37e7f8d6835b1c0fea5a28bb71d87e8a36eee55a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 23:44:20 +0200 Subject: [PATCH 175/230] Rename module env.worker.base -> env.worker.worker_base --- examples/offline/atari_il.py | 2 +- tianshou/algorithm/__init__.py | 2 +- tianshou/env/worker/__init__.py | 2 +- tianshou/env/worker/{base.py => worker_base.py} | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename tianshou/env/worker/{base.py => worker_base.py} (100%) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 3d08a5ed7..2a403f363 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -16,7 +16,7 @@ from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.imitation.base import ImitationPolicy, OfflineImitationLearning +from tianshou.algorithm.imitation.imitation_base import ImitationPolicy, OfflineImitationLearning from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo diff --git a/tianshou/algorithm/__init__.py b/tianshou/algorithm/__init__.py index 07e3f855e..93606f2f5 100644 --- a/tianshou/algorithm/__init__.py +++ b/tianshou/algorithm/__init__.py @@ -21,7 +21,7 @@ from tianshou.algorithm.modelfree.sac import SAC from tianshou.algorithm.modelfree.redq import REDQ from tianshou.algorithm.modelfree.discrete_sac import DiscreteSAC -from tianshou.algorithm.imitation.base import OffPolicyImitationLearning +from tianshou.algorithm.imitation.imitation_base import OffPolicyImitationLearning from tianshou.algorithm.imitation.bcq import BCQ from tianshou.algorithm.imitation.cql import CQL from tianshou.algorithm.imitation.td3_bc import TD3BC diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py index 1b1f37510..43a7066b1 100644 --- a/tianshou/env/worker/__init__.py +++ b/tianshou/env/worker/__init__.py @@ -1,4 +1,4 @@ -from tianshou.env.worker.base import EnvWorker +from tianshou.env.worker.worker_base import EnvWorker from tianshou.env.worker.dummy import DummyEnvWorker from tianshou.env.worker.ray import RayEnvWorker from tianshou.env.worker.subproc import SubprocEnvWorker diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/worker_base.py similarity index 100% rename from tianshou/env/worker/base.py rename to tianshou/env/worker/worker_base.py From b5938061c1fcd4c4d4a48c54b961c69fb1afb155 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 00:18:00 +0200 Subject: [PATCH 176/230] v2: Clean imports/apply formatter after renamings --- docs/02_notebooks/L0_overview.ipynb | 2 +- docs/02_notebooks/L4_Policy.ipynb | 12 ++-- docs/02_notebooks/L5_Collector.ipynb | 2 +- docs/02_notebooks/L6_Trainer.ipynb | 2 +- docs/02_notebooks/L7_Experiment.ipynb | 2 +- examples/atari/atari_c51.py | 8 +-- examples/atari/atari_dqn.py | 8 +-- examples/atari/atari_fqf.py | 8 +-- examples/atari/atari_iqn.py | 8 +-- examples/atari/atari_ppo.py | 10 +-- examples/atari/atari_ppo_hl.py | 2 +- examples/atari/atari_qrdqn.py | 8 +-- examples/atari/atari_rainbow.py | 8 +-- examples/atari/atari_sac.py | 8 +-- examples/atari/atari_sac_hl.py | 2 +- examples/box2d/acrobot_dualdqn.py | 4 +- examples/box2d/bipedal_bdq.py | 4 +- examples/box2d/bipedal_hardcore_sac.py | 4 +- examples/box2d/lunarlander_dqn.py | 4 +- examples/box2d/mcc_sac.py | 6 +- examples/discrete/discrete_dqn.py | 2 +- examples/inverse/irl_gail.py | 8 +-- examples/mujoco/mujoco_a2c.py | 4 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ddpg.py | 6 +- examples/mujoco/mujoco_ddpg_hl.py | 2 +- examples/mujoco/mujoco_npg.py | 4 +- examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_ppo.py | 4 +- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl_multi.py | 2 +- examples/mujoco/mujoco_redq.py | 4 +- examples/mujoco/mujoco_redq_hl.py | 2 +- examples/mujoco/mujoco_reinforce.py | 4 +- examples/mujoco/mujoco_reinforce_hl.py | 2 +- examples/mujoco/mujoco_sac.py | 4 +- examples/mujoco/mujoco_sac_hl.py | 2 +- examples/mujoco/mujoco_td3.py | 6 +- examples/mujoco/mujoco_td3_hl.py | 2 +- examples/mujoco/mujoco_trpo.py | 4 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- examples/offline/atari_bcq.py | 8 +-- examples/offline/atari_cql.py | 8 +-- examples/offline/atari_crr.py | 8 +-- examples/offline/atari_il.py | 9 ++- examples/offline/d4rl_bcq.py | 4 +- examples/offline/d4rl_cql.py | 4 +- examples/offline/d4rl_il.py | 9 ++- examples/offline/d4rl_td3_bc.py | 6 +- examples/vizdoom/vizdoom_c51.py | 6 +- examples/vizdoom/vizdoom_ppo.py | 6 +- test/base/test_collector.py | 2 +- test/base/test_env_finite.py | 2 +- test/base/test_policy.py | 7 ++- test/base/test_returns.py | 2 +- test/base/test_stats.py | 2 +- test/continuous/test_ddpg.py | 6 +- test/continuous/test_npg.py | 4 +- test/continuous/test_ppo.py | 4 +- test/continuous/test_redq.py | 4 +- test/continuous/test_sac_with_il.py | 4 +- test/continuous/test_td3.py | 6 +- test/continuous/test_trpo.py | 4 +- test/discrete/test_a2c_with_il.py | 7 +-- test/discrete/test_bdqn.py | 4 +- test/discrete/test_c51.py | 8 +-- test/discrete/test_discrete_sac.py | 4 +- test/discrete/test_dqn.py | 8 +-- test/discrete/test_drqn.py | 4 +- test/discrete/test_fqf.py | 8 +-- test/discrete/test_iqn.py | 8 +-- test/discrete/test_pg.py | 4 +- test/discrete/test_ppo_discrete.py | 4 +- test/discrete/test_qrdqn.py | 8 +-- test/discrete/test_rainbow.py | 8 +-- test/modelbased/test_dqn_icm.py | 6 +- test/modelbased/test_ppo_icm.py | 4 +- test/modelbased/test_psrl.py | 2 +- test/offline/gather_cartpole_data.py | 8 +-- test/offline/gather_pendulum_data.py | 4 +- test/offline/test_bcq.py | 4 +- test/offline/test_cql.py | 4 +- test/offline/test_discrete_bcq.py | 6 +- test/offline/test_discrete_cql.py | 6 +- test/offline/test_discrete_crr.py | 6 +- test/offline/test_gail.py | 4 +- test/offline/test_td3_bc.py | 6 +- test/pettingzoo/pistonball.py | 6 +- test/pettingzoo/pistonball_continuous.py | 8 +-- test/pettingzoo/tic_tac_toe.py | 8 +-- tianshou/__init__.py | 2 + tianshou/algorithm/algorithm_base.py | 3 +- tianshou/algorithm/imitation/bcq.py | 6 +- tianshou/algorithm/imitation/cql.py | 6 +- tianshou/algorithm/imitation/discrete_bcq.py | 14 ++--- tianshou/algorithm/imitation/discrete_cql.py | 4 +- tianshou/algorithm/imitation/discrete_crr.py | 4 +- tianshou/algorithm/imitation/gail.py | 8 +-- .../algorithm/imitation/imitation_base.py | 14 ++--- tianshou/algorithm/imitation/td3_bc.py | 4 +- tianshou/algorithm/modelbased/icm.py | 6 +- tianshou/algorithm/modelbased/psrl.py | 6 +- tianshou/algorithm/modelfree/a2c.py | 4 +- tianshou/algorithm/modelfree/bdqn.py | 14 ++--- tianshou/algorithm/modelfree/c51.py | 4 +- tianshou/algorithm/modelfree/ddpg.py | 20 +++--- tianshou/algorithm/modelfree/discrete_sac.py | 8 +-- tianshou/algorithm/modelfree/dqn.py | 18 +++--- tianshou/algorithm/modelfree/fqf.py | 4 +- tianshou/algorithm/modelfree/iqn.py | 8 +-- tianshou/algorithm/modelfree/npg.py | 4 +- tianshou/algorithm/modelfree/pg.py | 14 ++--- tianshou/algorithm/modelfree/ppo.py | 4 +- tianshou/algorithm/modelfree/qrdqn.py | 4 +- tianshou/algorithm/modelfree/rainbow.py | 2 +- tianshou/algorithm/modelfree/redq.py | 14 ++--- tianshou/algorithm/modelfree/sac.py | 8 +-- tianshou/algorithm/modelfree/td3.py | 10 +-- tianshou/algorithm/modelfree/trpo.py | 4 +- tianshou/algorithm/multiagent/marl.py | 6 +- tianshou/algorithm/random.py | 3 +- tianshou/data/collector.py | 4 +- tianshou/data/stats.py | 2 +- tianshou/env/atari/atari_network.py | 2 +- tianshou/env/worker/__init__.py | 2 + tianshou/highlevel/algorithm.py | 61 ++++++++++--------- tianshou/highlevel/experiment.py | 2 +- tianshou/highlevel/module/actor.py | 2 +- tianshou/highlevel/params/alpha.py | 2 +- tianshou/highlevel/params/dist_fn.py | 2 +- tianshou/highlevel/params/lr_scheduler.py | 2 +- tianshou/highlevel/params/policy_wrapper.py | 6 +- tianshou/highlevel/trainer.py | 4 +- tianshou/highlevel/world.py | 2 +- tianshou/trainer/trainer.py | 14 ++--- 135 files changed, 404 insertions(+), 387 deletions(-) diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index b38ead76b..3572ddf05 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -58,9 +58,9 @@ "import gymnasium as gym\n", "import torch\n", "\n", + "from tianshou.algorithm import PPOPolicy\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.algorithm import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import ActorCritic, Net\n", "from tianshou.utils.net.discrete import Actor, Critic\n", diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 2707e2869..fec0c5cfe 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -52,6 +52,12 @@ "import numpy as np\n", "import torch\n", "\n", + "from tianshou.algorithm import BasePolicy\n", + "from tianshou.algorithm.modelfree.pg import (\n", + " PGTrainingStats,\n", + " TDistFnDiscrOrCont,\n", + " TPGTrainingStats,\n", + ")\n", "from tianshou.data import (\n", " Batch,\n", " ReplayBuffer,\n", @@ -66,12 +72,6 @@ " ObsBatchProtocol,\n", " RolloutBatchProtocol,\n", ")\n", - "from tianshou.algorithm import BasePolicy\n", - "from tianshou.algorithm.modelfree.pg import (\n", - " PGTrainingStats,\n", - " TDistFnDiscrOrCont,\n", - " TPGTrainingStats,\n", - ")\n", "from tianshou.utils import RunningMeanStd\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor\n", diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 04f259722..d7aaa9fb3 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -58,9 +58,9 @@ "import gymnasium as gym\n", "import torch\n", "\n", + "from tianshou.algorithm import PGPolicy\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.algorithm import PGPolicy\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor" ] diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index 58fa3d40e..ffa18168b 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -73,9 +73,9 @@ "import gymnasium as gym\n", "import torch\n", "\n", + "from tianshou.algorithm import PGPolicy\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.algorithm import PGPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor\n", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index c30fa08f7..8cfc25c1d 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -71,9 +71,9 @@ "import gymnasium as gym\n", "import torch\n", "\n", + "from tianshou.algorithm import PPOPolicy\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.algorithm import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import ActorCritic, Net\n", "from tianshou.utils.net.discrete import Actor, Critic\n", diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index ade97544c..9b653e909 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -7,14 +7,14 @@ import numpy as np import torch -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import C51Net -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import C51 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import C51Net +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 8c57d076f..d798f469c 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -7,15 +7,15 @@ import numpy as np import torch -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import DQNet -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOffPolicyWrapper from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import IntrinsicCuriosityModule diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 6cc0aa0ca..6aabc1171 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -7,14 +7,14 @@ import numpy as np import torch -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import DQNet -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import FQF from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.fqf import FQFPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 056da4288..47e8af2e5 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -7,14 +7,14 @@ import numpy as np import torch -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import DQNet -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import IQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.iqn import IQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import ImplicitQuantileNetwork diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index bb76758e4..22399fb27 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -9,6 +9,11 @@ import numpy as np import torch +from tianshou.algorithm import PPO +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper +from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import ( DQNet, @@ -18,11 +23,6 @@ ) from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.algorithm import PPO -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper -from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index e99ebc224..7f4186dbd 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -16,8 +16,8 @@ ExperimentConfig, PPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.algorithm_params import PPOParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.policy_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 0d0a4095d..28679b9e8 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -7,14 +7,14 @@ import numpy as np import torch -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import QRDQNet -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import QRDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import QRDQNet +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index ac1a1de04..f62123d1f 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -7,6 +7,10 @@ import numpy as np import torch +from tianshou.algorithm import C51, RainbowDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,10 +20,6 @@ from tianshou.env.atari.atari_network import Rainbow from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.algorithm import C51, RainbowDQN -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.c51 import C51Policy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index c88dd27ee..79b51ea65 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -7,15 +7,15 @@ import numpy as np import torch -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import DQNet -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DiscreteSAC, ICMOffPolicyWrapper from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 8e8d743d1..75da41456 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -16,8 +16,8 @@ DiscreteSACExperimentBuilder, ExperimentConfig, ) -from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.algorithm_params import DiscreteSACParams +from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.policy_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index de5c5c739..d9c0464c4 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -7,12 +7,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 31f7af4d4..19d9ba467 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -8,12 +8,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.algorithm import BDQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.bdqn import BDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 4729490ff..1fa5be9d1 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -9,12 +9,12 @@ from gymnasium.core import WrapperActType, WrapperObsType from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv from tianshou.algorithm import SAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index e2e4ca628..010a48400 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -7,12 +7,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 6aa140f9e..44c84dc25 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -7,13 +7,13 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.exploration import OUNoise from tianshou.algorithm import SAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import OUNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 39ffeb4ea..5194124a4 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -2,9 +2,9 @@ from torch.utils.tensorboard import SummaryWriter import tianshou as ts -from tianshou.data import CollectStats from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import CollectStats from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.space_info import SpaceInfo diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index f1d2f6fb6..8a691184c 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -14,6 +14,10 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import GAIL +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import ( Batch, Collector, @@ -23,10 +27,6 @@ ) from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs -from tianshou.algorithm import GAIL -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic -from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 717a194a6..f6b4eae34 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -11,12 +11,12 @@ from torch import nn from torch.distributions import Distribution, Independent, Normal -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import A2C from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 2f21cd56a..224d55fcb 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -15,8 +15,8 @@ ExperimentConfig, ) from tianshou.highlevel.optim import OptimizerFactoryFactoryRMSprop -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.algorithm_params import A2CParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index b9849dca0..3cadc2034 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -9,13 +9,13 @@ import torch from mujoco_env import make_mujoco_env -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.exploration import GaussianNoise -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DDPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.exploration import GaussianNoise +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index deb44b485..4ed64177f 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -12,8 +12,8 @@ DDPGExperimentBuilder, ExperimentConfig, ) -from tianshou.highlevel.params.noise import MaxActionScaledGaussian from tianshou.highlevel.params.algorithm_params import DDPGParams +from tianshou.highlevel.params.noise import MaxActionScaledGaussian def main( diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 10116ac8a..b5004997e 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -11,12 +11,12 @@ from torch import nn from torch.distributions import Distribution, Independent, Normal -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index bb29a92d9..2a6e372b4 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -14,8 +14,8 @@ ExperimentConfig, NPGExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.algorithm_params import NPGParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index d504c8b41..a12c00ca5 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -11,12 +11,12 @@ from torch import nn from torch.distributions import Distribution, Independent, Normal -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 213bf8218..73a2cb711 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -14,8 +14,8 @@ ExperimentConfig, PPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.algorithm_params import PPOParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 36a6ae239..1f2b0aae1 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -28,8 +28,8 @@ PPOExperimentBuilder, ) from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.algorithm_params import PPOParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear log = logging.getLogger(__name__) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index d8e313385..58bf51fc4 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -9,13 +9,13 @@ import torch from mujoco_env import make_mujoco_env -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import REDQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 633f34466..dafcadacc 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -13,8 +13,8 @@ ExperimentConfig, REDQExperimentBuilder, ) -from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.algorithm_params import REDQParams +from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault def main( diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 7c7c92f34..9d9bae48e 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -11,12 +11,12 @@ from torch import nn from torch.distributions import Distribution, Independent, Normal -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 15a1b66c2..156c05fff 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -14,8 +14,8 @@ ExperimentConfig, ReinforceExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.algorithm_params import ReinforceParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 67547e0a0..bc45e8358 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -9,12 +9,12 @@ import torch from mujoco_env import make_mujoco_env -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import SAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index cf94953e8..da5f338da 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -12,8 +12,8 @@ ExperimentConfig, SACExperimentBuilder, ) -from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.algorithm_params import SACParams +from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault def main( diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 8367b0a07..524087006 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -9,13 +9,13 @@ import torch from mujoco_env import make_mujoco_env -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.exploration import GaussianNoise -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import TD3 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.exploration import GaussianNoise +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index a04f03d7a..717e1d8d5 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -13,11 +13,11 @@ ExperimentConfig, TD3ExperimentBuilder, ) +from tianshou.highlevel.params.algorithm_params import TD3Params from tianshou.highlevel.params.env_param import MaxActionScaled from tianshou.highlevel.params.noise import ( MaxActionScaledGaussian, ) -from tianshou.highlevel.params.algorithm_params import TD3Params def main( diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index e5e6e40f7..b02af4cd6 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -11,12 +11,12 @@ from torch import nn from torch.distributions import Distribution, Independent, Normal -from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index a70ba692c..0399fea8d 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -14,8 +14,8 @@ ExperimentConfig, TRPOExperimentBuilder, ) -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.algorithm_params import TRPOParams +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 4b239ea45..415cf2df9 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -12,14 +12,14 @@ from gymnasium.spaces import Discrete from examples.offline.utils import load_buffer -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import DQNet -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DiscreteBCQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OfflineTrainerParams from tianshou.utils.net.discrete import DiscreteActor diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 591a44b80..edad3a9cc 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -13,14 +13,14 @@ from gymnasium.spaces import Discrete from examples.offline.utils import load_buffer -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import QRDQNet -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DiscreteCQL from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import QRDQNet +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index e94492283..ea5c4fc9c 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -12,14 +12,14 @@ from gymnasium.spaces import Discrete from examples.offline.utils import load_buffer -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import DQNet -from tianshou.env.atari.atari_wrapper import make_atari_env -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import DiscreteCRR from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.env.atari.atari_wrapper import make_atari_env +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OfflineTrainerParams from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 2a403f363..10f9d4159 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -11,13 +11,16 @@ import torch from examples.offline.utils import load_buffer +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.imitation.imitation_base import ( + ImitationPolicy, + OfflineImitationLearning, +) +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.imitation.imitation_base import ImitationPolicy, OfflineImitationLearning -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils.space_info import SpaceInfo diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 91e51690e..997c02098 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -11,12 +11,12 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl -from tianshou.data import Collector, CollectStats -from tianshou.env import SubprocVectorEnv from tianshou.algorithm import BCQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.bcq import BCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats +from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 04ba3ce0f..c27b1ffed 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -11,12 +11,12 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl -from tianshou.data import Collector, CollectStats -from tianshou.env import SubprocVectorEnv from tianshou.algorithm import CQL from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats +from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 3e6c09150..0bb5cb3a4 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -11,11 +11,14 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl -from tianshou.data import Collector, CollectStats -from tianshou.env import SubprocVectorEnv from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.imitation.imitation_base import ImitationPolicy, OfflineImitationLearning +from tianshou.algorithm.imitation.imitation_base import ( + ImitationPolicy, + OfflineImitationLearning, +) from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats +from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index b75f3eedb..e7d6609ae 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -11,13 +11,13 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer -from tianshou.data import Collector, CollectStats -from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs -from tianshou.exploration import GaussianNoise from tianshou.algorithm import TD3BC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats +from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs +from tianshou.exploration import GaussianNoise from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index b6fd383f7..9f1e07c5d 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -8,13 +8,13 @@ import torch from env import make_vizdoom_env -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import C51Net -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import C51 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.c51 import C51Policy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import C51Net +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index e23af08f3..4004fdc7e 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -9,14 +9,14 @@ from env import make_vizdoom_env from torch.distributions import Categorical -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env.atari.atari_network import DQNet -from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env.atari.atari_network import DQNet +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils.net.discrete import ( DiscreteActor, diff --git a/test/base/test_collector.py b/test/base/test_collector.py index e3a42f3c9..ef1ec1e14 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -7,6 +7,7 @@ import pytest import tqdm +from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.data import ( AsyncCollector, Batch, @@ -25,7 +26,6 @@ ) from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go try: import envpool diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 3d7a02e31..32289bfec 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -11,6 +11,7 @@ from gymnasium.spaces import Box from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tianshou.algorithm.algorithm_base import Policy from tianshou.data import Batch, Collector, CollectStats from tianshou.data.types import ( ActBatchProtocol, @@ -19,7 +20,6 @@ ) from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type -from tianshou.algorithm.algorithm_base import Policy class DummyDataset(Dataset): diff --git a/test/base/test_policy.py b/test/base/test_policy.py index f0048d931..c357b7e87 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -4,11 +4,14 @@ import torch from torch.distributions import Categorical, Distribution, Independent, Normal -from tianshou.data import Batch from tianshou.algorithm import PPO -from tianshou.algorithm.algorithm_base import RandomActionPolicy, episode_mc_return_to_go +from tianshou.algorithm.algorithm_base import ( + RandomActionPolicy, + episode_mc_return_to_go, +) from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Batch from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.net.discrete import DiscreteActor diff --git a/test/base/test_returns.py b/test/base/test_returns.py index b1a409068..1e2b00dd2 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -3,9 +3,9 @@ import numpy as np import torch +from tianshou.algorithm import Algorithm from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import RolloutBatchProtocol -from tianshou.algorithm import Algorithm def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: diff --git a/test/base/test_stats.py b/test/base/test_stats.py index 59ae349da..2b5630465 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -5,9 +5,9 @@ import torch from torch.distributions import Categorical, Normal +from tianshou.algorithm.algorithm_base import TrainingStats, TrainingStatsWrapper from tianshou.data import Batch, CollectStats from tianshou.data.collector import CollectStepBatchProtocol, get_stddev_from_dist -from tianshou.algorithm.algorithm_base import TrainingStats, TrainingStatsWrapper class DummyTrainingStatsWrapper(TrainingStatsWrapper): diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index aca34fa27..4de6c7dc2 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -7,13 +7,13 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.exploration import GaussianNoise from tianshou.algorithm import DDPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import GaussianNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 638332701..3f94d6f03 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -9,12 +9,12 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index c10a1fe8f..bad8edbc8 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -8,12 +8,12 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 9769b946d..d47c7a559 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -8,13 +8,13 @@ import torch.nn as nn from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import REDQ from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.redq import REDQPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 7d4d43a9a..a1a0f60cc 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -7,13 +7,13 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import SAC, OffPolicyImitationLearning from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.imitation.imitation_base import ImitationPolicy from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index e9c31f4e5..fabae724d 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -7,13 +7,13 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.exploration import GaussianNoise from tianshou.algorithm import TD3 from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import GaussianNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 8865c47a0..dec82d7f5 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -9,12 +9,12 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index f5c123bab..906695879 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -8,13 +8,12 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.algorithm import A2C, OffPolicyImitationLearning -from tianshou.algorithm import Algorithm +from tianshou.algorithm import A2C, Algorithm, OffPolicyImitationLearning from tianshou.algorithm.imitation.imitation_base import ImitationPolicy from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 2a5a797da..38ca7b64b 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -5,11 +5,11 @@ import numpy as np import torch -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import ContinuousToDiscrete, DummyVectorEnv from tianshou.algorithm import BDQN from tianshou.algorithm.modelfree.bdqn import BDQNPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import ContinuousToDiscrete, DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils.net.common import BranchingNet from tianshou.utils.torch_utils import policy_within_training_step diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index eaefe8d32..74e289e0d 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -8,6 +8,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import C51 +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,10 +20,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import C51 -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.c51 import C51Policy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 879115f0f..108b2edb1 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -7,8 +7,6 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import DiscreteSAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.discrete_sac import ( @@ -16,6 +14,8 @@ ) from tianshou.algorithm.modelfree.sac import AutoAlpha from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 008593b53..99f580b83 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -15,10 +19,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import DQN -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 912f6f7fd..d2e0792a1 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -7,12 +7,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import DQN from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 3a6ba65f1..4a49f8079 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import FQF +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.fqf import FQFPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -15,10 +19,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import FQF -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.fqf import FQFPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index dd93df6ea..95108929b 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import IQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.iqn import IQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -15,10 +19,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import IQN -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.iqn import IQNPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 461131e42..218b40ec8 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -8,12 +8,12 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index 21f10f378..ffdc8822f 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -8,12 +8,12 @@ import torch.nn as nn from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index c99d02563..f51ca1703 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import QRDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -14,10 +18,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import QRDQN -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 4d74976c8..ffd53f758 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -8,6 +8,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import RainbowDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.c51 import C51Policy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -15,10 +19,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import RainbowDQN -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.c51 import C51Policy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index a3286b36c..f5a800b2c 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -6,6 +6,9 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import DQN, Algorithm, ICMOffPolicyWrapper +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -13,9 +16,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import DQN, Algorithm, ICMOffPolicyWrapper -from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 8511f9b4d..a3a22487e 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -7,13 +7,13 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 51e81e62c..c9b0dffbc 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -6,9 +6,9 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.algorithm import PSRL from tianshou.algorithm.modelbased.psrl import PSRLPolicy +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index cba273ffb..3c88b77c5 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -7,6 +7,10 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import QRDQN +from tianshou.algorithm.algorithm_base import Algorithm +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -14,10 +18,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import QRDQN -from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index be7d5c774..68776dbbb 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -7,12 +7,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import SAC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy, SACTrainingStats from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 715347716..3d02c1ee0 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -10,11 +10,11 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import BCQ, Algorithm from tianshou.algorithm.imitation.bcq import BCQPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 8cda37a7a..69204a976 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -10,11 +10,11 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import CQL, Algorithm from tianshou.algorithm.modelfree.sac import AutoAlpha, SACPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 12faddb05..91f651bf3 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -9,6 +9,9 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import Algorithm, DiscreteBCQ +from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,9 +19,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import Algorithm, DiscreteBCQ -from tianshou.algorithm.imitation.discrete_bcq import DiscreteBCQPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 1bdde5b14..2ce0f73a0 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -9,6 +9,9 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import Algorithm, DiscreteCQL +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,9 +19,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import Algorithm, DiscreteCQL -from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 6878fdfe3..df3f03a75 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -9,6 +9,9 @@ import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.algorithm import Algorithm, DiscreteCRR +from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, CollectStats, @@ -16,9 +19,6 @@ VectorReplayBuffer, ) from tianshou.env import DummyVectorEnv -from tianshou.algorithm import Algorithm, DiscreteCRR -from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy -from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 86a84da35..f4f5236c1 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -10,11 +10,11 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv from tianshou.algorithm import GAIL, Algorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index d0f6bcfc4..8dff6a0ca 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -10,13 +10,13 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.exploration import GaussianNoise from tianshou.algorithm import TD3BC from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import GaussianNoise from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 84d4c920a..7a02234a6 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -8,13 +8,13 @@ from pettingzoo.butterfly import pistonball_v6 from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.algorithm import DQN, Algorithm, MultiAgentOffPolicyAlgorithm from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index c4cf23106..98c5f768d 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -11,15 +11,15 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.data.stats import InfoStats -from tianshou.env import DummyVectorEnv -from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.algorithm import PPO, Algorithm from tianshou.algorithm.algorithm_base import OnPolicyAlgorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.multiagent.marl import MultiAgentOnPolicyAlgorithm from tianshou.algorithm.optim import AdamOptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.data.stats import InfoStats +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ModuleWithVectorOutput diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 706e27787..ff621ddc8 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -9,10 +9,6 @@ from pettingzoo.classic import tictactoe_v3 from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.data.stats import InfoStats -from tianshou.env import DummyVectorEnv -from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.algorithm import ( DQN, Algorithm, @@ -22,6 +18,10 @@ from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, OptimizerFactory +from tianshou.data import Collector, CollectStats, VectorReplayBuffer +from tianshou.data.stats import InfoStats +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 0b87ad0f9..13df75b46 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,3 +1,5 @@ +# isort: skip_file +# NOTE: Import order is important to avoid circular import errors! from tianshou import data, env, exploration, algorithm, trainer, utils __version__ = "1.2.0-dev" diff --git a/tianshou/algorithm/algorithm_base.py b/tianshou/algorithm/algorithm_base.py index d404fc392..ea960ea68 100644 --- a/tianshou/algorithm/algorithm_base.py +++ b/tianshou/algorithm/algorithm_base.py @@ -41,6 +41,7 @@ from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode if TYPE_CHECKING: + from tianshou.data.stats import InfoStats from tianshou.trainer import ( OfflineTrainer, OfflineTrainerParams, @@ -51,7 +52,7 @@ Trainer, TrainerParams, ) - from tianshou.data.stats import InfoStats + mark_used(TrainerParams) logger = logging.getLogger(__name__) diff --git a/tianshou/algorithm/imitation/bcq.py b/tianshou/algorithm/imitation/bcq.py index 7c7129104..609621ae4 100644 --- a/tianshou/algorithm/imitation/bcq.py +++ b/tianshou/algorithm/imitation/bcq.py @@ -7,9 +7,6 @@ import torch import torch.nn.functional as F -from tianshou.data import Batch, to_torch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, @@ -17,6 +14,9 @@ TrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.utils.net.continuous import VAE diff --git a/tianshou/algorithm/imitation/cql.py b/tianshou/algorithm/imitation/cql.py index 1013a146a..f37b03c2d 100644 --- a/tianshou/algorithm/imitation/cql.py +++ b/tianshou/algorithm/imitation/cql.py @@ -7,15 +7,15 @@ import torch.nn.functional as F from overrides import override -from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.buffer.buffer_base import TBuffer -from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, OfflineAlgorithm, ) from tianshou.algorithm.modelfree.sac import Alpha, SACPolicy, SACTrainingStats from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.data.buffer.buffer_base import TBuffer +from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.conversion import to_optional_float from tianshou.utils.torch_utils import torch_device diff --git a/tianshou/algorithm/imitation/discrete_bcq.py b/tianshou/algorithm/imitation/discrete_bcq.py index 6fabc0bb3..495b6ffa1 100644 --- a/tianshou/algorithm/imitation/discrete_bcq.py +++ b/tianshou/algorithm/imitation/discrete_bcq.py @@ -7,13 +7,6 @@ import torch import torch.nn.functional as F -from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.types import ( - BatchWithReturnsProtocol, - ImitationBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, @@ -21,6 +14,13 @@ from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.data.types import ( + BatchWithReturnsProtocol, + ImitationBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) float_info = torch.finfo(torch.float32) INF = float_info.max diff --git a/tianshou/algorithm/imitation/discrete_cql.py b/tianshou/algorithm/imitation/discrete_cql.py index 3f2985b39..a3f832902 100644 --- a/tianshou/algorithm/imitation/discrete_cql.py +++ b/tianshou/algorithm/imitation/discrete_cql.py @@ -4,13 +4,13 @@ import torch import torch.nn.functional as F -from tianshou.data import to_torch -from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm import QRDQN from tianshou.algorithm.algorithm_base import OfflineAlgorithm from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import to_torch +from tianshou.data.types import RolloutBatchProtocol @dataclass(kw_only=True) diff --git a/tianshou/algorithm/imitation/discrete_crr.py b/tianshou/algorithm/imitation/discrete_crr.py index 69feb7c97..902b0e317 100644 --- a/tianshou/algorithm/imitation/discrete_crr.py +++ b/tianshou/algorithm/imitation/discrete_crr.py @@ -7,8 +7,6 @@ from torch.distributions import Categorical from torch.nn import ModuleList -from tianshou.data import ReplayBuffer, to_torch, to_torch_as -from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, @@ -19,6 +17,8 @@ SimpleLossTrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ReplayBuffer, to_torch, to_torch_as +from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/algorithm/imitation/gail.py b/tianshou/algorithm/imitation/gail.py index af83b12b8..dbeddd1c3 100644 --- a/tianshou/algorithm/imitation/gail.py +++ b/tianshou/algorithm/imitation/gail.py @@ -4,6 +4,10 @@ import torch import torch.nn.functional as F +from tianshou.algorithm.modelfree.a2c import A2CTrainingStats +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.ppo import PPO +from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ( ReplayBuffer, SequenceSummaryStats, @@ -11,10 +15,6 @@ to_torch, ) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol -from tianshou.algorithm.modelfree.a2c import A2CTrainingStats -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic -from tianshou.algorithm.modelfree.ppo import PPO -from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.common import ModuleWithVectorOutput from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/algorithm/imitation/imitation_base.py b/tianshou/algorithm/imitation/imitation_base.py index 61fbd3a6b..b21bd3132 100644 --- a/tianshou/algorithm/imitation/imitation_base.py +++ b/tianshou/algorithm/imitation/imitation_base.py @@ -6,13 +6,6 @@ import torch import torch.nn.functional as F -from tianshou.data import Batch, to_torch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ModelOutputBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( OfflineAlgorithm, @@ -21,6 +14,13 @@ TrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) # Dimension Naming Convention # B - Batch Size diff --git a/tianshou/algorithm/imitation/td3_bc.py b/tianshou/algorithm/imitation/td3_bc.py index f091bba0a..50f339673 100644 --- a/tianshou/algorithm/imitation/td3_bc.py +++ b/tianshou/algorithm/imitation/td3_bc.py @@ -1,13 +1,13 @@ import torch import torch.nn.functional as F -from tianshou.data import to_torch_as -from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm import TD3 from tianshou.algorithm.algorithm_base import OfflineAlgorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.modelfree.td3 import TD3TrainingStats from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import to_torch_as +from tianshou.data.types import RolloutBatchProtocol # NOTE: This uses diamond inheritance to convert from off-policy to offline diff --git a/tianshou/algorithm/modelbased/icm.py b/tianshou/algorithm/modelbased/icm.py index f40e17116..72a3beedf 100644 --- a/tianshou/algorithm/modelbased/icm.py +++ b/tianshou/algorithm/modelbased/icm.py @@ -2,9 +2,6 @@ import torch import torch.nn.functional as F -from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( OffPolicyAlgorithm, @@ -16,6 +13,9 @@ TrainingStatsWrapper, ) from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.net.discrete import IntrinsicCuriosityModule diff --git a/tianshou/algorithm/modelbased/psrl.py b/tianshou/algorithm/modelbased/psrl.py index c9c5296cf..bd13deb7a 100644 --- a/tianshou/algorithm/modelbased/psrl.py +++ b/tianshou/algorithm/modelbased/psrl.py @@ -5,14 +5,14 @@ import numpy as np import torch -from tianshou.data import Batch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.algorithm.algorithm_base import ( OnPolicyAlgorithm, Policy, TrainingStats, ) +from tianshou.data import Batch +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol @dataclass(kw_only=True) diff --git a/tianshou/algorithm/modelfree/a2c.py b/tianshou/algorithm/modelfree/a2c.py index 0bff04489..f4f1d4b1e 100644 --- a/tianshou/algorithm/modelfree/a2c.py +++ b/tianshou/algorithm/modelfree/a2c.py @@ -6,14 +6,14 @@ import torch import torch.nn.functional as F -from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as -from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.algorithm.algorithm_base import ( OnPolicyAlgorithm, TrainingStats, ) from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.continuous import ContinuousCritic diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index 01707de58..dd88b7600 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -5,6 +5,13 @@ import torch from sensai.util.helper import mark_used +from tianshou.algorithm.algorithm_base import TArrOrActBatch +from tianshou.algorithm.modelfree.dqn import ( + DiscreteQLearningPolicy, + QLearningOffPolicyAlgorithm, +) +from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( @@ -14,13 +21,6 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm.algorithm_base import TArrOrActBatch -from tianshou.algorithm.modelfree.dqn import ( - DiscreteQLearningPolicy, - QLearningOffPolicyAlgorithm, -) -from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats -from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.common import BranchingNet mark_used(ActBatchProtocol) diff --git a/tianshou/algorithm/modelfree/c51.py b/tianshou/algorithm/modelfree/c51.py index 65b89a914..97d86386f 100644 --- a/tianshou/algorithm/modelfree/c51.py +++ b/tianshou/algorithm/modelfree/c51.py @@ -2,14 +2,14 @@ import numpy as np import torch -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) from tianshou.algorithm.modelfree.pg import LossSequenceTrainingStats from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.net.common import Net diff --git a/tianshou/algorithm/modelfree/ddpg.py b/tianshou/algorithm/modelfree/ddpg.py index 62868a2f4..127b515bd 100644 --- a/tianshou/algorithm/modelfree/ddpg.py +++ b/tianshou/algorithm/modelfree/ddpg.py @@ -8,16 +8,6 @@ import torch from sensai.util.helper import mark_used -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ActBatchProtocol, - ActStateBatchProtocol, - BatchWithReturnsProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( LaggedNetworkPolyakUpdateAlgorithmMixin, @@ -28,6 +18,16 @@ TrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + ActStateBatchProtocol, + BatchWithReturnsProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.utils.net.continuous import ( ContinuousActorDeterministicInterface, ContinuousCritic, diff --git a/tianshou/algorithm/modelfree/discrete_sac.py b/tianshou/algorithm/modelfree/discrete_sac.py index 16d211754..65094a191 100644 --- a/tianshou/algorithm/modelfree/discrete_sac.py +++ b/tianshou/algorithm/modelfree/discrete_sac.py @@ -6,6 +6,10 @@ import torch from torch.distributions import Categorical +from tianshou.algorithm.algorithm_base import Policy +from tianshou.algorithm.modelfree.sac import Alpha, SACTrainingStats +from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, to_torch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( @@ -13,10 +17,6 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm.algorithm_base import Policy -from tianshou.algorithm.modelfree.sac import Alpha, SACTrainingStats -from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm -from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index 8e2042c21..570c1c1ed 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -7,15 +7,6 @@ import torch from sensai.util.helper import mark_used -from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ActBatchProtocol, - BatchWithReturnsProtocol, - ModelOutputBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( LaggedNetworkFullUpdateAlgorithmMixin, @@ -27,6 +18,15 @@ SimpleLossTrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ( + ActBatchProtocol, + BatchWithReturnsProtocol, + ModelOutputBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.common import Net diff --git a/tianshou/algorithm/modelfree/fqf.py b/tianshou/algorithm/modelfree/fqf.py index b2110f2eb..eed8d7c5a 100644 --- a/tianshou/algorithm/modelfree/fqf.py +++ b/tianshou/algorithm/modelfree/fqf.py @@ -7,13 +7,13 @@ import torch.nn.functional as F from overrides import override -from tianshou.data import Batch, ReplayBuffer, to_numpy -from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.algorithm import QRDQN, Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction diff --git a/tianshou/algorithm/modelfree/iqn.py b/tianshou/algorithm/modelfree/iqn.py index 977e64797..cb6881996 100644 --- a/tianshou/algorithm/modelfree/iqn.py +++ b/tianshou/algorithm/modelfree/iqn.py @@ -5,6 +5,10 @@ import torch import torch.nn.functional as F +from tianshou.algorithm import QRDQN +from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, to_numpy from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( @@ -12,10 +16,6 @@ QuantileRegressionBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm import QRDQN -from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats -from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy -from tianshou.algorithm.optim import OptimizerFactory class IQNPolicy(QRDQNPolicy): diff --git a/tianshou/algorithm/modelfree/npg.py b/tianshou/algorithm/modelfree/npg.py index cc257466c..21c200ea5 100644 --- a/tianshou/algorithm/modelfree/npg.py +++ b/tianshou/algorithm/modelfree/npg.py @@ -7,12 +7,12 @@ from torch import nn from torch.distributions import kl_divergence -from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as -from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.algorithm.modelfree.a2c import ActorCriticOnPolicyAlgorithm from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/algorithm/modelfree/pg.py b/tianshou/algorithm/modelfree/pg.py index e82485b0c..0603524e3 100644 --- a/tianshou/algorithm/modelfree/pg.py +++ b/tianshou/algorithm/modelfree/pg.py @@ -8,6 +8,13 @@ import numpy as np import torch +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import ( + OnPolicyAlgorithm, + Policy, + TrainingStats, +) +from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ( Batch, ReplayBuffer, @@ -22,13 +29,6 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.algorithm import Algorithm -from tianshou.algorithm.algorithm_base import ( - OnPolicyAlgorithm, - Policy, - TrainingStats, -) -from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ( ContinuousActorProbabilisticInterface, diff --git a/tianshou/algorithm/modelfree/ppo.py b/tianshou/algorithm/modelfree/ppo.py index b47d9be3c..a81bf4ffc 100644 --- a/tianshou/algorithm/modelfree/ppo.py +++ b/tianshou/algorithm/modelfree/ppo.py @@ -3,12 +3,12 @@ import numpy as np import torch -from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as -from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.algorithm import A2C from tianshou.algorithm.modelfree.a2c import A2CTrainingStats from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as +from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/algorithm/modelfree/qrdqn.py b/tianshou/algorithm/modelfree/qrdqn.py index 5e0e9ebb7..883086d6a 100644 --- a/tianshou/algorithm/modelfree/qrdqn.py +++ b/tianshou/algorithm/modelfree/qrdqn.py @@ -5,14 +5,14 @@ import torch import torch.nn.functional as F -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm.modelfree.dqn import ( DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.types import RolloutBatchProtocol class QRDQNPolicy(DiscreteQLearningPolicy): diff --git a/tianshou/algorithm/modelfree/rainbow.py b/tianshou/algorithm/modelfree/rainbow.py index ba9466a7e..6efce699e 100644 --- a/tianshou/algorithm/modelfree/rainbow.py +++ b/tianshou/algorithm/modelfree/rainbow.py @@ -2,10 +2,10 @@ from torch import nn -from tianshou.data.types import RolloutBatchProtocol from tianshou.algorithm.modelfree.c51 import C51, C51Policy from tianshou.algorithm.modelfree.pg import LossSequenceTrainingStats from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.lagged_network import EvalModeModuleWrapper from tianshou.utils.net.discrete import NoisyLinear diff --git a/tianshou/algorithm/modelfree/redq.py b/tianshou/algorithm/modelfree/redq.py index fa1ceb5a4..99f7deadc 100644 --- a/tianshou/algorithm/modelfree/redq.py +++ b/tianshou/algorithm/modelfree/redq.py @@ -6,13 +6,6 @@ import torch from torch.distributions import Independent, Normal -from tianshou.data import Batch -from tianshou.data.types import ( - DistLogProbBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) -from tianshou.exploration import BaseNoise from tianshou.algorithm.modelfree.ddpg import ( ActorCriticOffPolicyAlgorithm, ContinuousPolicyWithExplorationNoise, @@ -20,6 +13,13 @@ ) from tianshou.algorithm.modelfree.sac import Alpha from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch +from tianshou.data.types import ( + DistLogProbBatchProtocol, + ObsBatchProtocol, + RolloutBatchProtocol, +) +from tianshou.exploration import BaseNoise from tianshou.utils.net.continuous import ContinuousActorProbabilistic diff --git a/tianshou/algorithm/modelfree/sac.py b/tianshou/algorithm/modelfree/sac.py index 62eaa824f..695e773ee 100644 --- a/tianshou/algorithm/modelfree/sac.py +++ b/tianshou/algorithm/modelfree/sac.py @@ -7,6 +7,10 @@ import torch from torch.distributions import Independent, Normal +from tianshou.algorithm.algorithm_base import TrainingStats +from tianshou.algorithm.modelfree.ddpg import ContinuousPolicyWithExplorationNoise +from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm +from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch from tianshou.data.types import ( DistLogProbBatchProtocol, @@ -14,10 +18,6 @@ RolloutBatchProtocol, ) from tianshou.exploration import BaseNoise -from tianshou.algorithm.algorithm_base import TrainingStats -from tianshou.algorithm.modelfree.ddpg import ContinuousPolicyWithExplorationNoise -from tianshou.algorithm.modelfree.td3 import ActorDualCriticsOffPolicyAlgorithm -from tianshou.algorithm.optim import OptimizerFactory from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ContinuousActorProbabilistic diff --git a/tianshou/algorithm/modelfree/td3.py b/tianshou/algorithm/modelfree/td3.py index faf7bf36b..0616c5d64 100644 --- a/tianshou/algorithm/modelfree/td3.py +++ b/tianshou/algorithm/modelfree/td3.py @@ -5,11 +5,6 @@ import torch -from tianshou.data import Batch -from tianshou.data.types import ( - ActStateBatchProtocol, - RolloutBatchProtocol, -) from tianshou.algorithm.algorithm_base import ( TPolicy, TrainingStats, @@ -20,6 +15,11 @@ TActBatchProtocol, ) from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import Batch +from tianshou.data.types import ( + ActStateBatchProtocol, + RolloutBatchProtocol, +) @dataclass(kw_only=True) diff --git a/tianshou/algorithm/modelfree/trpo.py b/tianshou/algorithm/modelfree/trpo.py index 71e5ba9d2..f30ef7fba 100644 --- a/tianshou/algorithm/modelfree/trpo.py +++ b/tianshou/algorithm/modelfree/trpo.py @@ -5,12 +5,12 @@ import torch.nn.functional as F from torch.distributions import kl_divergence -from tianshou.data import SequenceSummaryStats -from tianshou.data.types import BatchWithAdvantagesProtocol from tianshou.algorithm import NPG from tianshou.algorithm.modelfree.npg import NPGTrainingStats from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory +from tianshou.data import SequenceSummaryStats +from tianshou.data.types import BatchWithAdvantagesProtocol from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic diff --git a/tianshou/algorithm/multiagent/marl.py b/tianshou/algorithm/multiagent/marl.py index 29be11905..1c30a1cbc 100644 --- a/tianshou/algorithm/multiagent/marl.py +++ b/tianshou/algorithm/multiagent/marl.py @@ -6,9 +6,6 @@ from sensai.util.helper import mark_used from torch.nn import ModuleList -from tianshou.data import Batch, ReplayBuffer -from tianshou.data.batch import BatchProtocol, IndexType -from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.algorithm import Algorithm from tianshou.algorithm.algorithm_base import ( OffPolicyAlgorithm, @@ -16,6 +13,9 @@ Policy, TrainingStats, ) +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.batch import BatchProtocol, IndexType +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol try: from tianshou.env.pettingzoo_env import PettingZooEnv diff --git a/tianshou/algorithm/random.py b/tianshou/algorithm/random.py index 92da3fad6..b374ef301 100644 --- a/tianshou/algorithm/random.py +++ b/tianshou/algorithm/random.py @@ -3,11 +3,10 @@ import gymnasium as gym import numpy as np +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, Policy, TrainingStats from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol -from tianshou.algorithm import algorithm_base -from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, TrainingStats, Policy class MARLRandomTrainingStats(TrainingStats): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 60200a307..bbf8d69d7 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -12,6 +12,8 @@ from overrides import override from torch.distributions import Categorical, Distribution +from tianshou.algorithm import Algorithm +from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.data import ( Batch, CachedReplayBuffer, @@ -30,8 +32,6 @@ RolloutBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.algorithm import Algorithm -from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go from tianshou.utils.determinism import TraceLogger from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index c56e1c075..4deda628f 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -8,8 +8,8 @@ from tianshou.utils.print import DataclassPPrintMixin if TYPE_CHECKING: - from tianshou.data import CollectStats, CollectStatsBase from tianshou.algorithm.algorithm_base import TrainingStats + from tianshou.data import CollectStats, CollectStatsBase log = logging.getLogger(__name__) diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index 6b2848a29..1d83bac6d 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -5,6 +5,7 @@ import torch from torch import nn +from tianshou.algorithm.modelfree.pg import TDistFnDiscrOrCont from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import ( @@ -15,7 +16,6 @@ IntermediateModuleFactory, ) from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical -from tianshou.algorithm.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import NetBase from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear from tianshou.utils.torch_utils import torch_device diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py index 43a7066b1..5e3d21235 100644 --- a/tianshou/env/worker/__init__.py +++ b/tianshou/env/worker/__init__.py @@ -1,3 +1,5 @@ +# isort:skip_file +# NOTE: Import order is important to avoid circular import errors! from tianshou.env.worker.worker_base import EnvWorker from tianshou.env.worker.dummy import DummyEnvWorker from tianshou.env.worker.ray import RayEnvWorker diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 1ccdbed22..8a8ecce71 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -7,6 +7,33 @@ import torch from sensai.util.string import ToStringMixin +from tianshou.algorithm import ( + A2C, + DDPG, + DQN, + IQN, + NPG, + PPO, + REDQ, + SAC, + TD3, + TRPO, + Algorithm, + DiscreteSAC, + Reinforce, +) +from tianshou.algorithm.algorithm_base import ( + OffPolicyAlgorithm, + OnPolicyAlgorithm, + Policy, +) +from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy +from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.modelfree.iqn import IQNPolicy +from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.redq import REDQPolicy +from tianshou.algorithm.modelfree.sac import SACPolicy from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data.collector import BaseCollector, CollectStats from tianshou.highlevel.config import ( @@ -46,35 +73,13 @@ from tianshou.highlevel.persistence import PolicyPersistence from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.world import World -from tianshou.algorithm import ( - A2C, - DDPG, - DQN, - IQN, - NPG, - PPO, - REDQ, - SAC, - TD3, - TRPO, - Algorithm, - DiscreteSAC, - Reinforce, +from tianshou.trainer import ( + OffPolicyTrainer, + OffPolicyTrainerParams, + OnPolicyTrainer, + OnPolicyTrainerParams, + Trainer, ) -from tianshou.algorithm.algorithm_base import ( - OffPolicyAlgorithm, - OnPolicyAlgorithm, - Policy, -) -from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy -from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy -from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.algorithm.modelfree.iqn import IQNPolicy -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic -from tianshou.algorithm.modelfree.redq import REDQPolicy -from tianshou.algorithm.modelfree.sac import SACPolicy -from tianshou.trainer import OffPolicyTrainer, OnPolicyTrainer, Trainer -from tianshou.trainer import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils.net.discrete import DiscreteActor CHECKPOINT_DICT_KEY_MODEL = "model" diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 3126fa7c4..d5debd788 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -36,6 +36,7 @@ from sensai.util.logging import datetime_tag from sensai.util.string import ToStringMixin +from tianshou.algorithm import Algorithm from tianshou.data import BaseCollector, Collector, CollectStats, InfoStats from tianshou.env import BaseVectorEnv from tianshou.highlevel.algorithm import ( @@ -112,7 +113,6 @@ TrainerCallbacks, ) from tianshou.highlevel.world import World -from tianshou.algorithm import Algorithm from tianshou.utils import LazyLogger from tianshou.utils.net.common import ModuleType from tianshou.utils.print import DataclassPPrintMixin diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 0a6055a5f..5373c7a4f 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -8,6 +8,7 @@ from sensai.util.string import ToStringMixin from torch import nn +from tianshou.algorithm.modelfree.pg import TDistFnDiscrOrCont from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.core import ( ModuleFactory, @@ -22,7 +23,6 @@ DistributionFunctionFactoryCategorical, DistributionFunctionFactoryIndependentGaussians, ) -from tianshou.algorithm.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import Actor, ModuleType, ModuleWithVectorOutput, Net diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 7bd1a3baf..28f13511e 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -3,10 +3,10 @@ import numpy as np from sensai.util.string import ToStringMixin +from tianshou.algorithm.modelfree.sac import Alpha, AutoAlpha from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.optim import OptimizerFactoryFactory -from tianshou.algorithm.modelfree.sac import Alpha, AutoAlpha class AutoAlphaFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index 11e37af2a..b55f33b30 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -5,8 +5,8 @@ import torch from sensai.util.string import ToStringMixin -from tianshou.highlevel.env import Environments from tianshou.algorithm.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont +from tianshou.highlevel.env import Environments class DistributionFunctionFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index fb9652e5f..1d99eca68 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -2,8 +2,8 @@ from sensai.util.string import ToStringMixin -from tianshou.highlevel.config import TrainingConfig from tianshou.algorithm.optim import LRSchedulerFactory, LRSchedulerFactoryLinear +from tianshou.highlevel.config import TrainingConfig class LRSchedulerFactoryFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 21457f10b..9ff957d49 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -4,13 +4,13 @@ from sensai.util.string import ToStringMixin +from tianshou.algorithm import Algorithm, ICMOffPolicyWrapper +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, OnPolicyAlgorithm +from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.optim import OptimizerFactoryFactory -from tianshou.algorithm import Algorithm, ICMOffPolicyWrapper -from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, OnPolicyAlgorithm -from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper from tianshou.utils.net.discrete import IntrinsicCuriosityModule TAlgorithmOut = TypeVar("TAlgorithmOut", bound=Algorithm) diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index bc4c8cf62..452232d08 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -6,10 +6,10 @@ from sensai.util.string import ToStringMixin -from tianshou.highlevel.env import Environments -from tianshou.highlevel.logger import TLogger from tianshou.algorithm import DQN, Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.highlevel.env import Environments +from tianshou.highlevel.logger import TLogger TAlgorithm = TypeVar("TAlgorithm", bound=Algorithm) log = logging.getLogger(__name__) diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index bd92178bf..7f3521773 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -3,10 +3,10 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: + from tianshou.algorithm import Algorithm from tianshou.data import BaseCollector from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger - from tianshou.algorithm import Algorithm from tianshou.trainer import Trainer diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index f8c651d82..43a29ffa7 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -36,6 +36,13 @@ from sensai.util.helper import count_none from sensai.util.string import ToStringMixin +from tianshou.algorithm.algorithm_base import ( + Algorithm, + OfflineAlgorithm, + OffPolicyAlgorithm, + OnPolicyAlgorithm, + TrainingStats, +) from tianshou.data import ( AsyncCollector, CollectStats, @@ -47,13 +54,6 @@ ) from tianshou.data.buffer.buffer_base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase -from tianshou.algorithm.algorithm_base import ( - Algorithm, - OfflineAlgorithm, - OffPolicyAlgorithm, - OnPolicyAlgorithm, - TrainingStats, -) from tianshou.utils import ( BaseLogger, LazyLogger, From 8f8ea0bd15b49b1a275616913f76208bf6e853d3 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 00:31:34 +0200 Subject: [PATCH 177/230] v2: Mention renamed packages and modules in change log --- CHANGELOG.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82edf3c8c..f2803cbaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,7 +75,8 @@ Developers: ### Algorithms and Policies * We now conceptually differentiate between the learning algorithm and the policy being optimised: - * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`. + * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`, and the package was renamed + from `tianshou.policy` to `tianshou.algorithm`. * Migration information: The instantiation of a policy is replaced by the instantiation of an `Algorithm`, which is passed a `Policy`. In most cases, the former policy class name `Policy` is replaced by algorithm class ``; exceptions are noted below. @@ -204,7 +205,8 @@ Developers: dimension as an argument were changed to use `ModuleWithVectorOutput`. * The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance (via adaptation if necessary). - +* All modules containing base classes were renamed from `base` to a more descriptive name, rendering + file names unique. ## Upcoming Release 1.2.0 From 29887a72376399ab3f9f535c202890c8e2d560d4 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 00:54:21 +0200 Subject: [PATCH 178/230] v2: Rename parameter estimation_step -> n_step_return_horizon (more precise naming) --- CHANGELOG.md | 3 ++- README.md | 2 +- examples/atari/atari_c51.py | 2 +- examples/atari/atari_dqn.py | 2 +- examples/atari/atari_dqn_hl.py | 2 +- examples/atari/atari_fqf.py | 2 +- examples/atari/atari_iqn.py | 2 +- examples/atari/atari_iqn_hl.py | 2 +- examples/atari/atari_qrdqn.py | 2 +- examples/atari/atari_rainbow.py | 2 +- examples/atari/atari_sac.py | 2 +- examples/atari/atari_sac_hl.py | 2 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- examples/discrete/discrete_dqn.py | 2 +- examples/discrete/discrete_dqn_hl.py | 2 +- examples/mujoco/fetch_her_ddpg.py | 2 +- examples/mujoco/mujoco_ddpg.py | 2 +- examples/mujoco/mujoco_ddpg_hl.py | 2 +- examples/mujoco/mujoco_redq.py | 2 +- examples/mujoco/mujoco_redq_hl.py | 2 +- examples/mujoco/mujoco_sac.py | 2 +- examples/mujoco/mujoco_sac_hl.py | 2 +- examples/mujoco/mujoco_td3.py | 2 +- examples/mujoco/mujoco_td3_hl.py | 2 +- examples/offline/atari_bcq.py | 2 +- examples/offline/atari_cql.py | 2 +- examples/offline/d4rl_td3_bc.py | 2 +- examples/vizdoom/vizdoom_c51.py | 2 +- test/continuous/test_ddpg.py | 2 +- test/continuous/test_redq.py | 2 +- test/continuous/test_sac_with_il.py | 2 +- test/continuous/test_td3.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_discrete_sac.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_fqf.py | 2 +- test/discrete/test_iqn.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_rainbow.py | 2 +- test/modelbased/test_dqn_icm.py | 2 +- test/offline/gather_cartpole_data.py | 2 +- test/offline/gather_pendulum_data.py | 2 +- test/offline/test_discrete_bcq.py | 2 +- test/offline/test_discrete_cql.py | 2 +- test/offline/test_td3_bc.py | 2 +- test/pettingzoo/pistonball.py | 2 +- test/pettingzoo/tic_tac_toe.py | 2 +- tianshou/algorithm/imitation/discrete_bcq.py | 12 ++++++------ tianshou/algorithm/imitation/discrete_cql.py | 6 +++--- tianshou/algorithm/imitation/td3_bc.py | 4 ++-- tianshou/algorithm/modelfree/bdqn.py | 10 +++++----- tianshou/algorithm/modelfree/c51.py | 6 +++--- tianshou/algorithm/modelfree/ddpg.py | 12 ++++++------ tianshou/algorithm/modelfree/discrete_sac.py | 6 +++--- tianshou/algorithm/modelfree/dqn.py | 16 ++++++++-------- tianshou/algorithm/modelfree/fqf.py | 6 +++--- tianshou/algorithm/modelfree/iqn.py | 6 +++--- tianshou/algorithm/modelfree/qrdqn.py | 6 +++--- tianshou/algorithm/modelfree/rainbow.py | 6 +++--- tianshou/algorithm/modelfree/redq.py | 6 +++--- tianshou/algorithm/modelfree/sac.py | 6 +++--- tianshou/algorithm/modelfree/td3.py | 8 ++++---- tianshou/highlevel/params/algorithm_params.py | 13 +++++++------ 66 files changed, 116 insertions(+), 114 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2803cbaa..aa594619e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -109,6 +109,7 @@ Developers: * removed from Q-learning algorithms, where it was actually unsupported (DQN, C561, etc.) * `clip_grad` -> `max_grad_norm` (for consistency) * `clip_loss_grad` -> `huber_loss_delta` (allowing to control not only the use of the Huber loss but also its essential parameter) + * `estimation_step` -> `n_step_return_horizon` (more precise naming) * Internal design improvements: * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. @@ -146,7 +147,7 @@ Developers: * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) * `CQL`: * Inherit directly from `OfflineAlgorithm` instead of `SAC` (off-policy). - * Remove parameter `estimation_step`, which was not actually used (it was only passed it on to its + * Remove parameter `estimation_step` (now `n_step_return_horizon`), which was not actually used (it was only passed it on to its superclass). * `DiscreteBCQ`: * Inherit directly from `OfflineAlgorithm` instead of `DQN` diff --git a/README.md b/README.md index 0dc181d94..7c6275fdc 100644 --- a/README.md +++ b/README.md @@ -262,7 +262,7 @@ experiment = ( DQNParams( lr=1e-3, discount_factor=0.9, - estimation_step=3, + n_step_return_horizon=3, target_update_freq=320, eps_training=0.3, eps_inference=0.0, diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 9b653e909..990073349 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -105,7 +105,7 @@ def main(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index d798f469c..b8242163a 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -121,7 +121,7 @@ def main(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) if args.icm_lr_scale > 0: diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 77bae646b..c9fcffa52 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -78,7 +78,7 @@ def main( .with_dqn_params( DQNParams( gamma=gamma, - estimation_step=n_step, + n_step_return_horizon=n_step, lr=lr, target_update_freq=target_update_freq, ), diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 6aabc1171..86c44ac3e 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -117,7 +117,7 @@ def main(args: argparse.Namespace = get_args()) -> None: gamma=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 47e8af2e5..6fa661296 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -115,7 +115,7 @@ def main(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 141bdf215..8aa335cea 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -76,7 +76,7 @@ def main( .with_iqn_params( IQNParams( gamma=gamma, - estimation_step=n_step, + n_step_return_horizon=n_step, lr=lr, sample_size=sample_size, online_sample_size=online_sample_size, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 28679b9e8..bd0c426ff 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -109,7 +109,7 @@ def main(args: argparse.Namespace = get_args()) -> None: optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index f62123d1f..64d5f9856 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -129,7 +129,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 79b51ea65..d57608d18 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -149,7 +149,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ).to(args.device) if args.icm_lr_scale > 0: c, h, w = args.state_shape diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 75da41456..429f5360d 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -85,7 +85,7 @@ def main( alpha=AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98) if auto_alpha else alpha, - estimation_step=n_step, + n_step_return_horizon=n_step, ), ) .with_actor_factory(ActorFactoryAtariDQN(scale_obs=False, features_only=True)) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index d9c0464c4..18f833259 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -86,7 +86,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # collector diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 1fa5be9d1..382ca01dd 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -157,7 +157,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # load a previous policy if args.resume_path: diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 010a48400..8d142eb2e 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -88,7 +88,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) # collector diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 5194124a4..38f42007f 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -44,7 +44,7 @@ def main() -> None: policy=policy, optim=optim, gamma=gamma, - estimation_step=n_step, + n_step_return_horizon=n_step, target_update_freq=target_freq, ) train_collector = ts.data.Collector[CollectStats]( diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index 7f8ce777e..a7c05a6f6 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -42,7 +42,7 @@ def main() -> None: DQNParams( lr=1e-3, gamma=0.9, - estimation_step=3, + n_step_return_horizon=3, target_update_freq=320, eps_training=0.3, eps_inference=0.0, diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 8517e3153..9a3812e96 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -182,7 +182,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # load a previous policy diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 3cadc2034..066824d7b 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -112,7 +112,7 @@ def main(args: argparse.Namespace = get_args()) -> None: critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # load a previous policy diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 4ed64177f..06c0588a6 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -67,7 +67,7 @@ def main( gamma=gamma, tau=tau, exploration_noise=MaxActionScaledGaussian(exploration_noise), - estimation_step=n_step, + n_step_return_horizon=n_step, ), ) .with_actor_factory_default(hidden_sizes) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 58bf51fc4..9d905081e 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -135,7 +135,7 @@ def linear(x: int, y: int) -> EnsembleLinear: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, actor_delay=args.update_per_step, target_mode=args.target_mode, ) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index dafcadacc..c35e0abb0 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -73,7 +73,7 @@ def main( gamma=gamma, tau=tau, alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, - estimation_step=n_step, + n_step_return_horizon=n_step, target_mode=target_mode, subset_size=subset_size, ensemble_size=ensemble_size, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index bc45e8358..64fe6a60b 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -130,7 +130,7 @@ def main(args: argparse.Namespace = get_args()) -> None: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # load a previous policy diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index da5f338da..6c3c8352d 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -67,7 +67,7 @@ def main( tau=tau, gamma=gamma, alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, - estimation_step=n_step, + n_step_return_horizon=n_step, actor_lr=actor_lr, critic1_lr=critic_lr, critic2_lr=critic_lr, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 524087006..5e838a9ba 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -131,7 +131,7 @@ def main(args: argparse.Namespace = get_args()) -> None: policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # load a previous policy diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 717e1d8d5..2fd3e731c 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -71,7 +71,7 @@ def main( TD3Params( tau=tau, gamma=gamma, - estimation_step=n_step, + n_step_return_horizon=n_step, update_actor_freq=update_actor_freq, noise_clip=MaxActionScaled(noise_clip), policy_noise=MaxActionScaled(policy_noise), diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 415cf2df9..9c8921071 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -129,7 +129,7 @@ def main(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, imitation_logits_penalty=args.imitation_logits_penalty, ) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index edad3a9cc..e0cfbeee1 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -116,7 +116,7 @@ def main(args: argparse.Namespace = get_args()) -> None: optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, min_q_weight=args.min_q_weight, ).to(args.device) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index e7d6609ae..3242b780b 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -151,7 +151,7 @@ def test_td3_bc() -> None: update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # load a previous policy diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 9f1e07c5d..185739c67 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -111,7 +111,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 4de6c7dc2..b6ec0425c 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -101,7 +101,7 @@ def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = T critic_optim=critic_optim, tau=args.tau, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # collector diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index d47c7a559..f8b2b4407 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -127,7 +127,7 @@ def linear(x: int, y: int) -> nn.Module: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, actor_delay=args.update_per_step, target_mode=args.target_mode, ) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index a1a0f60cc..b9df78c5a 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -135,7 +135,7 @@ def test_sac_with_il( tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # collector train_collector = Collector[CollectStats]( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index fabae724d..8aaf83797 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -117,7 +117,7 @@ def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # collector train_collector = Collector[CollectStats]( diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 74e289e0d..1eef06ea2 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -109,7 +109,7 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 108b2edb1..711068072 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -116,7 +116,7 @@ def test_discrete_sac( tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # collector train_collector = Collector[CollectStats]( diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 99f580b83..475c57caf 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -100,7 +100,7 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index d2e0792a1..2ab6c6a66 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -89,7 +89,7 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) # collector diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 4a49f8079..fd27b28f1 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -115,7 +115,7 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr gamma=args.gamma, num_fractions=args.num_fractions, ent_coef=args.ent_coef, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 95108929b..7e0ac5d57 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -111,7 +111,7 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index f51ca1703..ae08523f3 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -106,7 +106,7 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index ffd53f758..910c9f829 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -118,7 +118,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index f5a800b2c..ada719005 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -116,7 +116,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 3c88b77c5..43928563a 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -108,7 +108,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) # buffer diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 68776dbbb..1dc70486f 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -127,7 +127,7 @@ def gather_data() -> VectorReplayBuffer: tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # collector buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 91f651bf3..a346be9b2 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -103,7 +103,7 @@ def test_discrete_bcq( policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, imitation_logits_penalty=args.imitation_logits_penalty, ) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 2ce0f73a0..c0232ff02 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -95,7 +95,7 @@ def test_discrete_cql( optim=optim, gamma=args.gamma, num_quantiles=args.num_quantiles, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, min_q_weight=args.min_q_weight, ).to(args.device) diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 8dff6a0ca..07c8e7b42 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -144,7 +144,7 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, alpha=args.alpha, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, ) # load a previous policy diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 7a02234a6..907eef5f0 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -112,7 +112,7 @@ def get_agents( policy=policy, optim=optim, gamma=args.gamma, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, target_update_freq=args.target_update_freq, ) algorithms.append(agent) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index ff621ddc8..934bf1464 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -131,7 +131,7 @@ def get_agents( agent_learn = DQN( policy=algorithm, optim=optim, - estimation_step=args.n_step, + n_step_return_horizon=args.n_step, gamma=args.gamma, target_update_freq=args.target_update_freq, ) diff --git a/tianshou/algorithm/imitation/discrete_bcq.py b/tianshou/algorithm/imitation/discrete_bcq.py index 495b6ffa1..4de2f4455 100644 --- a/tianshou/algorithm/imitation/discrete_bcq.py +++ b/tianshou/algorithm/imitation/discrete_bcq.py @@ -133,7 +133,7 @@ def __init__( policy: DiscreteBCQPolicy, optim: OptimizerFactory, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 8000, imitation_logits_penalty: float = 1e-2, ) -> None: @@ -147,7 +147,7 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -167,7 +167,7 @@ def __init__( complexity. :param imitation_logits_penalty: regularization weight for imitation logits. - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -194,9 +194,9 @@ def __init__( assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" self.gamma = gamma assert ( - estimation_step > 0 - ), f"estimation_step should be greater than 0 but got: {estimation_step}" - self.n_step = estimation_step + n_step_return_horizon > 0 + ), f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" + self.n_step = n_step_return_horizon self._target = target_update_freq > 0 self.freq = target_update_freq self._iter = 0 diff --git a/tianshou/algorithm/imitation/discrete_cql.py b/tianshou/algorithm/imitation/discrete_cql.py index a3f832902..e273b9e41 100644 --- a/tianshou/algorithm/imitation/discrete_cql.py +++ b/tianshou/algorithm/imitation/discrete_cql.py @@ -31,7 +31,7 @@ def __init__( min_q_weight: float = 10.0, gamma: float = 0.99, num_quantiles: int = 200, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ @@ -47,7 +47,7 @@ def __init__( Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -72,7 +72,7 @@ def __init__( optim=optim, gamma=gamma, num_quantiles=num_quantiles, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.min_q_weight = min_q_weight diff --git a/tianshou/algorithm/imitation/td3_bc.py b/tianshou/algorithm/imitation/td3_bc.py index 50f339673..5ccbbe0fb 100644 --- a/tianshou/algorithm/imitation/td3_bc.py +++ b/tianshou/algorithm/imitation/td3_bc.py @@ -29,7 +29,7 @@ def __init__( update_actor_freq: int = 2, noise_clip: float = 0.5, alpha: float = 2.5, - estimation_step: int = 1, + n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy @@ -95,7 +95,7 @@ def __init__( policy_noise=policy_noise, noise_clip=noise_clip, update_actor_freq=update_actor_freq, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, ) self.alpha = alpha diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index dd88b7600..3e317532f 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -107,7 +107,7 @@ def __init__( policy: BDQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, is_double: bool = True, ) -> None: @@ -121,7 +121,7 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -146,13 +146,13 @@ def __init__( Note: This parameter is most effective when used with a target network (target_update_freq > 0). """ assert ( - estimation_step == 1 - ), f"N-step bigger than one is not supported by BDQ but got: {estimation_step}" + n_step_return_horizon == 1 + ), f"N-step bigger than one is not supported by BDQ but got: {n_step_return_horizon}" super().__init__( policy=policy, optim=optim, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.is_double = is_double diff --git a/tianshou/algorithm/modelfree/c51.py b/tianshou/algorithm/modelfree/c51.py index 97d86386f..4f2917bd0 100644 --- a/tianshou/algorithm/modelfree/c51.py +++ b/tianshou/algorithm/modelfree/c51.py @@ -76,7 +76,7 @@ def __init__( policy: C51Policy, optim: OptimizerFactory, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ @@ -89,7 +89,7 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -112,7 +112,7 @@ def __init__( policy=policy, optim=optim, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.delta_z = (policy.v_max - policy.v_min) / (policy.num_atoms - 1) diff --git a/tianshou/algorithm/modelfree/ddpg.py b/tianshou/algorithm/modelfree/ddpg.py index 127b515bd..038d647b5 100644 --- a/tianshou/algorithm/modelfree/ddpg.py +++ b/tianshou/algorithm/modelfree/ddpg.py @@ -223,7 +223,7 @@ def __init__( critic_optim: OptimizerFactory, tau: float = 0.005, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy @@ -261,7 +261,7 @@ def __init__( self.critic_old = self._add_lagged_network(self.critic) self.critic_optim = self._create_optimizer(self.critic, critic_optim) self.gamma = gamma - self.estimation_step = estimation_step + self.n_step_return_horizon = n_step_return_horizon @staticmethod def _minimize_critic_squared_loss( @@ -298,7 +298,7 @@ def _preprocess_batch( indices=indices, target_q_fn=self._target_q, gamma=self.gamma, - n_step=self.estimation_step, + n_step=self.n_step_return_horizon, ) def _target_q_compute_action(self, obs_batch: Batch) -> TActBatchProtocol: @@ -353,7 +353,7 @@ def __init__( critic_optim: OptimizerFactory, tau: float = 0.005, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy @@ -375,7 +375,7 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -390,7 +390,7 @@ def __init__( critic_optim=critic_optim, tau=tau, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, ) self.actor_old = self._add_lagged_network(self.policy.actor) diff --git a/tianshou/algorithm/modelfree/discrete_sac.py b/tianshou/algorithm/modelfree/discrete_sac.py index 65094a191..ae1d1318f 100644 --- a/tianshou/algorithm/modelfree/discrete_sac.py +++ b/tianshou/algorithm/modelfree/discrete_sac.py @@ -96,7 +96,7 @@ def __init__( tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, - estimation_step: int = 1, + n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy @@ -124,7 +124,7 @@ def __init__( Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param alpha: the entropy regularization coefficient alpha or an object which can be used to automatically tune it (e.g. an instance of `AutoAlpha`). - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -141,7 +141,7 @@ def __init__( critic2_optim=critic2_optim, tau=tau, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, ) self.alpha = Alpha.from_float_or_instance(alpha) diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index 570c1c1ed..f1286120f 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -188,7 +188,7 @@ def __init__( policy: TDQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ @@ -201,7 +201,7 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -228,9 +228,9 @@ def __init__( assert 0.0 <= gamma <= 1.0, f"discount factor should be in [0, 1] but got: {gamma}" self.gamma = gamma assert ( - estimation_step > 0 - ), f"estimation_step should be greater than 0 but got: {estimation_step}" - self.n_step = estimation_step + n_step_return_horizon > 0 + ), f"n_step_return_horizon should be greater than 0 but got: {n_step_return_horizon}" + self.n_step = n_step_return_horizon self.target_update_freq = target_update_freq # TODO: 1 would be a more reasonable initialization given how it is incremented self._iter = 0 @@ -298,7 +298,7 @@ def __init__( policy: TDQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, is_double: bool = True, huber_loss_delta: float | None = None, @@ -313,7 +313,7 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -351,7 +351,7 @@ def __init__( policy=policy, optim=optim, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.is_double = is_double diff --git a/tianshou/algorithm/modelfree/fqf.py b/tianshou/algorithm/modelfree/fqf.py index eed8d7c5a..25e7a1047 100644 --- a/tianshou/algorithm/modelfree/fqf.py +++ b/tianshou/algorithm/modelfree/fqf.py @@ -123,7 +123,7 @@ def __init__( # Rename? Or at least explain what happens here. num_fractions: int = 32, ent_coef: float = 0.0, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ @@ -144,7 +144,7 @@ def __init__( Higher values promote more exploration by encouraging a more uniform action distribution. Lower values focus more on exploitation of the current policy's knowledge. Typically set between 0.01 and 0.05 for most actor-critic implementations. - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -168,7 +168,7 @@ def __init__( optim=optim, gamma=gamma, num_quantiles=num_fractions, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.ent_coef = ent_coef diff --git a/tianshou/algorithm/modelfree/iqn.py b/tianshou/algorithm/modelfree/iqn.py index cb6881996..f9610732f 100644 --- a/tianshou/algorithm/modelfree/iqn.py +++ b/tianshou/algorithm/modelfree/iqn.py @@ -113,7 +113,7 @@ def __init__( optim: OptimizerFactory, gamma: float = 0.99, num_quantiles: int = 200, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ @@ -128,7 +128,7 @@ def __init__( Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param num_quantiles: the number of quantile midpoints in the inverse cumulative distribution function of the value. - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -152,7 +152,7 @@ def __init__( optim=optim, gamma=gamma, num_quantiles=num_quantiles, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) diff --git a/tianshou/algorithm/modelfree/qrdqn.py b/tianshou/algorithm/modelfree/qrdqn.py index 883086d6a..7602a1887 100644 --- a/tianshou/algorithm/modelfree/qrdqn.py +++ b/tianshou/algorithm/modelfree/qrdqn.py @@ -36,7 +36,7 @@ def __init__( optim: OptimizerFactory, gamma: float = 0.99, num_quantiles: int = 200, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ @@ -56,7 +56,7 @@ def __init__( Lower values reduce computational cost but may not capture the distribution accurately enough. The original QRDQN paper used 200 quantiles for Atari environments. Must be greater than 1, as at least two quantiles are needed to represent a distribution. - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -80,7 +80,7 @@ def __init__( policy=policy, optim=optim, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) self.num_quantiles = num_quantiles diff --git a/tianshou/algorithm/modelfree/rainbow.py b/tianshou/algorithm/modelfree/rainbow.py index 6efce699e..4ea266d2d 100644 --- a/tianshou/algorithm/modelfree/rainbow.py +++ b/tianshou/algorithm/modelfree/rainbow.py @@ -24,7 +24,7 @@ def __init__( policy: C51Policy, optim: OptimizerFactory, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, target_update_freq: int = 0, ) -> None: """ @@ -37,7 +37,7 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -60,7 +60,7 @@ def __init__( policy=policy, optim=optim, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, target_update_freq=target_update_freq, ) diff --git a/tianshou/algorithm/modelfree/redq.py b/tianshou/algorithm/modelfree/redq.py index 99f7deadc..786db84d9 100644 --- a/tianshou/algorithm/modelfree/redq.py +++ b/tianshou/algorithm/modelfree/redq.py @@ -146,7 +146,7 @@ def __init__( tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, - estimation_step: int = 1, + n_step_return_horizon: int = 1, actor_delay: int = 20, deterministic_eval: bool = True, target_mode: Literal["mean", "min"] = "min", @@ -200,7 +200,7 @@ def __init__( premature convergence to suboptimal deterministic policies. Can be provided as a fixed float (0.2 is a reasonable default) or as an instance of, in particular, class `AutoAlpha` for automatic tuning during training. - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -236,7 +236,7 @@ def __init__( critic_optim=critic_optim, tau=tau, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, ) self.ensemble_size = ensemble_size self.subset_size = subset_size diff --git a/tianshou/algorithm/modelfree/sac.py b/tianshou/algorithm/modelfree/sac.py index 695e773ee..4ad7fd195 100644 --- a/tianshou/algorithm/modelfree/sac.py +++ b/tianshou/algorithm/modelfree/sac.py @@ -227,7 +227,7 @@ def __init__( tau: float = 0.005, gamma: float = 0.99, alpha: float | Alpha = 0.2, - estimation_step: int = 1, + n_step_return_horizon: int = 1, deterministic_eval: bool = True, ) -> None: """ @@ -265,7 +265,7 @@ def __init__( premature convergence to suboptimal deterministic policies. Can be provided as a fixed float (0.2 is a reasonable default) or as an instance of, in particular, class `AutoAlpha` for automatic tuning during training. - :param estimation_step: the number of future steps (> 0) to consider when computing temporal + :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: higher values reduce bias (by relying less on potentially inaccurate value estimates) but increase variance (by incorporating more environmental stochasticity and reducing @@ -282,7 +282,7 @@ def __init__( critic2_optim=critic2_optim, tau=tau, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, ) self.deterministic_eval = deterministic_eval self.alpha = Alpha.from_float_or_instance(alpha) diff --git a/tianshou/algorithm/modelfree/td3.py b/tianshou/algorithm/modelfree/td3.py index 0616c5d64..a6e857c0b 100644 --- a/tianshou/algorithm/modelfree/td3.py +++ b/tianshou/algorithm/modelfree/td3.py @@ -48,7 +48,7 @@ def __init__( critic2_optim: OptimizerFactory | None = None, tau: float = 0.005, gamma: float = 0.99, - estimation_step: int = 1, + n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy @@ -85,7 +85,7 @@ def __init__( critic_optim=critic_optim, tau=tau, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, ) self.critic2 = critic2 or deepcopy(critic) self.critic2_old = self._add_lagged_network(self.critic2) @@ -121,7 +121,7 @@ def __init__( policy_noise: float = 0.2, update_actor_freq: int = 2, noise_clip: float = 0.5, - estimation_step: int = 1, + n_step_return_horizon: int = 1, ) -> None: """ :param policy: the policy @@ -178,7 +178,7 @@ def __init__( critic2_optim=critic2_optim, tau=tau, gamma=gamma, - estimation_step=estimation_step, + n_step_return_horizon=n_step_return_horizon, ) self.actor_old = self._add_lagged_network(self.policy.actor) self.policy_noise = policy_noise diff --git a/tianshou/highlevel/params/algorithm_params.py b/tianshou/highlevel/params/algorithm_params.py index 347ea6d74..e6b1e45e7 100644 --- a/tianshou/highlevel/params/algorithm_params.py +++ b/tianshou/highlevel/params/algorithm_params.py @@ -272,8 +272,8 @@ def _get_param_transformers(self) -> list[ParamTransformer]: @dataclass(kw_only=True) -class ParamsMixinEstimationStep: - estimation_step: int = 1 +class ParamsMixinNStepReturnHorizon: + n_step_return_horizon: int = 1 """ the number of future steps (> 0) to consider when computing temporal difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: @@ -578,6 +578,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: ] +@dataclass(kw_only=True) class ParamsMixinAlpha(GetParamTransformersProtocol): alpha: float | AutoAlphaFactory = 0.2 """ @@ -603,7 +604,7 @@ class _SACParams( Params, ParamsMixinGamma, ParamsMixinActorAndDualCritics, - ParamsMixinEstimationStep, + ParamsMixinNStepReturnHorizon, ParamsMixinTau, ParamsMixinDeterministicEval, ParamsMixinAlpha, @@ -631,7 +632,7 @@ class DiscreteSACParams(_SACParams): @dataclass(kw_only=True) class QLearningOffPolicyParams( - Params, ParamsMixinGamma, ParamsMixinSingleModel, ParamsMixinEstimationStep + Params, ParamsMixinGamma, ParamsMixinSingleModel, ParamsMixinNStepReturnHorizon ): target_update_freq: int = 0 """ @@ -738,7 +739,7 @@ class DDPGParams( ParamsMixinActorAndCritic, ParamsMixinExplorationNoise, ParamsMixinActionScaling, - ParamsMixinEstimationStep, + ParamsMixinNStepReturnHorizon, ParamsMixinTau, ): def _get_param_transformers(self) -> list[ParamTransformer]: @@ -806,7 +807,7 @@ class TD3Params( ParamsMixinActorAndDualCritics, ParamsMixinExplorationNoise, ParamsMixinActionScaling, - ParamsMixinEstimationStep, + ParamsMixinNStepReturnHorizon, ParamsMixinTau, ): policy_noise: float | FloatEnvValueFactory = 0.2 From 59d5916bfd79a3dcb0e33e58a61f9be9f21ca7cc Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 01:04:15 +0200 Subject: [PATCH 179/230] v2: BDQN: Remove parameter 'n_step_horizon' (formerly 'estimation_step') as the algorithm only computes 1-step returns --- CHANGELOG.md | 1 + tianshou/algorithm/modelfree/bdqn.py | 14 ++------------ 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa594619e..b5a60bbe9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -142,6 +142,7 @@ Developers: * `BDQN`: * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` * Remove parameter `clip_loss_grad` (unused; only passed on to former base class) + * Remove parameter `estimation_step`, for which only one option was valid * `C51`: * Inherit from `QLearningOffPolicyAlgorithm` instead of `DQN` * Remove parameters `clip_loss_grad` and `is_double` (unused; only passed on to former base class) diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index 3e317532f..fc965d1eb 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -107,7 +107,6 @@ def __init__( policy: BDQNPolicy, optim: OptimizerFactory, gamma: float = 0.99, - n_step_return_horizon: int = 1, target_update_freq: int = 0, is_double: bool = True, ) -> None: @@ -121,13 +120,6 @@ def __init__( potentially improving performance in tasks where delayed rewards are important but increasing training variance by incorporating more environmental stochasticity. Typically set between 0.9 and 0.99 for most reinforcement learning tasks - :param n_step_return_horizon: the number of future steps (> 0) to consider when computing temporal - difference (TD) targets. Controls the balance between TD learning and Monte Carlo methods: - higher values reduce bias (by relying less on potentially inaccurate value estimates) - but increase variance (by incorporating more environmental stochasticity and reducing - the averaging effect). A value of 1 corresponds to standard TD learning with immediate - bootstrapping, while very large values approach Monte Carlo-like estimation that uses - complete episode returns. :param target_update_freq: the number of training iterations between each complete update of the target network. Controls how frequently the target Q-network parameters are updated with the current @@ -145,14 +137,12 @@ def __init__( If False, the algorithm selects actions by directly taking the maximum Q-value from the target network. Note: This parameter is most effective when used with a target network (target_update_freq > 0). """ - assert ( - n_step_return_horizon == 1 - ), f"N-step bigger than one is not supported by BDQ but got: {n_step_return_horizon}" super().__init__( policy=policy, optim=optim, gamma=gamma, - n_step_return_horizon=n_step_return_horizon, + # BDQN implements its own returns computation (below), which supports only 1-step returns + n_step_return_horizon=1, target_update_freq=target_update_freq, ) self.is_double = is_double From 272e70f676619b744195221c3b5793e013d4161e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 13:17:10 +0200 Subject: [PATCH 180/230] v2: Rename 'pg' module and associated scripts to 'reinforce' --- examples/atari/atari_ppo.py | 2 +- examples/inverse/irl_gail.py | 2 +- examples/mujoco/mujoco_a2c.py | 2 +- examples/mujoco/mujoco_npg.py | 2 +- examples/mujoco/mujoco_ppo.py | 2 +- examples/mujoco/mujoco_reinforce.py | 2 +- examples/mujoco/mujoco_trpo.py | 2 +- examples/offline/atari_crr.py | 2 +- examples/vizdoom/vizdoom_ppo.py | 2 +- test/base/test_policy.py | 2 +- test/continuous/test_npg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_trpo.py | 2 +- test/discrete/test_a2c_with_il.py | 2 +- test/discrete/test_ppo_discrete.py | 2 +- test/discrete/{test_pg.py => test_reinforce.py} | 4 ++-- test/modelbased/test_ppo_icm.py | 2 +- test/offline/test_discrete_crr.py | 2 +- test/offline/test_gail.py | 2 +- test/pettingzoo/pistonball_continuous.py | 2 +- tianshou/algorithm/__init__.py | 2 +- tianshou/algorithm/imitation/discrete_bcq.py | 2 +- tianshou/algorithm/imitation/discrete_cql.py | 2 +- tianshou/algorithm/imitation/discrete_crr.py | 2 +- tianshou/algorithm/imitation/gail.py | 2 +- tianshou/algorithm/modelfree/a2c.py | 2 +- tianshou/algorithm/modelfree/bdqn.py | 2 +- tianshou/algorithm/modelfree/c51.py | 2 +- tianshou/algorithm/modelfree/dqn.py | 2 +- tianshou/algorithm/modelfree/fqf.py | 2 +- tianshou/algorithm/modelfree/iqn.py | 2 +- tianshou/algorithm/modelfree/npg.py | 2 +- tianshou/algorithm/modelfree/ppo.py | 2 +- tianshou/algorithm/modelfree/qrdqn.py | 2 +- tianshou/algorithm/modelfree/{pg.py => reinforce.py} | 0 tianshou/algorithm/modelfree/trpo.py | 2 +- tianshou/env/atari/atari_network.py | 2 +- tianshou/highlevel/algorithm.py | 2 +- tianshou/highlevel/module/actor.py | 2 +- tianshou/highlevel/params/dist_fn.py | 2 +- 40 files changed, 40 insertions(+), 40 deletions(-) rename test/discrete/{test_pg.py => test_reinforce.py} (97%) rename tianshou/algorithm/modelfree/{pg.py => reinforce.py} (100%) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 22399fb27..7b1ea4c68 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -12,7 +12,7 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper -from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import ( diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 8a691184c..02040e221 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -16,7 +16,7 @@ from tianshou.algorithm import GAIL from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import ( Batch, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index f6b4eae34..a7820f489 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -13,7 +13,7 @@ from tianshou.algorithm import A2C from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index b5004997e..5a5f08cc6 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -13,7 +13,7 @@ from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index a12c00ca5..061036d14 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -13,7 +13,7 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 9d9bae48e..2808dd9f1 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -13,7 +13,7 @@ from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index b02af4cd6..77863cc7a 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -13,7 +13,7 @@ from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index ea5c4fc9c..79e3b67cb 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -14,7 +14,7 @@ from examples.offline.utils import load_buffer from tianshou.algorithm import DiscreteCRR from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 4004fdc7e..28bfaa979 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -12,7 +12,7 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet diff --git a/test/base/test_policy.py b/test/base/test_policy.py index c357b7e87..2fcc6d522 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -9,7 +9,7 @@ RandomActionPolicy, episode_mc_return_to_go, ) -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Batch from tianshou.utils.net.common import Net diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 3f94d6f03..8d2578707 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -11,7 +11,7 @@ from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index bad8edbc8..762615962 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -10,7 +10,7 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index dec82d7f5..e29108f26 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -11,7 +11,7 @@ from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 906695879..5801654f3 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -10,7 +10,7 @@ from tianshou.algorithm import A2C, Algorithm, OffPolicyImitationLearning from tianshou.algorithm.imitation.imitation_base import ImitationPolicy -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index ffdc8822f..0de6299a4 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -10,7 +10,7 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv diff --git a/test/discrete/test_pg.py b/test/discrete/test_reinforce.py similarity index 97% rename from test/discrete/test_pg.py rename to test/discrete/test_reinforce.py index 218b40ec8..ab222ae26 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_reinforce.py @@ -10,7 +10,7 @@ from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv @@ -137,4 +137,4 @@ def stop_fn(mean_rewards: float) -> bool: def test_pg_determinism() -> None: main_fn = lambda args: test_pg(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_pg", main_fn, get_args()).run() + AlgorithmDeterminismTest("discrete_reinforce", main_fn, get_args()).run() diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index a3a22487e..acf9daa4c 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -10,7 +10,7 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index df3f03a75..d1a1fafca 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import Algorithm, DiscreteCRR -from tianshou.algorithm.modelfree.pg import DiscreteActorPolicy +from tianshou.algorithm.modelfree.reinforce import DiscreteActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import ( Collector, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index f4f5236c1..e6549bf5b 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import GAIL, Algorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 98c5f768d..14da09f63 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -13,7 +13,7 @@ from tianshou.algorithm import PPO, Algorithm from tianshou.algorithm.algorithm_base import OnPolicyAlgorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.multiagent.marl import MultiAgentOnPolicyAlgorithm from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer diff --git a/tianshou/algorithm/__init__.py b/tianshou/algorithm/__init__.py index 93606f2f5..6ce54f60b 100644 --- a/tianshou/algorithm/__init__.py +++ b/tianshou/algorithm/__init__.py @@ -2,7 +2,7 @@ # isort:skip_file from tianshou.algorithm.algorithm_base import Algorithm, TrainingStats -from tianshou.algorithm.modelfree.pg import Reinforce +from tianshou.algorithm.modelfree.reinforce import Reinforce from tianshou.algorithm.modelfree.dqn import DQN from tianshou.algorithm.modelfree.ddpg import DDPG diff --git a/tianshou/algorithm/imitation/discrete_bcq.py b/tianshou/algorithm/imitation/discrete_bcq.py index 4de2f4455..cd49ac1aa 100644 --- a/tianshou/algorithm/imitation/discrete_bcq.py +++ b/tianshou/algorithm/imitation/discrete_bcq.py @@ -12,7 +12,7 @@ OfflineAlgorithm, ) from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.types import ( diff --git a/tianshou/algorithm/imitation/discrete_cql.py b/tianshou/algorithm/imitation/discrete_cql.py index e273b9e41..8ca39440b 100644 --- a/tianshou/algorithm/imitation/discrete_cql.py +++ b/tianshou/algorithm/imitation/discrete_cql.py @@ -6,8 +6,8 @@ from tianshou.algorithm import QRDQN from tianshou.algorithm.algorithm_base import OfflineAlgorithm -from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import to_torch from tianshou.data.types import RolloutBatchProtocol diff --git a/tianshou/algorithm/imitation/discrete_crr.py b/tianshou/algorithm/imitation/discrete_crr.py index 902b0e317..1a344a65b 100644 --- a/tianshou/algorithm/imitation/discrete_crr.py +++ b/tianshou/algorithm/imitation/discrete_crr.py @@ -11,7 +11,7 @@ LaggedNetworkFullUpdateAlgorithmMixin, OfflineAlgorithm, ) -from tianshou.algorithm.modelfree.pg import ( +from tianshou.algorithm.modelfree.reinforce import ( DiscountedReturnComputation, DiscreteActorPolicy, SimpleLossTrainingStats, diff --git a/tianshou/algorithm/imitation/gail.py b/tianshou/algorithm/imitation/gail.py index dbeddd1c3..3d93ae8e9 100644 --- a/tianshou/algorithm/imitation/gail.py +++ b/tianshou/algorithm/imitation/gail.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from tianshou.algorithm.modelfree.a2c import A2CTrainingStats -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.modelfree.ppo import PPO +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ( ReplayBuffer, diff --git a/tianshou/algorithm/modelfree/a2c.py b/tianshou/algorithm/modelfree/a2c.py index f4f1d4b1e..91d7cbe9a 100644 --- a/tianshou/algorithm/modelfree/a2c.py +++ b/tianshou/algorithm/modelfree/a2c.py @@ -10,7 +10,7 @@ OnPolicyAlgorithm, TrainingStats, ) -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index fc965d1eb..9eb5f0c79 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -10,7 +10,7 @@ DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) -from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as from tianshou.data.batch import BatchProtocol diff --git a/tianshou/algorithm/modelfree/c51.py b/tianshou/algorithm/modelfree/c51.py index 4f2917bd0..8ca11f37d 100644 --- a/tianshou/algorithm/modelfree/c51.py +++ b/tianshou/algorithm/modelfree/c51.py @@ -6,7 +6,7 @@ DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) -from tianshou.algorithm.modelfree.pg import LossSequenceTrainingStats +from tianshou.algorithm.modelfree.reinforce import LossSequenceTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index f1286120f..7166bfe15 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -14,7 +14,7 @@ Policy, TArrOrActBatch, ) -from tianshou.algorithm.modelfree.pg import ( +from tianshou.algorithm.modelfree.reinforce import ( SimpleLossTrainingStats, ) from tianshou.algorithm.optim import OptimizerFactory diff --git a/tianshou/algorithm/modelfree/fqf.py b/tianshou/algorithm/modelfree/fqf.py index 25e7a1047..1f93c1e8b 100644 --- a/tianshou/algorithm/modelfree/fqf.py +++ b/tianshou/algorithm/modelfree/fqf.py @@ -9,8 +9,8 @@ from tianshou.algorithm import QRDQN, Algorithm from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy -from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol diff --git a/tianshou/algorithm/modelfree/iqn.py b/tianshou/algorithm/modelfree/iqn.py index f9610732f..dd20e12c4 100644 --- a/tianshou/algorithm/modelfree/iqn.py +++ b/tianshou/algorithm/modelfree/iqn.py @@ -6,8 +6,8 @@ import torch.nn.functional as F from tianshou.algorithm import QRDQN -from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats from tianshou.algorithm.modelfree.qrdqn import QRDQNPolicy +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, to_numpy from tianshou.data.batch import BatchProtocol diff --git a/tianshou/algorithm/modelfree/npg.py b/tianshou/algorithm/modelfree/npg.py index 21c200ea5..93f327e0f 100644 --- a/tianshou/algorithm/modelfree/npg.py +++ b/tianshou/algorithm/modelfree/npg.py @@ -9,7 +9,7 @@ from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.algorithm.modelfree.a2c import ActorCriticOnPolicyAlgorithm -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol diff --git a/tianshou/algorithm/modelfree/ppo.py b/tianshou/algorithm/modelfree/ppo.py index a81bf4ffc..a1395eddc 100644 --- a/tianshou/algorithm/modelfree/ppo.py +++ b/tianshou/algorithm/modelfree/ppo.py @@ -5,7 +5,7 @@ from tianshou.algorithm import A2C from tianshou.algorithm.modelfree.a2c import A2CTrainingStats -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol diff --git a/tianshou/algorithm/modelfree/qrdqn.py b/tianshou/algorithm/modelfree/qrdqn.py index 7602a1887..af1cb416a 100644 --- a/tianshou/algorithm/modelfree/qrdqn.py +++ b/tianshou/algorithm/modelfree/qrdqn.py @@ -9,7 +9,7 @@ DiscreteQLearningPolicy, QLearningOffPolicyAlgorithm, ) -from tianshou.algorithm.modelfree.pg import SimpleLossTrainingStats +from tianshou.algorithm.modelfree.reinforce import SimpleLossTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol diff --git a/tianshou/algorithm/modelfree/pg.py b/tianshou/algorithm/modelfree/reinforce.py similarity index 100% rename from tianshou/algorithm/modelfree/pg.py rename to tianshou/algorithm/modelfree/reinforce.py diff --git a/tianshou/algorithm/modelfree/trpo.py b/tianshou/algorithm/modelfree/trpo.py index f30ef7fba..5b7372143 100644 --- a/tianshou/algorithm/modelfree/trpo.py +++ b/tianshou/algorithm/modelfree/trpo.py @@ -7,7 +7,7 @@ from tianshou.algorithm import NPG from tianshou.algorithm.modelfree.npg import NPGTrainingStats -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import SequenceSummaryStats from tianshou.data.types import BatchWithAdvantagesProtocol diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index 1d83bac6d..9cfc27a24 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -5,7 +5,7 @@ import torch from torch import nn -from tianshou.algorithm.modelfree.pg import TDistFnDiscrOrCont +from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import ( diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 8a8ecce71..b46aa6e74 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -31,8 +31,8 @@ from tianshou.algorithm.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.modelfree.iqn import IQNPolicy -from tianshou.algorithm.modelfree.pg import ActorPolicyProbabilistic from tianshou.algorithm.modelfree.redq import REDQPolicy +from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic from tianshou.algorithm.modelfree.sac import SACPolicy from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data.collector import BaseCollector, CollectStats diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 5373c7a4f..f46203126 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -8,7 +8,7 @@ from sensai.util.string import ToStringMixin from torch import nn -from tianshou.algorithm.modelfree.pg import TDistFnDiscrOrCont +from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.module.core import ( ModuleFactory, diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index b55f33b30..d75096bc3 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -5,7 +5,7 @@ import torch from sensai.util.string import ToStringMixin -from tianshou.algorithm.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont +from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrete, TDistFnDiscrOrCont from tianshou.highlevel.env import Environments From d0c72e6ba3d43619f410802050408d842e875a3d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 13:31:28 +0200 Subject: [PATCH 181/230] v2: Improve docstrings of actors --- tianshou/utils/net/common.py | 19 +++++++++++++++---- tianshou/utils/net/discrete.py | 13 +++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 1516bc936..284c7f952 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -690,13 +690,24 @@ def forward( class ContinuousActorProbabilisticInterface(Actor, ABC): - """Marker interface for probabilistic actors defined by users (outside of Tianshou code).""" + """Type bound for probabilistic actors which output distribution parameters for continuous action spaces.""" class DiscreteActorInterface(Actor, ABC): - """Marker interface for discrete actors defined by users (outside of Tianshou code). - - See docstring of :class:`DiscreteActor` + """ + Type bound for discrete actors. + + For on-policy algos like Reinforce, this typically directly outputs unnormalized log + probabilities, which can be interpreted as "logits" in conjunction with a + `torch.distributions.Categorical` instance. + + In Tianshou, discrete actors are also used for computing action distributions within + Q-learning type algorithms (e.g., DQN). In this case, the observations are mapped + to a vector of Q-values (one for each action). In other words, the component is actually + a critic, not an actor in the traditional sense. + Note that when sampling actions, the Q-values can be interpreted as inputs for + a `torch.distributions.Categorical` instance, similar to the on-policy case mentioned + above. """ diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index e1405be1f..4ea673155 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -24,14 +24,11 @@ def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions class DiscreteActor(DiscreteActorInterface): - """For on-policy algos like Reinforce, this usually directly outputs unnormalized log - probabilities. - - In Tianshou, discrete actors are also used for computing action distributions within - Q-learning type algorithms, discrete actors - typically the values of the Q function for each action (as tensor), - which are then later re-interpreted as unnormalized log-probabilities for sampling - discrete actions. So such an actor is essentially a critic. + """ + Generic discrete actor which uses a preprocessing network to generate a latent representation + which is subsequently passed to an MLP to compute the output. + + For common output semantics, see :class:`DiscreteActorInterface`. """ def __init__( From ed195c181db78e35c41c71fd4e0de4311b7306e3 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 17:32:39 +0200 Subject: [PATCH 182/230] v2: Fix import --- tianshou/algorithm/modelfree/rainbow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/algorithm/modelfree/rainbow.py b/tianshou/algorithm/modelfree/rainbow.py index 4ea266d2d..e04c6876e 100644 --- a/tianshou/algorithm/modelfree/rainbow.py +++ b/tianshou/algorithm/modelfree/rainbow.py @@ -3,7 +3,7 @@ from torch import nn from tianshou.algorithm.modelfree.c51 import C51, C51Policy -from tianshou.algorithm.modelfree.pg import LossSequenceTrainingStats +from tianshou.algorithm.modelfree.reinforce import LossSequenceTrainingStats from tianshou.algorithm.optim import OptimizerFactory from tianshou.data.types import RolloutBatchProtocol from tianshou.utils.lagged_network import EvalModeModuleWrapper From 9728bc05fc854c769a06374ac6fe3ebc3d835b21 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 17:45:20 +0200 Subject: [PATCH 183/230] v2: Fix assertion (stats can be None) --- tianshou/trainer/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 43a29ffa7..212eb2c8e 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -563,10 +563,9 @@ def execute_epoch(self) -> EpochStats: self._stop_fn_flag = training_step_result.is_training_done() self._env_step += training_step_result.get_env_step_advancement() training_stats = training_step_result.get_training_stats() - assert training_stats is not None TraceLogger.log( log, - lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict()}", + lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict() if training_stats is not None else None}", ) self._log_params(self.algorithm) From 9ba76b74dd4d85a583fb68c3ad3a36621e74433c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 17:29:34 +0200 Subject: [PATCH 184/230] v2: Rename trainer parameters: * max_epoch | num_epochs -> max_epochs * step_per_epoch -> epoch_num_steps * episode_per_test | num_test_episode -> test_step_num_episodes * step_per_collect -> collection_step_num_env_steps * episode_per_collect -> collection_step_num_episodes * update_per_step -> update_step_num_gradient_steps_per_sample * repeat_per_collect -> update_step_num_repetitions --- CHANGELOG.md | 11 +++- examples/atari/atari_c51.py | 10 +-- examples/atari/atari_dqn.py | 10 +-- examples/atari/atari_dqn_hl.py | 6 +- examples/atari/atari_fqf.py | 10 +-- examples/atari/atari_iqn.py | 10 +-- examples/atari/atari_iqn_hl.py | 6 +- examples/atari/atari_ppo.py | 16 ++--- examples/atari/atari_ppo_hl.py | 6 +- examples/atari/atari_qrdqn.py | 10 +-- examples/atari/atari_rainbow.py | 10 +-- examples/atari/atari_sac.py | 10 +-- examples/atari/atari_sac_hl.py | 6 +- examples/box2d/acrobot_dualdqn.py | 10 +-- examples/box2d/bipedal_bdq.py | 10 +-- examples/box2d/bipedal_hardcore_sac.py | 10 +-- examples/box2d/lunarlander_dqn.py | 10 +-- examples/box2d/mcc_sac.py | 10 +-- examples/discrete/discrete_dqn.py | 10 +-- examples/discrete/discrete_dqn_hl.py | 6 +- examples/inverse/irl_gail.py | 16 ++--- examples/mujoco/fetch_her_ddpg.py | 10 +-- examples/mujoco/mujoco_a2c.py | 16 ++--- examples/mujoco/mujoco_a2c_hl.py | 6 +- examples/mujoco/mujoco_ddpg.py | 10 +-- examples/mujoco/mujoco_ddpg_hl.py | 6 +- examples/mujoco/mujoco_npg.py | 16 ++--- examples/mujoco/mujoco_npg_hl.py | 6 +- examples/mujoco/mujoco_ppo.py | 16 ++--- examples/mujoco/mujoco_ppo_hl.py | 6 +- examples/mujoco/mujoco_ppo_hl_multi.py | 8 +-- examples/mujoco/mujoco_redq.py | 10 +-- examples/mujoco/mujoco_redq_hl.py | 6 +- examples/mujoco/mujoco_reinforce.py | 16 ++--- examples/mujoco/mujoco_reinforce_hl.py | 6 +- examples/mujoco/mujoco_sac.py | 10 +-- examples/mujoco/mujoco_sac_hl.py | 6 +- examples/mujoco/mujoco_td3.py | 10 +-- examples/mujoco/mujoco_td3_hl.py | 6 +- examples/mujoco/mujoco_trpo.py | 16 ++--- examples/mujoco/mujoco_trpo_hl.py | 6 +- examples/offline/atari_bcq.py | 6 +- examples/offline/atari_cql.py | 6 +- examples/offline/atari_crr.py | 6 +- examples/offline/atari_il.py | 6 +- examples/offline/d4rl_bcq.py | 6 +- examples/offline/d4rl_cql.py | 6 +- examples/offline/d4rl_il.py | 6 +- examples/offline/d4rl_td3_bc.py | 6 +- examples/vizdoom/vizdoom_c51.py | 10 +-- examples/vizdoom/vizdoom_ppo.py | 16 ++--- test/continuous/test_ddpg.py | 10 +-- test/continuous/test_npg.py | 10 +-- test/continuous/test_ppo.py | 12 ++-- test/continuous/test_redq.py | 10 +-- test/continuous/test_sac_with_il.py | 18 +++--- test/continuous/test_td3.py | 10 +-- test/continuous/test_trpo.py | 10 +-- test/discrete/test_a2c_with_il.py | 20 +++--- test/discrete/test_bdqn.py | 10 +-- test/discrete/test_c51.py | 10 +-- test/discrete/test_discrete_sac.py | 10 +-- test/discrete/test_dqn.py | 10 +-- test/discrete/test_drqn.py | 10 +-- test/discrete/test_fqf.py | 10 +-- test/discrete/test_iqn.py | 10 +-- test/discrete/test_ppo_discrete.py | 10 +-- test/discrete/test_qrdqn.py | 10 +-- test/discrete/test_rainbow.py | 10 +-- test/discrete/test_reinforce.py | 12 ++-- test/highlevel/test_experiment_builder.py | 20 +++--- test/modelbased/test_dqn_icm.py | 10 +-- test/modelbased/test_ppo_icm.py | 10 +-- test/modelbased/test_psrl.py | 12 ++-- test/offline/gather_cartpole_data.py | 10 +-- test/offline/gather_pendulum_data.py | 10 +-- test/offline/test_bcq.py | 6 +- test/offline/test_cql.py | 6 +- test/offline/test_discrete_bcq.py | 6 +- test/offline/test_discrete_cql.py | 6 +- test/offline/test_discrete_crr.py | 6 +- test/offline/test_gail.py | 12 ++-- test/offline/test_td3_bc.py | 6 +- test/pettingzoo/pistonball.py | 10 +-- test/pettingzoo/pistonball_continuous.py | 12 ++-- test/pettingzoo/tic_tac_toe.py | 10 +-- tianshou/algorithm/optim.py | 11 ++-- tianshou/highlevel/algorithm.py | 20 +++--- tianshou/highlevel/config.py | 68 ++++++++++++--------- tianshou/highlevel/params/lr_scheduler.py | 12 ++-- tianshou/trainer/trainer.py | 74 ++++++++++++----------- 91 files changed, 522 insertions(+), 496 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b5a60bbe9..d17cf1b65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,7 +63,15 @@ Developers: * Migration information at a glance: * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: `OnPolicyTrainerParams`, `OffPolicyTrainerParams`, `OfflineTrainerParams`. - * Changed parameter default: Default for `test_in_train` was changed from True to False. + * Changed parameter default: Default for `test_in_train` was changed from True to False. + * Changed parameter names to improve clarity: + * `max_epoch` (`num_epochs` in high-level API) -> `max_epochs` + * `step_per_epoch` -> `epoch_num_steps` + * `episode_per_test` (`num_test_episodes` in high-level API) -> `test_step_num_episodes` + * `step_per_collect` -> `collection_step_num_env_steps` + * `episode_per_collect` -> collection_step_num_episodes` + * `update_per_step` -> `update_step_num_gradient_steps_per_sample` + * `repeat_per_collect` -> `update_step_num_repetitions` * Trainer classes have been renamed: * `OnpolicyTrainer` -> `OnPolicyTrainer` * `OffpolicyTrainer` -> `OffPolicyTrainer` @@ -182,6 +190,7 @@ Developers: * The `test_in_train` parameter is now exposed (default False). * Inapplicable arguments can no longer be set in the respective subclass (e.g. `OffPolicyTrainingConfig` does not contain parameter `repeat_per_collect`). + * All parameter names have been aligned with the new names used by `TrainerParams` (see above). ### Peripheral Changes diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 990073349..b24ca107b 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -206,16 +206,16 @@ def watch() -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index b8242163a..df046018b 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -249,16 +249,16 @@ def watch() -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, resume_from_log=args.resume_id is not None, save_checkpoint_fn=save_checkpoint_fn, diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index c9fcffa52..25dc84c72 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -52,14 +52,14 @@ def main( log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - update_per_step=update_per_step, + update_step_num_gradient_steps_per_sample=update_per_step, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 86c44ac3e..49fe491b4 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -223,17 +223,17 @@ def watch() -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 6fa661296..591484f8f 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -217,16 +217,16 @@ def watch() -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 8aa335cea..fc1a23156 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -50,14 +50,14 @@ def main( log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - update_per_step=update_per_step, + update_step_num_gradient_steps_per_sample=update_per_step, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 7b1ea4c68..8752cee0c 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -141,9 +141,9 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( - num_epochs=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, ) ) @@ -280,12 +280,12 @@ def watch() -> None: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 7f4186dbd..2f1fdcb7a 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -58,14 +58,14 @@ def main( log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) training_config = OnPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + update_step_num_repetitions=repeat_per_collect, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index bd0c426ff..fc1e1a125 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -211,16 +211,16 @@ def watch() -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 64d5f9856..c86555ecf 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -254,16 +254,16 @@ def watch() -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index d57608d18..7e1774424 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -265,15 +265,15 @@ def watch() -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, resume_from_log=args.resume_id is not None, save_checkpoint_fn=save_checkpoint_fn, diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 429f5360d..ddbe98840 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -52,9 +52,9 @@ def main( log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, - update_per_step=update_per_step, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, + update_step_num_gradient_steps_per_sample=update_per_step, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 18f833259..0a8f3a9ea 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -129,12 +129,12 @@ def train_fn(epoch: int, env_step: int) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 19d9ba467..9d40d12ae 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -147,12 +147,12 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, save_best_fn=save_best_fn, diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 382ca01dd..d1ce67760 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -194,12 +194,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 8d142eb2e..9b6cbc4a6 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -126,12 +126,12 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, save_best_fn=save_best_fn, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 44c84dc25..626c80ec0 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -143,12 +143,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 38f42007f..a04a29774 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -71,12 +71,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=epoch, - step_per_epoch=step_per_epoch, - step_per_collect=step_per_collect, - episode_per_test=test_num, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, + collection_step_num_env_steps=step_per_collect, + test_step_num_episodes=test_num, batch_size=batch_size, - update_per_step=1 / step_per_collect, + update_step_num_gradient_steps_per_sample=1 / step_per_collect, stop_fn=stop_fn, logger=logger, test_in_train=True, diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index a7c05a6f6..59ead1e26 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -28,14 +28,14 @@ def main() -> None: watch_num_episodes=100, ), OffPolicyTrainingConfig( - num_epochs=10, - step_per_epoch=10000, + max_epochs=10, + epoch_num_steps=10000, batch_size=64, num_train_envs=10, num_test_envs=100, buffer_size=20000, step_per_collect=10, - update_per_step=1 / 10, + update_step_num_gradient_steps_per_sample=1 / 10, ), ) .with_dqn_params( diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 02040e221..3e7f54ffe 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -172,9 +172,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( - num_epochs=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, ) ) @@ -262,12 +262,12 @@ def save_best_fn(policy: Algorithm) -> None: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 9a3812e96..45c1524ca 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -230,14 +230,14 @@ def save_best_fn(policy: Algorithm) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index a7820f489..8e0911b94 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -130,9 +130,9 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( - num_epochs=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, ) ) @@ -207,12 +207,12 @@ def save_best_fn(policy: Algorithm) -> None: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 224d55fcb..1a85c52fc 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -44,14 +44,14 @@ def main( log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag()) training_config = OnPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + update_step_num_repetitions=repeat_per_collect, ) env_factory = MujocoEnvFactory( diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 066824d7b..908ae3cf3 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -161,14 +161,14 @@ def save_best_fn(policy: Algorithm) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 06c0588a6..7795347f8 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -39,14 +39,14 @@ def main( log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - update_per_step=update_per_step, + update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, ) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 5a5f08cc6..fe6a1357c 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -128,9 +128,9 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( - num_epochs=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, ) ) @@ -205,12 +205,12 @@ def save_best_fn(policy: Algorithm) -> None: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 2a6e372b4..7099fcbe0 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -43,14 +43,14 @@ def main( log_name = os.path.join(task, "npg", str(experiment_config.seed), datetime_tag()) training_config = OnPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + update_step_num_repetitions=repeat_per_collect, ) env_factory = MujocoEnvFactory( diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 061036d14..ea4b39d96 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -131,9 +131,9 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( - num_epochs=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, ) ) @@ -213,12 +213,12 @@ def save_best_fn(policy: Algorithm) -> None: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 73a2cb711..b146012bd 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -48,14 +48,14 @@ def main( log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) training_config = OnPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + update_step_num_repetitions=repeat_per_collect, ) env_factory = MujocoEnvFactory( diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 1f2b0aae1..da0c050a0 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -59,15 +59,15 @@ def main( experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False) training_config = OnPolicyTrainingConfig( - num_epochs=1, - step_per_epoch=5000, + max_epochs=1, + epoch_num_steps=5000, batch_size=64, num_train_envs=5, num_test_envs=5, - num_test_episodes=5, + test_step_num_episodes=5, buffer_size=4096, step_per_collect=2048, - repeat_per_collect=1, + update_step_num_repetitions=1, ) env_factory = MujocoEnvFactory( diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 9d905081e..df22a2a03 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -186,14 +186,14 @@ def save_best_fn(policy: Algorithm) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index c35e0abb0..f0c6e2c04 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -45,14 +45,14 @@ def main( log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - update_per_step=update_per_step, + update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, ) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 2808dd9f1..53695a389 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -114,9 +114,9 @@ def main(args: argparse.Namespace = get_args()) -> None: if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( - num_epochs=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, ) ) @@ -186,12 +186,12 @@ def save_best_fn(policy: Algorithm) -> None: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 156c05fff..21508e2e1 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -39,14 +39,14 @@ def main( log_name = os.path.join(task, "reinforce", str(experiment_config.seed), datetime_tag()) training_config = OnPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + update_step_num_repetitions=repeat_per_collect, ) env_factory = MujocoEnvFactory( diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 64fe6a60b..2f1bc6e93 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -179,14 +179,14 @@ def save_best_fn(policy: Algorithm) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 6c3c8352d..3a7d0dd78 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -41,14 +41,14 @@ def main( log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, batch_size=batch_size, step_per_collect=step_per_collect, - update_per_step=update_per_step, + update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, ) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 5e838a9ba..c0626a0d3 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -180,14 +180,14 @@ def save_best_fn(policy: Algorithm) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 2fd3e731c..8fcbe8168 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -46,13 +46,13 @@ def main( log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag()) training_config = TrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, batch_size=batch_size, - step_per_collect=step_per_collect, + collection_step_num_env_steps=step_per_collect, update_per_step=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 77863cc7a..0d81046bf 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -131,9 +131,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( - num_epochs=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, ) ) @@ -210,12 +210,12 @@ def save_best_fn(policy: Algorithm) -> None: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 0399fea8d..e1a9bb4cd 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -45,14 +45,14 @@ def main( log_name = os.path.join(task, "trpo", str(experiment_config.seed), datetime_tag()) training_config = OnPolicyTrainingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, + max_epochs=epoch, + epoch_num_steps=step_per_epoch, batch_size=batch_size, num_train_envs=training_num, num_test_envs=test_num, buffer_size=buffer_size, step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + update_step_num_repetitions=repeat_per_collect, ) env_factory = MujocoEnvFactory( diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 9c8921071..a9f478dc3 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -202,9 +202,9 @@ def watch() -> None: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.update_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index e0cfbeee1..73fd01d58 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -188,9 +188,9 @@ def watch() -> None: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.update_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 79e3b67cb..4c95e5c86 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -202,9 +202,9 @@ def watch() -> None: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.update_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 10f9d4159..5481b61f3 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -167,9 +167,9 @@ def watch() -> None: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.update_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 997c02098..7d96a3ab1 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -214,9 +214,9 @@ def watch() -> None: OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index c27b1ffed..43996909c 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -354,9 +354,9 @@ def watch() -> None: OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 0bb5cb3a4..23089462c 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -157,9 +157,9 @@ def watch() -> None: OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 3242b780b..261f96458 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -205,9 +205,9 @@ def watch() -> None: OfflineTrainerParams( buffer=replay_buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 185739c67..676b467e0 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -208,16 +208,16 @@ def watch() -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 28bfaa979..26ac04a9f 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -139,9 +139,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: if args.lr_decay: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( - num_epochs=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, ) ) @@ -283,12 +283,12 @@ def watch() -> None: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index b6ec0425c..12918f77c 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -128,12 +128,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 8d2578707..938f5fd7a 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -146,12 +146,12 @@ def stop_fn(mean_rewards: float) -> bool: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 762615962..d87c2585a 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -171,13 +171,13 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - step_per_collect=None, + collection_step_num_episodes=args.episode_per_collect, + collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index f8b2b4407..761bdaa9c 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -157,12 +157,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index b9df78c5a..cf4a04593 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -162,12 +162,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, @@ -216,10 +216,10 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=il_test_collector, - max_epoch=args.epoch, - step_per_epoch=args.il_step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.il_step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 8aaf83797..f9e57d7af 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -144,12 +144,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index e29108f26..741cf325f 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -148,12 +148,12 @@ def stop_fn(mean_rewards: float) -> bool: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 5801654f3..7f4515c75 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -143,13 +143,13 @@ def stop_fn(mean_rewards: float) -> bool: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - step_per_collect=None, + collection_step_num_episodes=args.episode_per_collect, + collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, @@ -199,10 +199,10 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=il_test_collector, - max_epoch=args.epoch, - step_per_epoch=args.il_step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.il_step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 38ca7b64b..f3db73902 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -139,12 +139,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, stop_fn=stop_fn, test_in_train=True, diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 1eef06ea2..5b9183aba 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -192,12 +192,12 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 711068072..73c7c4afe 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -142,15 +142,15 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, ) ) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 475c57caf..a1b02f4a3 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -151,12 +151,12 @@ def train_fn(epoch: int, env_step: int) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 2ab6c6a66..0d289d476 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -120,12 +120,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index fd27b28f1..21faadb77 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -166,16 +166,16 @@ def train_fn(epoch: int, env_step: int) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=True, ) ) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 7e0ac5d57..ded071df1 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -162,16 +162,16 @@ def train_fn(epoch: int, env_step: int) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=True, ) ) diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index 0de6299a4..51809a98a 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -144,12 +144,12 @@ def stop_fn(mean_rewards: float) -> bool: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index ae08523f3..6c8d385a3 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -157,16 +157,16 @@ def train_fn(epoch: int, env_step: int) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=True, ) ) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 910c9f829..228fdb619 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -211,12 +211,12 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/discrete/test_reinforce.py b/test/discrete/test_reinforce.py index ab222ae26..9c17c9164 100644 --- a/test/discrete/test_reinforce.py +++ b/test/discrete/test_reinforce.py @@ -117,13 +117,13 @@ def stop_fn(mean_rewards: float) -> bool: training_config = OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - step_per_collect=None, + collection_step_num_episodes=args.episode_per_collect, + collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 86c298abe..a0ec22f2a 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -33,17 +33,21 @@ def create_training_config( num_test_envs: int = 2, ) -> OffPolicyTrainingConfig | OnPolicyTrainingConfig: if issubclass(builder_cls, OffPolicyExperimentBuilder): - cfg_class = OffPolicyTrainingConfig + return OffPolicyTrainingConfig( + max_epochs=num_epochs, + epoch_num_steps=step_per_epoch, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, + ) elif issubclass(builder_cls, OnPolicyExperimentBuilder): - cfg_class = OnPolicyTrainingConfig + return OnPolicyTrainingConfig( + max_epochs=num_epochs, + epoch_num_steps=step_per_epoch, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, + ) else: raise ValueError - return cfg_class( - num_epochs=num_epochs, - step_per_epoch=step_per_epoch, - num_train_envs=num_train_envs, - num_test_envs=num_test_envs, - ) @pytest.mark.parametrize( diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index ada719005..90af3e10b 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -191,12 +191,12 @@ def train_fn(epoch: int, env_step: int) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index acf9daa4c..1b86b5fbb 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -190,12 +190,12 @@ def stop_fn(mean_rewards: float) -> bool: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - step_per_collect=args.step_per_collect, + collection_step_num_env_steps=args.step_per_collect, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index c9b0dffbc..1dc55d81a 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -116,13 +116,13 @@ def stop_fn(mean_rewards: float) -> bool: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=1, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=1, + test_step_num_episodes=args.test_num, batch_size=0, - episode_per_collect=args.episode_per_collect, - step_per_collect=None, + collection_step_num_episodes=args.episode_per_collect, + collection_step_num_env_steps=None, stop_fn=stop_fn, logger=logger, test_in_train=False, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 43928563a..c8ed0b770 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -154,16 +154,16 @@ def train_fn(epoch: int, env_step: int) -> None: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=True, ) ) diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 1dc70486f..e8e9cb625 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -150,12 +150,12 @@ def stop_fn(mean_rewards: float) -> bool: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, save_best_fn=save_best_fn, stop_fn=stop_fn, logger=logger, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 3d02c1ee0..67215ec05 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -192,9 +192,9 @@ def watch() -> None: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 69204a976..b0f0e603d 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -186,9 +186,9 @@ def stop_fn(mean_rewards: float) -> bool: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index a346be9b2..5641b2c2e 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -160,9 +160,9 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index c0232ff02..95dde08f6 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -129,9 +129,9 @@ def stop_fn(mean_rewards: float) -> bool: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index d1a1fafca..eb83aa4cd 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -130,9 +130,9 @@ def stop_fn(mean_rewards: float) -> bool: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index e6549bf5b..8bd23608e 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -203,13 +203,13 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - step_per_collect=None, + collection_step_num_episodes=args.episode_per_collect, + collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 07c8e7b42..cea928be0 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -176,9 +176,9 @@ def stop_fn(mean_rewards: float) -> bool: OfflineTrainerParams( buffer=buffer, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 907eef5f0..e9406dd5b 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -167,14 +167,14 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, logger=logger, test_in_train=False, reward_metric=reward_metric, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 14da09f63..def27cf14 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -272,13 +272,13 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: OnPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - repeat_per_collect=args.repeat_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + update_step_num_repetitions=args.repeat_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, - episode_per_collect=args.episode_per_collect, - step_per_collect=None, + collection_step_num_episodes=args.episode_per_collect, + collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 934bf1464..334219e77 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -212,14 +212,14 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, - max_epoch=args.epoch, - step_per_epoch=args.step_per_epoch, - step_per_collect=args.step_per_collect, - episode_per_test=args.test_num, + max_epochs=args.epoch, + epoch_num_steps=args.step_per_epoch, + collection_step_num_env_steps=args.step_per_collect, + test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, - update_per_step=args.update_per_step, + update_step_num_gradient_steps_per_sample=args.update_per_step, logger=logger, test_in_train=False, reward_metric=reward_metric, diff --git a/tianshou/algorithm/optim.py b/tianshou/algorithm/optim.py index b949a96f2..c802f95d4 100644 --- a/tianshou/algorithm/optim.py +++ b/tianshou/algorithm/optim.py @@ -25,10 +25,10 @@ class LRSchedulerFactoryLinear(LRSchedulerFactory): zero for the given trainer parameters. """ - def __init__(self, num_epochs: int, step_per_epoch: int, step_per_collect: int): - self.num_epochs = num_epochs - self.step_per_epoch = step_per_epoch - self.step_per_collect = step_per_collect + def __init__(self, max_epochs: int, epoch_num_steps: int, collection_step_num_env_steps: int): + self.num_epochs = max_epochs + self.epoch_num_steps = epoch_num_steps + self.collection_step_num_env_steps = collection_step_num_env_steps def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: return LambdaLR(optim, lr_lambda=self._LRLambda(self).compute) @@ -36,7 +36,8 @@ def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: class _LRLambda: def __init__(self, parent: "LRSchedulerFactoryLinear"): self.max_update_num = ( - np.ceil(parent.step_per_epoch / parent.step_per_collect) * parent.num_epochs + np.ceil(parent.epoch_num_steps / parent.collection_step_num_env_steps) + * parent.num_epochs ) def compute(self, epoch: int) -> float: diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index b46aa6e74..6f4dc5220 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -220,12 +220,12 @@ def create_trainer( OnPolicyTrainerParams( train_collector=world.train_collector, test_collector=world.test_collector, - max_epoch=training_config.num_epochs, - step_per_epoch=training_config.step_per_epoch, - repeat_per_collect=training_config.repeat_per_collect, - episode_per_test=training_config.num_test_episodes, + max_epochs=training_config.max_epochs, + epoch_num_steps=training_config.epoch_num_steps, + update_step_num_repetitions=training_config.update_step_num_repetitions, + test_step_num_episodes=training_config.test_step_num_episodes, batch_size=training_config.batch_size, - step_per_collect=training_config.step_per_collect, + collection_step_num_env_steps=training_config.collection_step_num_env_steps, save_best_fn=policy_persistence.get_save_best_fn(world), save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), logger=world.logger, @@ -268,14 +268,14 @@ def create_trainer( OffPolicyTrainerParams( train_collector=world.train_collector, test_collector=world.test_collector, - max_epoch=training_config.num_epochs, - step_per_epoch=training_config.step_per_epoch, - step_per_collect=training_config.step_per_collect, - episode_per_test=training_config.num_test_episodes, + max_epochs=training_config.max_epochs, + epoch_num_steps=training_config.epoch_num_steps, + collection_step_num_env_steps=training_config.collection_step_num_env_steps, + test_step_num_episodes=training_config.test_step_num_episodes, batch_size=training_config.batch_size, save_best_fn=policy_persistence.get_save_best_fn(world), logger=world.logger, - update_per_step=training_config.update_per_step, + update_step_num_gradient_steps_per_sample=training_config.update_step_num_gradient_steps_per_sample, test_in_train=training_config.test_in_train, train_fn=train_fn, test_fn=test_fn, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 485071658..223dfe725 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -11,13 +11,13 @@ class TrainingConfig(ToStringMixin): """Training configuration.""" - num_epochs: int = 100 + max_epochs: int = 100 """ the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each epoch consists of a number of training steps and one test step, where each training step * [for the online case] collects environment steps/transitions (**collection step**), - adding them to the (replay) buffer (see :attr:`step_per_collect` and :attr:`episode_per_collect`) + adding them to the (replay) buffer (see :attr:`collection_step_num_env_steps` and :attr:`episode_per_collect`) * performs an **update step** via the RL algorithm being used, which can involve one or more actual gradient updates, depending on the algorithm @@ -27,21 +27,22 @@ class TrainingConfig(ToStringMixin): Training may be stopped early if the stop criterion is met (see :attr:`stop_fn`). For online training, the number of training steps in each epoch is indirectly determined by - :attr:`step_per_epoch`: As many training steps will be performed as are required in - order to reach :attr:`step_per_epoch` total steps in the training environments. + :attr:`epoch_num_steps`: As many training steps will be performed as are required in + order to reach :attr:`epoch_num_steps` total steps in the training environments. Specifically, if the number of transitions collected per step is `c` (see - :attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number + :attr:`collection_step_num_env_steps`) and :attr:`epoch_num_steps` is set to `s`, then the number of training steps per epoch is `ceil(s / c)`. - Therefore, if `num_epochs = e`, the total number of environment steps taken during training + Therefore, if `max_epochs = e`, the total number of environment steps taken during training can be computed as `e * ceil(s / c) * c`. - For offline training, the number of training steps per epoch is equal to :attr:`step_per_epoch`. + For offline training, the number of training steps per epoch is equal to :attr:`epoch_num_steps`. """ - step_per_epoch: int = 30000 + epoch_num_steps: int = 30000 """ - the total number of environment steps to be made per epoch. See :attr:`num_epochs` for - an explanation of epoch semantics. + For an online algorithm, this is the total number of environment steps to be collected per epoch, and, + for an offline algorithm, it is the total number of training steps to take per epoch. + See :attr:`max_epochs` for an explanation of epoch semantics. """ num_train_envs: int = -1 @@ -53,7 +54,7 @@ class TrainingConfig(ToStringMixin): num_test_envs: int = 1 """the number of test environments to use""" - num_test_episodes: int = 1 + test_step_num_episodes: int = 1 """the total number of episodes to collect in each test step (across all test environments). """ @@ -61,7 +62,7 @@ class TrainingConfig(ToStringMixin): """the total size of the sample/replay buffer, in which environment steps (transitions) are stored""" - step_per_collect: int | None = 2048 + collection_step_num_env_steps: int | None = 2048 """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. @@ -74,17 +75,17 @@ class TrainingConfig(ToStringMixin): Specifically, if this is set to `n` and `m` training environments are used, then the total number of transitions collected per collection step is `ceil(n / m) * m =: c`. - See :attr:`num_epochs` for information on the total number of environment steps being + See :attr:`max_epochs` for information on the total number of environment steps being collected during training. """ - episode_per_collect: int | None = None + collection_step_num_episodes: int | None = None """ the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected. - This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. + This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. """ start_timesteps: int = 0 @@ -141,28 +142,37 @@ def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() - if self.num_test_episodes == 0 and self.num_test_envs != 0: + if self.test_step_num_episodes == 0 and self.num_test_envs != 0: log.warning( f"Number of test episodes is set to 0, " f"but number of test environments is ({self.num_test_envs}). " f"This can cause unnecessary memory usage.", ) - if self.num_test_episodes != 0 and self.num_test_episodes % self.num_test_envs != 0: + if ( + self.test_step_num_episodes != 0 + and self.test_step_num_episodes % self.num_test_envs != 0 + ): log.warning( - f"Number of test episodes ({self.num_test_episodes} " + f"Number of test episodes ({self.test_step_num_episodes} " f"is not divisible by the number of test environments ({self.num_test_envs}). " f"This can cause unnecessary memory usage, it is recommended to adjust this.", ) assert ( - sum([self.step_per_collect is not None, self.episode_per_collect is not None]) == 1 - ), ("Only one of `step_per_collect` and `episode_per_collect` can be set.",) + sum( + [ + self.collection_step_num_env_steps is not None, + self.collection_step_num_episodes is not None, + ] + ) + == 1 + ), ("Only one of `collection_step_num_env_steps` and `episode_per_collect` can be set.",) @dataclass(kw_only=True) class OnlineTrainingConfig(TrainingConfig): - step_per_collect: int | None = 2048 + collection_step_num_env_steps: int | None = 2048 """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. @@ -175,17 +185,17 @@ class OnlineTrainingConfig(TrainingConfig): Specifically, if this is set to `n` and `m` training environments are used, then the total number of transitions collected per collection step is `ceil(n / m) * m =: c`. - See :attr:`num_epochs` for information on the total number of environment steps being + See :attr:`max_epochs` for information on the total number of environment steps being collected during training. """ - episode_per_collect: int | None = None + collection_step_num_episodes: int | None = None """ the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected. - This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. + This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. """ test_in_train: bool = False @@ -196,7 +206,7 @@ class OnlineTrainingConfig(TrainingConfig): Specifically, after each collect step, we check whether the early stopping criterion would be satisfied by data we collected (provided that at least one episode was indeed completed, such that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step - (collecting :attr:`episode_per_test` episodes in order to evaluate performance), and if the early + (collecting :attr:`test_step_num_episodes` episodes in order to evaluate performance), and if the early stopping criterion is also satisfied based on the test data, we stop training early. """ @@ -211,7 +221,7 @@ class OnPolicyTrainingConfig(OnlineTrainingConfig): used for the gradient update (no mini-batching). """ - repeat_per_collect: int = 1 + update_step_num_repetitions: int = 1 """ controls, within one update step of an on-policy algorithm, the number of times the full collected data is applied for gradient updates, i.e. if the parameter is @@ -227,11 +237,9 @@ class OffPolicyTrainingConfig(OnlineTrainingConfig): the the number of environment steps/transitions to sample from the buffer for a gradient update. """ - # TODO: Given our glossary, this is confusingly named. Should definitely contain the word "gradient"; - # also in corresponding TrainerParams object - update_per_step: float = 1.0 + update_step_num_gradient_steps_per_sample: float = 1.0 """ - the number of gradient steps to perform per sample collected (see :attr:`step_per_collect`). + the number of gradient steps to perform per sample collected (see :attr:`collection_step_num_env_steps`). Specifically, if this is set to `u` and the number of samples collected in the preceding collection step is `n`, then `round(u * n)` gradient steps will be performed. """ diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 1d99eca68..98f90c26c 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -20,15 +20,15 @@ def __init__(self, training_config: TrainingConfig): def create_lr_scheduler_factory(self) -> LRSchedulerFactory: if ( - self.training_config.step_per_epoch is None - or self.training_config.step_per_collect is None + self.training_config.epoch_num_steps is None + or self.training_config.collection_step_num_env_steps is None ): raise ValueError( - f"{self.__class__.__name__} requires step_per_epoch and step_per_collect to be set " + f"{self.__class__.__name__} requires epoch_num_steps and collection_step_num_env_steps to be set " f"in order for the scheduling to be well-defined." ) return LRSchedulerFactoryLinear( - num_epochs=self.training_config.num_epochs, - step_per_epoch=self.training_config.step_per_epoch, - step_per_collect=self.training_config.step_per_collect, + max_epochs=self.training_config.max_epochs, + epoch_num_steps=self.training_config.epoch_num_steps, + collection_step_num_env_steps=self.training_config.collection_step_num_env_steps, ) diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 212eb2c8e..0eabc8543 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -68,13 +68,13 @@ @dataclass(kw_only=True) class TrainerParams(ToStringMixin): - max_epoch: int = 100 + max_epochs: int = 100 """ the (maximum) number of epochs to run training for. An **epoch** is the outermost iteration level and each epoch consists of a number of training steps and one test step, where each training step * [for the online case] collects environment steps/transitions (**collection step**), - adding them to the (replay) buffer (see :attr:`step_per_collect` and :attr:`episode_per_collect`) + adding them to the (replay) buffer (see :attr:`collection_step_num_env_steps` and :attr:`collection_step_num_episodes`) * performs an **update step** via the RL algorithm being used, which can involve one or more actual gradient updates, depending on the algorithm @@ -87,19 +87,19 @@ class TrainerParams(ToStringMixin): :attr:`step_per_epoch`: As many training steps will be performed as are required in order to reach :attr:`step_per_epoch` total steps in the training environments. Specifically, if the number of transitions collected per step is `c` (see - :attr:`step_per_collect`) and :attr:`step_per_epoch` is set to `s`, then the number + :attr:`collection_step_num_env_steps`) and :attr:`step_per_epoch` is set to `s`, then the number of training steps per epoch is `ceil(s / c)`. - Therefore, if `num_epochs = e`, the total number of environment steps taken during training + Therefore, if `max_epochs = e`, the total number of environment steps taken during training can be computed as `e * ceil(s / c) * c`. For offline training, the number of training steps per epoch is equal to :attr:`step_per_epoch`. """ - step_per_epoch: int = 30000 + epoch_num_steps: int = 30000 """ - for an online algorithm, this is the total number of environment steps to be collected per epoch, and, + For an online algorithm, this is the total number of environment steps to be collected per epoch, and, for an offline algorithm, it is the total number of training steps to take per epoch. - See :attr:`num_epochs` for an explanation of epoch semantics. + See :attr:`max_epochs` for an explanation of epoch semantics. """ test_collector: BaseCollector | None = None @@ -107,7 +107,7 @@ class TrainerParams(ToStringMixin): the collector to use for test episode collection (test steps); if None, perform no test steps. """ - episode_per_test: int = 1 + test_step_num_episodes: int = 1 """the number of episodes to collect in each test step. """ @@ -211,9 +211,9 @@ def __post_init__(self) -> None: "save_best_fn is set while test steps are disabled (test_collector is None)" ) else: - if self.episode_per_test < 1: + if self.test_step_num_episodes < 1: raise ValueError( - "episode_per_test must be positive if test steps are enabled " + "test_step_num_episodes must be positive if test steps are enabled " "(test_collector not None)" ) @@ -225,12 +225,12 @@ class OnlineTrainerParams(TrainerParams): the collector with which to gather new data for training in each training step """ - step_per_collect: int | None = 2048 + collection_step_num_env_steps: int | None = 2048 """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. - This is mutually exclusive with :attr:`episode_per_collect`, and one of the two must be set. + This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same @@ -238,17 +238,17 @@ class OnlineTrainerParams(TrainerParams): Specifically, if this is set to `n` and `m` training environments are used, then the total number of transitions collected per collection step is `ceil(n / m) * m =: c`. - See :attr:`num_epochs` for information on the total number of environment steps being + See :attr:`max_epochs` for information on the total number of environment steps being collected during training. """ - episode_per_collect: int | None = None + collection_step_num_episodes: int | None = None """ the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected. - This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. + This is mutually exclusive with :attr:`collection_step_num_env_steps`, and one of the two must be set. """ test_in_train: bool = False @@ -258,14 +258,16 @@ class OnlineTrainerParams(TrainerParams): Specifically, after each collect step, we check whether the early stopping criterion (:attr:`stop_fn`) would be satisfied by data we collected (provided that at least one episode was indeed completed, such that we can evaluate returns, etc.). If the criterion is satisfied, we perform a full test step - (collecting :attr:`episode_per_test` episodes in order to evaluate performance), and if the early + (collecting :attr:`test_step_num_episodes` episodes in order to evaluate performance), and if the early stopping criterion is also satisfied based on the test data, we stop training early. """ def __post_init__(self) -> None: super().__post_init__() - if count_none(self.step_per_collect, self.episode_per_collect) != 1: - raise ValueError("Exactly one of {step_per_collect, episode_per_collect} must be set") + if count_none(self.collection_step_num_env_steps, self.collection_step_num_episodes) != 1: + raise ValueError( + "Exactly one of {collection_step_num_env_steps, collection_step_num_episodes} must be set" + ) if self.test_in_train and (self.test_collector is None or self.stop_fn is None): raise ValueError("test_in_train requires test_collector and stop_fn to be set") @@ -280,7 +282,7 @@ class OnPolicyTrainerParams(OnlineTrainerParams): used for the gradient update (no mini-batching). """ - repeat_per_collect: int = 1 + update_step_num_repetitions: int = 1 """ controls, within one update step of an on-policy algorithm, the number of times the full collected data is applied for gradient updates, i.e. if the parameter is @@ -296,10 +298,9 @@ class OffPolicyTrainerParams(OnlineTrainerParams): the the number of environment steps/transitions to sample from the buffer for a gradient update. """ - # TODO: Given our glossary, this is confusingly named. Should definitely contain the word "gradient" - update_per_step: float = 1.0 + update_step_num_gradient_steps_per_sample: float = 1.0 """ - the number of gradient steps to perform per sample collected (see :attr:`step_per_collect`). + the number of gradient steps to perform per sample collected (see :attr:`collection_step_num_env_steps`). Specifically, if this is set to `u` and the number of samples collected in the preceding collection step is `n`, then `round(u * n)` gradient steps will be performed. """ @@ -435,7 +436,7 @@ def reset(self, reset_collectors: bool = True, reset_collector_buffers: bool = F # make an initial test step to determine the initial best model if self.params.test_collector is not None: - assert self.params.episode_per_test is not None + assert self.params.test_step_num_episodes is not None assert not isinstance(self.params.test_collector, AsyncCollector) # Issue 700 self._test_step(force_update_best=True, log_msg_prefix="Initial test step") @@ -551,9 +552,9 @@ def execute_epoch(self) -> EpochStats: steps_done_in_this_epoch = 0 train_collect_stats, training_stats = None, None with self._pbar( - total=self.params.step_per_epoch, desc=f"Epoch #{self._epoch}", position=1 + total=self.params.epoch_num_steps, desc=f"Epoch #{self._epoch}", position=1 ) as t: - while steps_done_in_this_epoch < self.params.step_per_epoch and not self._stop_fn_flag: + while steps_done_in_this_epoch < self.params.epoch_num_steps and not self._stop_fn_flag: # perform a training step and update progress TraceLogger.log(log, lambda: "Training step") self._current_update_step += 1 @@ -634,7 +635,7 @@ def _collect_test_episodes( collector.reset(reset_stats=False) if self.params.test_fn: self.params.test_fn(self._epoch, self._env_step) - result = collector.collect(n_episode=self.params.episode_per_test) + result = collector.collect(n_episode=self.params.test_step_num_episodes) if self.params.reward_metric: # TODO: move into collector rew = self.params.reward_metric(result.returns) result.returns = rew @@ -654,7 +655,7 @@ def _test_step( :param force_update_best: whether to force updating of the best model stats (best score, reward, etc.) and call the `save_best_fn` callback """ - assert self.params.episode_per_test is not None + assert self.params.test_step_num_episodes is not None assert self.params.test_collector is not None # collect test episodes @@ -741,7 +742,7 @@ def run( reset_collectors=reset_collectors, reset_collector_buffers=reset_collector_buffers ) - while self._epoch < self.params.max_epoch and not self._stop_fn_flag: + while self._epoch < self.params.max_epochs and not self._stop_fn_flag: self.execute_epoch() return self._create_info_stats() @@ -901,15 +902,15 @@ def _collect_training_data(self) -> CollectStats: :return: the data collection stats """ - assert self.params.episode_per_test is not None + assert self.params.test_step_num_episodes is not None assert self.params.train_collector is not None if self.params.train_fn: self.params.train_fn(self._epoch, self._env_step) collect_stats = self.params.train_collector.collect( - n_step=self.params.step_per_collect, - n_episode=self.params.episode_per_collect, + n_step=self.params.collection_step_num_env_steps, + n_episode=self.params.collection_step_num_episodes, ) TraceLogger.log( log, @@ -1019,18 +1020,21 @@ def _update_step( # TODO: this is the only implementation where collect_stats is actually needed. Maybe change interface? collect_stats: CollectStatsBase, ) -> TrainingStats: - """Perform `update_per_step * n_collected_steps` gradient steps by sampling mini-batches from the buffer. + """Perform `update_step_num_gradient_steps_per_sample * n_collected_steps` gradient steps by sampling + mini-batches from the buffer. :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values in it will be replaced by their moving averages. """ assert self.params.train_collector is not None n_collected_steps = collect_stats.n_collected_steps - n_gradient_steps = round(self.params.update_per_step * n_collected_steps) + n_gradient_steps = round( + self.params.update_step_num_gradient_steps_per_sample * n_collected_steps + ) if n_gradient_steps == 0: raise ValueError( f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " - f"update_per_step={self.params.update_per_step}", + f"update_step_num_gradient_steps_per_sample={self.params.update_step_num_gradient_steps_per_sample}", ) update_stat = None @@ -1078,7 +1082,7 @@ def _update_step( training_stat = self.algorithm.update( buffer=self.params.train_collector.buffer, batch_size=self.params.batch_size, - repeat=self.params.repeat_per_collect, + repeat=self.params.update_step_num_repetitions, ) # just for logging, no functional role From d2800cf593d231857b8b94562c15e7093a7449d7 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 17:52:37 +0200 Subject: [PATCH 185/230] v2: Update test names for Reinforce --- test/discrete/test_reinforce.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/discrete/test_reinforce.py b/test/discrete/test_reinforce.py index 9c17c9164..53200439e 100644 --- a/test/discrete/test_reinforce.py +++ b/test/discrete/test_reinforce.py @@ -47,7 +47,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: +def test_reinforce(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -135,6 +135,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_pg_determinism() -> None: - main_fn = lambda args: test_pg(args, enable_assertions=False) +def test_reinforce_determinism() -> None: + main_fn = lambda args: test_reinforce(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_reinforce", main_fn, get_args()).run() From 839a929357ee060aa49b88c14605da521066369d Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 16 May 2025 18:18:10 +0200 Subject: [PATCH 186/230] v2: major refactoring - all Actors now follow a proper forward interface Also some fixes in Policy interfaces --- CHANGELOG.md | 6 +- README.md | 7 +- docs/01_tutorials/00_dqn.rst | 4 +- docs/01_tutorials/04_tictactoe.rst | 6 +- docs/01_tutorials/07_cheatsheet.rst | 4 +- docs/02_notebooks/L0_overview.ipynb | 4 +- docs/02_notebooks/L7_Experiment.ipynb | 4 +- examples/atari/atari_ppo.py | 3 +- examples/box2d/acrobot_dualdqn.py | 4 +- examples/box2d/bipedal_bdq.py | 4 +- examples/box2d/bipedal_hardcore_sac.py | 8 +- examples/box2d/lunarlander_dqn.py | 4 +- examples/box2d/mcc_sac.py | 8 +- examples/discrete/discrete_dqn.py | 4 +- examples/inverse/irl_gail.py | 8 +- examples/mujoco/fetch_her_ddpg.py | 6 +- examples/mujoco/mujoco_a2c.py | 6 +- examples/mujoco/mujoco_ddpg.py | 6 +- examples/mujoco/mujoco_npg.py | 6 +- examples/mujoco/mujoco_ppo.py | 6 +- examples/mujoco/mujoco_redq.py | 6 +- examples/mujoco/mujoco_reinforce.py | 4 +- examples/mujoco/mujoco_sac.py | 8 +- examples/mujoco/mujoco_td3.py | 8 +- examples/mujoco/mujoco_trpo.py | 6 +- examples/offline/d4rl_bcq.py | 6 +- examples/offline/d4rl_cql.py | 8 +- examples/offline/d4rl_il.py | 4 +- examples/offline/d4rl_td3_bc.py | 8 +- poetry.lock | 441 +++++++++++++++++++--- test/base/test_policy.py | 8 +- test/base/test_utils.py | 12 +- test/continuous/test_ddpg.py | 6 +- test/continuous/test_npg.py | 6 +- test/continuous/test_ppo.py | 6 +- test/continuous/test_redq.py | 6 +- test/continuous/test_sac_with_il.py | 10 +- test/continuous/test_td3.py | 8 +- test/continuous/test_trpo.py | 6 +- test/determinism_test.py | 2 +- test/discrete/test_a2c_with_il.py | 6 +- test/discrete/test_bdqn.py | 4 +- test/discrete/test_c51.py | 4 +- test/discrete/test_discrete_sac.py | 8 +- test/discrete/test_dqn.py | 4 +- test/discrete/test_fqf.py | 4 +- test/discrete/test_iqn.py | 4 +- test/discrete/test_pg.py | 4 +- test/discrete/test_ppo_discrete.py | 17 +- test/discrete/test_qrdqn.py | 4 +- test/discrete/test_rainbow.py | 4 +- test/modelbased/test_dqn_icm.py | 4 +- test/modelbased/test_ppo_icm.py | 4 +- test/offline/gather_cartpole_data.py | 4 +- test/offline/gather_pendulum_data.py | 6 +- test/offline/test_bcq.py | 4 +- test/offline/test_cql.py | 6 +- test/offline/test_discrete_bcq.py | 4 +- test/offline/test_discrete_cql.py | 4 +- test/offline/test_discrete_crr.py | 4 +- test/offline/test_gail.py | 8 +- test/offline/test_td3_bc.py | 8 +- test/pettingzoo/pistonball.py | 4 +- test/pettingzoo/tic_tac_toe.py | 4 +- tianshou/data/batch.py | 1 + tianshou/data/types.py | 18 +- tianshou/env/atari/atari_network.py | 82 ++-- tianshou/highlevel/module/actor.py | 13 +- tianshou/highlevel/module/critic.py | 8 +- tianshou/policy/base.py | 41 -- tianshou/policy/imitation/discrete_bcq.py | 22 +- tianshou/policy/modelfree/bdqn.py | 7 +- tianshou/policy/modelfree/c51.py | 4 +- tianshou/policy/modelfree/dqn.py | 31 +- tianshou/policy/modelfree/fqf.py | 3 - tianshou/policy/modelfree/iqn.py | 3 - tianshou/policy/modelfree/pg.py | 11 +- tianshou/utils/net/common.py | 140 ++++--- tianshou/utils/net/continuous.py | 6 +- tianshou/utils/net/discrete.py | 10 +- 80 files changed, 790 insertions(+), 414 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd389c68b..9c520927d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -192,9 +192,9 @@ Developers: * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. * Fix issues pertaining to the torch device assignment of network components (#810): * Remove 'device' member (and the corresponding constructor argument) from the following classes: - `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProbabilistic`, `ContinuousCritic`, + `BranchingActor`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProbabilistic`, `ContinuousCritic`, `DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`, - `IntrinsicCuriosityModule`, `Net`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, + `IntrinsicCuriosityModule`, `MLPActor`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, `RecurrentActorProb`, `RecurrentCritic`, `VAE` * (Peripheral change:) Require the use of keyword arguments for the constructors of all of these classes * Clean up handling of modules that define attribute `output_dim`, introducing the explicit base class @@ -415,7 +415,7 @@ A detailed list of changes can be found below. distribution type. #1032 - Exception no longer raised on `len` of empty `Batch`. #1084 - tests and examples are covered by `mypy`. #1077 -- `NetBase` is more used, stricter typing by making it generic. #1077 +- `Actor` is more used, stricter typing by making it generic. #1077 - Use explicit multiprocessing context for creating `Pipe` in `subproc.py`. #1102 diff --git a/README.md b/README.md index 5dbb66226..7fd901651 100644 --- a/README.md +++ b/README.md @@ -361,14 +361,17 @@ test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_nu Create the network as well as its optimizer: ```python -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network env = gym.make(task, render_mode="human") state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n -net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) +net = MLPActor( + state_shape=state_shape, action_shape=action_shape, + hidden_sizes=[128, 128, 128] +) optim = torch.optim.Adam(net.parameters(), lr=lr) ``` diff --git a/docs/01_tutorials/00_dqn.rst b/docs/01_tutorials/00_dqn.rst index 263ee3709..bb73d4c52 100644 --- a/docs/01_tutorials/00_dqn.rst +++ b/docs/01_tutorials/00_dqn.rst @@ -112,7 +112,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour import torch, numpy as np from torch import nn - class Net(nn.Module): + class MLPActor(nn.Module): def __init__(self, state_shape, action_shape): super().__init__() self.model = nn.Sequential( @@ -131,7 +131,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n - net = Net(state_shape, action_shape) + net = MLPActor(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst index ff7918e1f..372227069 100644 --- a/docs/01_tutorials/04_tictactoe.rst +++ b/docs/01_tutorials/04_tictactoe.rst @@ -206,7 +206,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul ) from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger - from tianshou.utils.net.common import Net + from tianshou.utils.net.common import MLPActor The explanation of each Tianshou class/function will be deferred to their first usages. Here we define some arguments and hyperparameters of the experiment. The meaning of arguments is clear by just looking at their names. :: @@ -284,7 +284,7 @@ The explanation of each Tianshou class/function will be deferred to their first The following ``get_agents`` function returns agents and their optimizers from either constructing a new policy, or loading from disk, or using the pass-in arguments. For the models: -- The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function; +- The action model we use is an instance of :class:`~tianshou.utils.net.common.MLPActor`, essentially a multi-layer perceptron with the ReLU activation function; - The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; - The opponent can be either a random agent :class:`~tianshou.policy.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves. @@ -307,7 +307,7 @@ Here it is: args.action_shape = env.action_space.shape or env.action_space.n if agent_learn is None: # model - net = Net( + net = MLPActor( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/docs/01_tutorials/07_cheatsheet.rst b/docs/01_tutorials/07_cheatsheet.rst index 51fece131..4631e7cde 100644 --- a/docs/01_tutorials/07_cheatsheet.rst +++ b/docs/01_tutorials/07_cheatsheet.rst @@ -283,12 +283,12 @@ Multi-GPU Training To enable training an RL agent with multiple GPUs for a standard environment (i.e., without nested observation) with default networks provided by Tianshou: 1. Import :class:`~tianshou.utils.net.common.DataParallelNet` from ``tianshou.utils.net.common``; -2. Change the ``device`` argument to ``None`` in the existing networks such as ``Net``, ``Actor``, ``Critic``, ``ActorProb`` +2. Change the ``device`` argument to ``None`` in the existing networks such as ``MLPActor``, ``Actor``, ``Critic``, ``ActorProb`` 3. Apply ``DataParallelNet`` wrapper to these networks. :: - from tianshou.utils.net.common import Net, DataParallelNet + from tianshou.utils.net.common import MLPActor, DataParallelNet from tianshou.utils.net.discrete import Actor, Critic actor = DataParallelNet(Actor(net, args.action_shape, device=None).to(args.device)) diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index a9bf617bc..f82c67bc3 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -62,7 +62,7 @@ "from tianshou.env import DummyVectorEnv\n", "from tianshou.policy import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import ActorCritic, Net\n", + "from tianshou.utils.net.common import ActorCritic, MLPActor\n", "from tianshou.utils.net.discrete import Actor, Critic\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" @@ -85,7 +85,7 @@ "\n", "# model & optimizer\n", "assert env.observation_space.shape is not None # for mypy\n", - "net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", + "net = MLPActor(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", "\n", "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 47e4cb0c9..9ffd3107a 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -75,7 +75,7 @@ "from tianshou.env import DummyVectorEnv\n", "from tianshou.policy import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import ActorCritic, Net\n", + "from tianshou.utils.net.common import ActorCritic, MLPActor\n", "from tianshou.utils.net.discrete import Actor, Critic\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" @@ -137,7 +137,7 @@ "# net is the shared head of the actor and the critic\n", "assert env.observation_space.shape is not None # for mypy\n", "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", + "net = MLPActor(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", "critic = Critic(preprocess_net=net, device=device).to(device)\n", "actor_critic = ActorCritic(actor=actor, critic=critic)\n", diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 4060f55e7..15de98de7 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -14,7 +14,6 @@ DQNet, ScaledObsInputModule, layer_init, - scale_obs, ) from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault @@ -133,7 +132,7 @@ def main(args: argparse.Namespace = get_args()) -> None: layer_init=layer_init, ) if args.scale_obs: - net = scale_obs(net) + net = ScaledObsInputModule(net) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) critic = DiscreteCritic(preprocess_net=net) optim = AdamOptimizerFactory(lr=args.lr, eps=1e-5) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 1f1f16c1c..638c90998 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -15,7 +15,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.space_info import SpaceInfo @@ -69,7 +69,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index d53999193..84aaf2bdb 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -16,7 +16,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import BranchingNet +from tianshou.utils.net.common import BranchingActor def get_args() -> argparse.Namespace: @@ -93,7 +93,7 @@ def run_bdq(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = BranchingNet( + net = BranchingActor( state_shape=args.state_shape, num_branches=args.num_branches, action_per_branch=args.action_per_branch, diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 72aae0b34..13effb2ef 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -110,7 +110,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, @@ -118,7 +118,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -127,7 +127,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 5744451d1..25fdd819f 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -15,7 +15,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.space_info import SpaceInfo @@ -71,7 +71,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 03eeec085..f981e0d74 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -16,7 +16,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -68,12 +68,12 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -81,7 +81,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 3fbfb7801..1a2e607a0 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -25,7 +25,7 @@ def main() -> None: train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) - from tianshou.utils.net.common import Net + from tianshou.utils.net.common import MLPActor # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network @@ -34,7 +34,7 @@ def main() -> None: space_info = SpaceInfo.from_env(env) state_shape = space_info.observation_info.obs_shape action_shape = space_info.action_info.action_shape - net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) + net = MLPActor(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = AdamOptimizerFactory(lr=lr) policy = DiscreteQLearningPolicy( diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 56ed5b445..afbecec0c 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -29,7 +29,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -122,7 +122,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -132,7 +132,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -154,7 +154,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: optim = AdamOptimizerFactory(lr=args.lr) # discriminator - net_d = Net( + net_d = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 061edab8f..4a15b5703 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -27,7 +27,7 @@ from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import Net, get_dict_state_decorator +from tianshou.utils.net.common import MLPActor, get_dict_state_decorator from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import ActionSpaceInfo @@ -149,7 +149,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, keys=["observation", "achieved_goal", "desired_goal"], ) - net_a = dict_state_dec(Net)( + net_a = dict_state_dec(MLPActor)( flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device, @@ -161,7 +161,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c = dict_state_dec(Net)( + net_c = dict_state_dec(MLPActor)( flat_state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index a7f2a6eac..797d817d4 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -18,7 +18,7 @@ from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.common import ActorCritic, MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -89,7 +89,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -99,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 56f1a233f..0b3e8036a 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -17,7 +17,7 @@ from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic @@ -85,14 +85,14 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 86f1d30df..e6a380022 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -18,7 +18,7 @@ from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -94,7 +94,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -104,7 +104,7 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index ffbaa1d3a..af2e87548 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -18,7 +18,7 @@ from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.common import ActorCritic, MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -94,7 +94,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -104,7 +104,7 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 09fdadad3..13f7cc34c 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -17,7 +17,7 @@ from tianshou.policy.modelfree.sac import AutoAlpha from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import EnsembleLinear, Net +from tianshou.utils.net.common import EnsembleLinear, MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -89,7 +89,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, @@ -101,7 +101,7 @@ def main(args: argparse.Namespace = get_args()) -> None: def linear(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(args.ensemble_size, x, y) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 5cbb77ef9..1874479ff 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -18,7 +18,7 @@ from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic @@ -86,7 +86,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 685276af6..9c284c081 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -16,7 +16,7 @@ from tianshou.policy.modelfree.sac import AutoAlpha, SACPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -85,7 +85,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, @@ -93,13 +93,13 @@ def main(args: argparse.Namespace = get_args()) -> None: conditioned_sigma=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 06d974535..5f1fb2127 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -17,7 +17,7 @@ from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic @@ -90,20 +90,20 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 6f291cf47..2ba5148c8 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -18,7 +18,7 @@ from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -97,7 +97,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -107,7 +107,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 252d915e8..6715dac7e 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.common import MLP, MLPActor from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo @@ -109,13 +109,13 @@ def test_bcq() -> None: ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index b9c3d17b9..c88b69911 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -245,7 +245,7 @@ def test_cql() -> None: # model # actor network - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -259,13 +259,13 @@ def test_cql() -> None: actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 8a4a9c520..6a0823707 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorDeterministic from tianshou.utils.space_info import SpaceInfo @@ -83,7 +83,7 @@ def test_il() -> None: test_envs.seed(args.seed) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 7a1224c2a..9a14001ed 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -20,7 +20,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -104,7 +104,7 @@ def test_td3_bc() -> None: # model # actor network - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, ) @@ -116,13 +116,13 @@ def test_td3_bc() -> None: actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/poetry.lock b/poetry.lock index 6fca0635d..c009dd5ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -6,6 +6,7 @@ version = "2.0.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "absl-py-2.0.0.tar.gz", hash = "sha256:d9690211c5fcfefcdd1a45470ac2b5c5acd45241c3af71eed96bc5441746c0d5"}, {file = "absl_py-2.0.0-py3-none-any.whl", hash = "sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3"}, @@ -17,6 +18,7 @@ version = "0.0.4" description = "A collection of accessible pygments styles" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "accessible-pygments-0.0.4.tar.gz", hash = "sha256:e7b57a9b15958e9601c7e9eb07a440c813283545a20973f2574a5f453d0e953e"}, {file = "accessible_pygments-0.0.4-py2.py3-none-any.whl", hash = "sha256:416c6d8c1ea1c5ad8701903a20fcedf953c6e720d64f33dc47bfb2d3f2fa4e8d"}, @@ -31,6 +33,8 @@ version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, @@ -45,6 +49,7 @@ version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "alabaster-0.7.13-py3-none-any.whl", hash = "sha256:1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3"}, {file = "alabaster-0.7.13.tar.gz", hash = "sha256:a27a4a084d5e690e16e01e03ad2b2e552c61a65469419b907243193de1a84ae2"}, @@ -56,6 +61,8 @@ version = "0.8.1" description = "The Arcade Learning Environment (ALE) - a platform for AI research." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "ale_py-0.8.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:b2aa2f69a4169742800615970efe6914fa856e33eaf7fa9133c0e06a617a80e2"}, {file = "ale_py-0.8.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f2f6b92c8fd6189654979bbf0b305dbe0ecf82176c47f244d8c1cbc36286b89"}, @@ -91,6 +98,7 @@ version = "4.0.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"}, {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"}, @@ -102,7 +110,7 @@ sniffio = ">=1.1" [package.extras] doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17) ; python_version < \"3.12\" and platform_python_implementation == \"CPython\" and platform_system != \"Windows\""] trio = ["trio (>=0.22)"] [[package]] @@ -111,6 +119,7 @@ version = "1.4.1" description = "Handy tools for working with URLs and APIs." optional = false python-versions = ">=3.6.1" +groups = ["dev"] files = [ {file = "apeye-1.4.1-py3-none-any.whl", hash = "sha256:44e58a9104ec189bf42e76b3a7fe91e2b2879d96d48e9a77e5e32ff699c9204e"}, {file = "apeye-1.4.1.tar.gz", hash = "sha256:14ea542fad689e3bfdbda2189a354a4908e90aee4bf84c15ab75d68453d76a36"}, @@ -132,6 +141,7 @@ version = "1.1.4" description = "Core (offline) functionality for the apeye library." optional = false python-versions = ">=3.6.1" +groups = ["dev"] files = [ {file = "apeye_core-1.1.4-py3-none-any.whl", hash = "sha256:084bc696448d3ac428fece41c1f2eb08fa9d9ce1d1b2f4d43187e3def4528a60"}, {file = "apeye_core-1.1.4.tar.gz", hash = "sha256:72bb89fed3baa647cb81aa28e1d851787edcbf9573853b5d2b5f87c02f50eaf5"}, @@ -147,6 +157,8 @@ version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" optional = false python-versions = "*" +groups = ["dev"] +markers = "platform_system == \"Darwin\" or sys_platform == \"darwin\"" files = [ {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, @@ -158,6 +170,8 @@ version = "5.3.1" description = "ARCH for Python" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "arch-5.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:75fa6f9386ecc2df81bcbf5d055a290a697482ca51e0b3459dab183d288993cb"}, {file = "arch-5.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f9c9220d331618322517e0f2b3b3529f9c51f5e5a891441da4a107fd2d6d7fce"}, @@ -197,6 +211,7 @@ version = "23.1.0" description = "Argon2 for Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"}, {file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"}, @@ -217,6 +232,7 @@ version = "21.2.0" description = "Low-level CFFI bindings for Argon2" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"}, {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"}, @@ -254,6 +270,7 @@ version = "1.3.0" description = "Better dates & times for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80"}, {file = "arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85"}, @@ -273,6 +290,7 @@ version = "2.4.1" description = "Annotate AST trees with source code positions" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"}, {file = "asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0"}, @@ -282,8 +300,8 @@ files = [ six = ">=1.12.0" [package.extras] -astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] -test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] +astroid = ["astroid (>=1,<2) ; python_version < \"3\"", "astroid (>=2,<4) ; python_version >= \"3\""] +test = ["astroid (>=1,<2) ; python_version < \"3\"", "astroid (>=2,<4) ; python_version >= \"3\"", "pytest"] [[package]] name = "async-lru" @@ -291,6 +309,7 @@ version = "2.0.4" description = "Simple LRU cache for asyncio" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "async-lru-2.0.4.tar.gz", hash = "sha256:b8a59a5df60805ff63220b2a0c5b5393da5521b113cd5465a44eb037d81a5627"}, {file = "async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224"}, @@ -302,6 +321,7 @@ version = "23.1.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, @@ -312,7 +332,7 @@ cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] dev = ["attrs[docs,tests]", "pre-commit"] docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-no-zope = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.1.1) ; platform_python_implementation == \"CPython\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version < \"3.11\"", "pytest-xdist[psutil]"] [[package]] name = "autodocsumm" @@ -320,6 +340,7 @@ version = "0.2.11" description = "Extended sphinx autodoc including automatic autosummaries" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "autodocsumm-0.2.11-py3-none-any.whl", hash = "sha256:f1d0a623bf1ad64d979a9e23fd360d1fb1b8f869beaf3197f711552cddc174e2"}, {file = "autodocsumm-0.2.11.tar.gz", hash = "sha256:183212bd9e9f3b58a96bb21b7958ee4e06224107aa45b2fd894b61b83581b9a9"}, @@ -334,6 +355,7 @@ version = "2.0.4" description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "autopep8-2.0.4-py2.py3-none-any.whl", hash = "sha256:067959ca4a07b24dbd5345efa8325f5f58da4298dab0dde0443d5ed765de80cb"}, {file = "autopep8-2.0.4.tar.gz", hash = "sha256:2913064abd97b3419d1cc83ea71f042cb821f87e45b9c88cad5ad3c4ea87fe0c"}, @@ -348,6 +370,8 @@ version = "0.4.2" description = "Automated installation of Atari ROMs for Gym/ALE-Py" optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "AutoROM-0.4.2-py3-none-any.whl", hash = "sha256:719c9d363ef08391fdb7003d70df235b68f36de628d289a946c4a59a3adefa13"}, {file = "AutoROM-0.4.2.tar.gz", hash = "sha256:b426a39bc0ee3781c7791f28963a9b2e4385b6421eeaf2f368edc00c761d428a"}, @@ -368,6 +392,8 @@ version = "0.6.1" description = "Automated installation of Atari ROMs for Gym/ALE-Py" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "AutoROM.accept-rom-license-0.6.1.tar.gz", hash = "sha256:0c905a708d634a076f686802f672817d3585259ce3be0bde8713a4fb59e3159e"}, ] @@ -385,6 +411,7 @@ version = "2.13.1" description = "Internationalization utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "Babel-2.13.1-py3-none-any.whl", hash = "sha256:7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed"}, {file = "Babel-2.13.1.tar.gz", hash = "sha256:33e0952d7dd6374af8dbf6768cc4ddf3ccfefc244f9986d4074704f2fbd18900"}, @@ -402,6 +429,7 @@ version = "4.12.2" description = "Screen-scraping library" optional = false python-versions = ">=3.6.0" +groups = ["dev"] files = [ {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, @@ -420,6 +448,7 @@ version = "23.11.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, @@ -462,6 +491,7 @@ version = "6.1.0" description = "An easy safelist-based HTML-sanitizing tool." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "bleach-6.1.0-py3-none-any.whl", hash = "sha256:3225f354cfc436b9789c66c4ee030194bee0568fbf9cbdad3bc8b5c26c5f12b6"}, {file = "bleach-6.1.0.tar.gz", hash = "sha256:0a31f1837963c41d46bbf1331b8778e1308ea0791db03cc4e7357b97cf42a8fe"}, @@ -480,6 +510,8 @@ version = "2.3.5" description = "Python Box2D" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"box2d\"" files = [ {file = "box2d-py-2.3.5.tar.gz", hash = "sha256:b37dc38844bcd7def48a97111d2b082e4f81cca3cece7460feb3eacda0da2207"}, {file = "box2d_py-2.3.5-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:287aa54005c0644b47bf7ad72966e4068d66e56bcf8458f5b4a653ffe42a2618"}, @@ -494,6 +526,7 @@ version = "0.13.1" description = "httplib2 caching for requests" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "cachecontrol-0.13.1-py3-none-any.whl", hash = "sha256:95dedbec849f46dda3137866dc28b9d133fc9af55f5b805ab1291833e4457aa4"}, {file = "cachecontrol-0.13.1.tar.gz", hash = "sha256:f012366b79d2243a6118309ce73151bf52a38d4a5dac8ea57f09bd29087e506b"}, @@ -515,6 +548,7 @@ version = "5.3.2" description = "Extensible memoizing collections and decorators" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"}, {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, @@ -526,6 +560,7 @@ version = "2024.7.4" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, @@ -537,6 +572,7 @@ version = "1.16.0" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, @@ -601,6 +637,7 @@ version = "3.4.0" description = "Validate configuration and produce human readable error messages." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, @@ -612,6 +649,7 @@ version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" +groups = ["main", "dev"] files = [ {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, @@ -711,10 +749,12 @@ version = "8.1.7" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, ] +markers = {main = "extra == \"atari\""} [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} @@ -725,6 +765,7 @@ version = "3.0.0" description = "Pickler class to extend the standard pickle.Pickler functionality" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, @@ -736,10 +777,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "platform_system == \"Windows\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "comm" @@ -747,6 +790,7 @@ version = "0.2.0" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "comm-0.2.0-py3-none-any.whl", hash = "sha256:2da8d9ebb8dd7bfc247adaff99f24dce705638a8042b85cb995066793e391001"}, {file = "comm-0.2.0.tar.gz", hash = "sha256:a517ea2ca28931c7007a7a99c562a0fa5883cfb48963140cf642c41c948498be"}, @@ -764,6 +808,7 @@ version = "1.2.1" description = "Python library for calculating contours of 2D quadrilateral grids" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"}, {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"}, @@ -827,6 +872,7 @@ version = "7.3.2" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "coverage-7.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d872145f3a3231a5f20fd48500274d7df222e291d90baa2026cc5152b7ce86bf"}, {file = "coverage-7.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:310b3bb9c91ea66d59c53fa4989f57d2436e08f18fb2f421a1b0b6b8cc7fffda"}, @@ -883,7 +929,7 @@ files = [ ] [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cssutils" @@ -891,6 +937,7 @@ version = "2.9.0" description = "A CSS Cascading Style Sheets library for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "cssutils-2.9.0-py3-none-any.whl", hash = "sha256:f8b013169e281c0c6083207366c5005f5dd4549055f7aba840384fb06a78745c"}, {file = "cssutils-2.9.0.tar.gz", hash = "sha256:89477b3d17d790e97b9fb4def708767061055795aae6f7c82ae32e967c9be4cd"}, @@ -898,7 +945,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["cssselect", "importlib-resources", "jaraco.test (>=5.1)", "lxml", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +testing = ["cssselect", "importlib-resources ; python_version < \"3.9\"", "jaraco.test (>=5.1)", "lxml ; python_version < \"3.11\"", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-ruff"] [[package]] name = "cycler" @@ -906,6 +953,7 @@ version = "0.12.1" description = "Composable style cycles" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, @@ -921,6 +969,7 @@ version = "3.0.8" description = "The Cython compiler for writing C extensions in the Python language." optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["main"] files = [ {file = "Cython-3.0.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a846e0a38e2b24e9a5c5dc74b0e54c6e29420d88d1dafabc99e0fc0f3e338636"}, {file = "Cython-3.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45523fdc2b78d79b32834cc1cc12dc2ca8967af87e22a3ee1bff20e77c7f5520"}, @@ -988,6 +1037,7 @@ version = "1.8.0" description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "debugpy-1.8.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7fb95ca78f7ac43393cd0e0f2b6deda438ec7c5e47fa5d38553340897d2fbdfb"}, {file = "debugpy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef9ab7df0b9a42ed9c878afd3eaaff471fce3fa73df96022e1f5c9f8f8c87ada"}, @@ -1015,6 +1065,7 @@ version = "5.1.1" description = "Decorators for Humans" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, @@ -1026,6 +1077,7 @@ version = "7.0.1" description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"}, {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"}, @@ -1044,6 +1096,7 @@ version = "0.7.1" description = "XML bomb protection for Python stdlib modules" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"}, {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, @@ -1055,6 +1108,7 @@ version = "0.3.0.post1" description = "A μ-library for constructing cascading style sheets from Python dictionaries." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "dict2css-0.3.0.post1-py3-none-any.whl", hash = "sha256:f006a6b774c3e31869015122ae82c491fd25e7de4a75607a62aa3e798f837e0d"}, {file = "dict2css-0.3.0.post1.tar.gz", hash = "sha256:89c544c21c4ca7472c3fffb9d37d3d926f606329afdb751dc1de67a411b70719"}, @@ -1070,6 +1124,7 @@ version = "0.3.7" description = "Distribution utilities" optional = false python-versions = "*" +groups = ["main", "dev"] files = [ {file = "distlib-0.3.7-py2.py3-none-any.whl", hash = "sha256:2e24928bc811348f0feb63014e97aaae3037f2cf48712d51ae61df7fd6075057"}, {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, @@ -1081,6 +1136,8 @@ version = "1.6" description = "A Python interface for Reinforcement Learning environments." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "dm-env-1.6.tar.gz", hash = "sha256:a436eb1c654c39e0c986a516cee218bea7140b510fceff63f97eb4fcff3d93de"}, {file = "dm_env-1.6-py3-none-any.whl", hash = "sha256:0eabb6759dd453b625e041032f7ae0c1e87d4eb61b6a96b9ca586483837abf29"}, @@ -1097,6 +1154,8 @@ version = "0.1.8" description = "Tree is a library for working with nested data structures." optional = true python-versions = "*" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"}, {file = "dm_tree-0.1.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60"}, @@ -1152,6 +1211,7 @@ version = "0.4.0" description = "Python bindings for the docker credentials store API" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"}, {file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"}, @@ -1166,6 +1226,8 @@ version = "0.15" description = "Parse Python docstrings in reST, Google and Numpydoc format" optional = true python-versions = ">=3.6,<4.0" +groups = ["main"] +markers = "extra == \"argparse\" or extra == \"eval\"" files = [ {file = "docstring_parser-0.15-py3-none-any.whl", hash = "sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9"}, {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"}, @@ -1177,6 +1239,7 @@ version = "0.20.1" description = "Docutils -- Python Documentation Utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "docutils-0.20.1-py3-none-any.whl", hash = "sha256:96f387a2c5562db4476f09f13bbab2192e764cac08ebbf3a34a95d9b1e4a59d6"}, {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"}, @@ -1188,6 +1251,7 @@ version = "3.7.0" description = "Helpful functions for Python 🐍 🛠️" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "domdf_python_tools-3.7.0-py3-none-any.whl", hash = "sha256:7b4d1c3bdb7402b872d43953824bf921ae2e52f893adbe5c0052a21a6efa2fe4"}, {file = "domdf_python_tools-3.7.0.tar.gz", hash = "sha256:df1af9a91649af0fb2a4e7b3a4b0a0936e4f78389dd7280dd6fd2f53a339ca71"}, @@ -1207,6 +1271,8 @@ version = "0.8.4" description = "\"C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.\"" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "envpool-0.8.4-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:9c6a1af66c8a18d798b3069e8eee4cde2e5942af22b25d058189714f2630b024"}, {file = "envpool-0.8.4-cp311-cp311-manylinux_2_24_x86_64.whl", hash = "sha256:2407294307a3e20c18787bb836a94cc0649e708b04d8a8200be674f5fc46f3b4"}, @@ -1231,13 +1297,14 @@ version = "2.0.1" description = "Get the currently executing AST node of a frame, and other information" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, ] [package.extras] -tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich ; python_version >= \"3.11\""] [[package]] name = "farama-notifications" @@ -1245,6 +1312,7 @@ version = "0.0.4" description = "Notifications for all Farama Foundation maintained libraries." optional = false python-versions = "*" +groups = ["main"] files = [ {file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"}, {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, @@ -1256,6 +1324,7 @@ version = "2.19.0" description = "Fastest Python implementation of JSON schema" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "fastjsonschema-2.19.0-py3-none-any.whl", hash = "sha256:b9fd1a2dd6971dbc7fee280a95bd199ae0dd9ce22beb91cc75e9c1c528a5170e"}, {file = "fastjsonschema-2.19.0.tar.gz", hash = "sha256:e25df6647e1bc4a26070b700897b07b542ec898dd4f1f6ea013e7f6a88417225"}, @@ -1270,6 +1339,7 @@ version = "3.13.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, @@ -1278,7 +1348,7 @@ files = [ [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] -typing = ["typing-extensions (>=4.8)"] +typing = ["typing-extensions (>=4.8) ; python_version < \"3.11\""] [[package]] name = "fonttools" @@ -1286,6 +1356,7 @@ version = "4.51.0" description = "Tools to manipulate font files" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74"}, {file = "fonttools-4.51.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308"}, @@ -1332,18 +1403,18 @@ files = [ ] [package.extras] -all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +all = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\"", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0) ; python_version <= \"3.12\"", "xattr ; sys_platform == \"darwin\"", "zopfli (>=0.1.4)"] graphite = ["lz4 (>=1.7.4.2)"] -interpolatable = ["munkres", "pycairo", "scipy"] +interpolatable = ["munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\""] lxml = ["lxml (>=4.0)"] pathops = ["skia-pathops (>=0.5.0)"] plot = ["matplotlib"] repacker = ["uharfbuzz (>=0.23.0)"] symfont = ["sympy"] -type1 = ["xattr"] +type1 = ["xattr ; sys_platform == \"darwin\""] ufo = ["fs (>=2.2.0,<3)"] -unicode = ["unicodedata2 (>=15.1.0)"] -woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] +unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""] +woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"] [[package]] name = "fqdn" @@ -1351,6 +1422,7 @@ version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" optional = false python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" +groups = ["dev"] files = [ {file = "fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014"}, {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, @@ -1362,6 +1434,8 @@ version = "1.4.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"}, @@ -1432,6 +1506,7 @@ version = "2023.10.0" description = "File-system specification" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "fsspec-2023.10.0-py3-none-any.whl", hash = "sha256:346a8f024efeb749d2a5fca7ba8854474b1ff9af7c3faaf636a4548781136529"}, {file = "fsspec-2023.10.0.tar.gz", hash = "sha256:330c66757591df346ad3091a53bd907e15348c2ba17d63fd54f5c39c4457d2a5"}, @@ -1467,6 +1542,7 @@ version = "4.0.11" description = "Git Object Database" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"}, {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"}, @@ -1481,6 +1557,7 @@ version = "3.1.41" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "GitPython-3.1.41-py3-none-any.whl", hash = "sha256:c36b6634d069b3f719610175020a9aed919421c87552185b085e04fbbdb10b7c"}, {file = "GitPython-3.1.41.tar.gz", hash = "sha256:ed66e624884f76df22c8e16066d567aaa5a37d5b5fa19db2c6df6f7156db9048"}, @@ -1490,7 +1567,7 @@ files = [ gitdb = ">=4.0.1,<5" [package.extras] -test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] +test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "sumtypes"] [[package]] name = "glfw" @@ -1498,6 +1575,8 @@ version = "2.6.5" description = "A ctypes-based wrapper for GLFW3." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "glfw-2.6.5-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_10_6_intel.whl", hash = "sha256:57d00367f8dc31b898a47ab22849bab9f87dff4b4c7a56d16d9a7158cda96c19"}, {file = "glfw-2.6.5-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-macosx_11_0_arm64.whl", hash = "sha256:a1a132e7d6f78ae7f32957b56de2fd996d2a416f9520adb40345cc9cf744d277"}, @@ -1519,6 +1598,7 @@ version = "2.23.4" description = "Google Authentication Library" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "google-auth-2.23.4.tar.gz", hash = "sha256:79905d6b1652187def79d491d6e23d0cbb3a21d3c7ba0dbaa9c8a01906b13ff3"}, {file = "google_auth-2.23.4-py2.py3-none-any.whl", hash = "sha256:d4bbc92fe4b8bfd2f3e8d88e5ba7085935da208ee38a134fc280e7ce682a05f2"}, @@ -1542,6 +1622,7 @@ version = "1.1.0" description = "Google Authentication Library" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "google-auth-oauthlib-1.1.0.tar.gz", hash = "sha256:83ea8c3b0881e453790baff4448e8a6112ac8778d1de9da0b68010b843937afb"}, {file = "google_auth_oauthlib-1.1.0-py2.py3-none-any.whl", hash = "sha256:089c6e587d36f4803ac7e0720c045c6a8b1fd1790088b8424975b90d0ee61c12"}, @@ -1560,6 +1641,8 @@ version = "3.0.1" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\"" files = [ {file = "greenlet-3.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f89e21afe925fcfa655965ca8ea10f24773a1791400989ff32f467badfe4a064"}, {file = "greenlet-3.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28e89e232c7593d33cac35425b58950789962011cc274aa43ef8865f2e11f46d"}, @@ -1630,6 +1713,7 @@ version = "1.59.3" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "grpcio-1.59.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:aca028a6c7806e5b61e5f9f4232432c52856f7fcb98e330b20b6bc95d657bdcc"}, {file = "grpcio-1.59.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:19ad26a7967f7999c8960d2b9fe382dae74c55b0c508c613a6c2ba21cddf2354"}, @@ -1696,6 +1780,8 @@ version = "0.26.2" description = "Gym: A universal API for reinforcement learning environments" optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "gym-0.26.2.tar.gz", hash = "sha256:e0d882f4b54f0c65f203104c24ab8a38b039f1289986803c7d02cdbe214fbcc4"}, ] @@ -1723,6 +1809,8 @@ version = "0.0.8" description = "Notices for gym" optional = true python-versions = "*" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "gym-notices-0.0.8.tar.gz", hash = "sha256:ad25e200487cafa369728625fe064e88ada1346618526102659b4640f2b4b911"}, {file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"}, @@ -1734,6 +1822,7 @@ version = "0.28.1" description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "gymnasium-0.28.1-py3-none-any.whl", hash = "sha256:7bc9a5bce1022f997d1dbc152fc91d1ac977bad9cc7794cdc25437010867cabf"}, {file = "gymnasium-0.28.1.tar.gz", hash = "sha256:4c2c745808792c8f45c6e88ad0a5504774394e0c126f6e3db555e720d3da6f24"}, @@ -1765,6 +1854,8 @@ version = "1.2.3" description = "Robotics environments for the Gymnasium repo." optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\"" files = [ {file = "gymnasium-robotics-1.2.3.tar.gz", hash = "sha256:b01eb9df74c0041e559e1251442ba1a59174bfc71a1c58519724d76df803c0b6"}, {file = "gymnasium_robotics-1.2.3-py3-none-any.whl", hash = "sha256:9c3cd7bcc7ac7a0efca03d5685a01686661c7fa678e34adfe4e15044580e7180"}, @@ -1788,6 +1879,7 @@ version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, @@ -1799,6 +1891,7 @@ version = "3.10.0" description = "Read and write HDF5 files from Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "h5py-3.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b963fb772964fc1d1563c57e4e2e874022ce11f75ddc6df1a626f42bd49ab99f"}, {file = "h5py-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:012ab448590e3c4f5a8dd0f3533255bc57f80629bf7c5054cf4c87b30085063c"}, @@ -1836,6 +1929,7 @@ version = "1.1" description = "HTML parser based on the WHATWG HTML specification" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "html5lib-1.1-py2.py3-none-any.whl", hash = "sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d"}, {file = "html5lib-1.1.tar.gz", hash = "sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f"}, @@ -1846,10 +1940,10 @@ six = ">=1.9" webencodings = "*" [package.extras] -all = ["chardet (>=2.2)", "genshi", "lxml"] +all = ["chardet (>=2.2)", "genshi", "lxml ; platform_python_implementation == \"CPython\""] chardet = ["chardet (>=2.2)"] genshi = ["genshi"] -lxml = ["lxml"] +lxml = ["lxml ; platform_python_implementation == \"CPython\""] [[package]] name = "httpcore" @@ -1857,6 +1951,7 @@ version = "1.0.5" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, @@ -1878,6 +1973,7 @@ version = "0.27.2" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, @@ -1891,7 +1987,7 @@ idna = "*" sniffio = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -1903,6 +1999,7 @@ version = "2.5.32" description = "File identification library for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "identify-2.5.32-py2.py3-none-any.whl", hash = "sha256:0b7656ef6cba81664b783352c73f8c24b39cf82f926f78f4550eda928e5e0545"}, {file = "identify-2.5.32.tar.gz", hash = "sha256:5d9979348ec1a21c768ae07e0a652924538e8bce67313a73cb0f681cf08ba407"}, @@ -1917,6 +2014,7 @@ version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.5" +groups = ["main", "dev"] files = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, @@ -1928,6 +2026,8 @@ version = "2.33.1" description = "Library for reading and writing a wide range of image, video, scientific, and volumetric data formats." optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "imageio-2.33.1-py3-none-any.whl", hash = "sha256:c5094c48ccf6b2e6da8b4061cd95e1209380afafcbeae4a4e280938cce227e1d"}, {file = "imageio-2.33.1.tar.gz", hash = "sha256:78722d40b137bd98f5ec7312119f8aea9ad2049f76f434748eb306b6937cc1ce"}, @@ -1960,6 +2060,7 @@ version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b"}, {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"}, @@ -1971,6 +2072,7 @@ version = "6.8.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, @@ -1982,7 +2084,7 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +testing = ["flufl.flake8", "importlib-resources (>=1.3) ; python_version < \"3.9\"", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" @@ -1990,6 +2092,8 @@ version = "6.1.1" description = "Read resources from Python packages" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"}, {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"}, @@ -1997,7 +2101,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"] +testing = ["pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-ruff", "zipp (>=3.17)"] [[package]] name = "iniconfig" @@ -2005,6 +2109,7 @@ version = "2.0.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -2016,6 +2121,7 @@ version = "6.26.0" description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "ipykernel-6.26.0-py3-none-any.whl", hash = "sha256:3ba3dc97424b87b31bb46586b5167b3161b32d7820b9201a9e698c71e271602c"}, {file = "ipykernel-6.26.0.tar.gz", hash = "sha256:553856658eb8430bbe9653ea041a41bff63e9606fc4628873fc92a6cf3abd404"}, @@ -2049,6 +2155,7 @@ version = "8.17.2" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "ipython-8.17.2-py3-none-any.whl", hash = "sha256:1e4d1d666a023e3c93585ba0d8e962867f7a111af322efff6b9c58062b3e5444"}, {file = "ipython-8.17.2.tar.gz", hash = "sha256:126bb57e1895594bb0d91ea3090bbd39384f6fe87c3d57fd558d0670f50339bb"}, @@ -2085,6 +2192,7 @@ version = "8.1.1" description = "Jupyter interactive widgets" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ipywidgets-8.1.1-py3-none-any.whl", hash = "sha256:2b88d728656aea3bbfd05d32c747cfd0078f9d7e159cf982433b58ad717eed7f"}, {file = "ipywidgets-8.1.1.tar.gz", hash = "sha256:40211efb556adec6fa450ccc2a77d59ca44a060f4f9f136833df59c9f538e6e8"}, @@ -2106,6 +2214,7 @@ version = "20.11.0" description = "Operations with ISO 8601 durations" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "isoduration-20.11.0-py3-none-any.whl", hash = "sha256:b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042"}, {file = "isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9"}, @@ -2120,6 +2229,7 @@ version = "1.0.0" description = "Common backend for Jax or Numpy." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "jax-jumpy-1.0.0.tar.gz", hash = "sha256:195fb955cc4c2b7f0b1453e3cb1fb1c414a51a407ffac7a51e69a73cb30d59ad"}, {file = "jax_jumpy-1.0.0-py3-none-any.whl", hash = "sha256:ab7e01454bba462de3c4d098e3e585c302a8f06bc36d9182ab4e7e4aa7067c5e"}, @@ -2138,6 +2248,7 @@ version = "0.19.1" description = "An autocompletion tool for Python that can be used for text editors." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, @@ -2157,6 +2268,7 @@ version = "3.1.4" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, @@ -2174,6 +2286,8 @@ version = "1.4.0" description = "Lightweight pipelining with Python functions" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "joblib-1.4.0-py3-none-any.whl", hash = "sha256:42942470d4062537be4d54c83511186da1fc14ba354961a2114da91efa9a4ed7"}, {file = "joblib-1.4.0.tar.gz", hash = "sha256:1eb0dc091919cd384490de890cb5dfd538410a6d4b3b54eef09fb8c50b409b1c"}, @@ -2185,6 +2299,7 @@ version = "0.9.14" description = "A Python implementation of the JSON5 data format." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "json5-0.9.14-py2.py3-none-any.whl", hash = "sha256:740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f"}, {file = "json5-0.9.14.tar.gz", hash = "sha256:9ed66c3a6ca3510a976a9ef9b8c0787de24802724ab1860bc0153c7fdd589b02"}, @@ -2199,6 +2314,8 @@ version = "4.27.0" description = "Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"argparse\" or extra == \"eval\"" files = [ {file = "jsonargparse-4.27.0-py3-none-any.whl", hash = "sha256:a6378bc8b7bbe38b708f090b10ea8431216e71f8b2eea1f9a4f095ae4abd0f2e"}, {file = "jsonargparse-4.27.0.tar.gz", hash = "sha256:6ac791cd7913cff34ad2dbd3ed0431f9e327af0be926332ac060bd5b13d353f2"}, @@ -2214,7 +2331,7 @@ coverage = ["jsonargparse[test-no-urls]", "pytest-cov (>=4.0.0)"] dev = ["build (>=0.10.0)", "jsonargparse[coverage]", "jsonargparse[doc]", "jsonargparse[mypy]", "jsonargparse[test]", "pre-commit (>=2.19.0)", "tox (>=3.25.0)"] doc = ["Sphinx (>=1.7.9)", "autodocsumm (>=0.1.10)", "sphinx-autodoc-typehints (>=1.19.5)", "sphinx-rtd-theme (>=1.2.2)"] fsspec = ["fsspec (>=0.8.4)"] -jsonnet = ["jsonnet (>=0.13.0)", "jsonnet-binary (>=0.17.0)"] +jsonnet = ["jsonnet (>=0.13.0) ; os_name == \"posix\"", "jsonnet-binary (>=0.17.0) ; os_name != \"posix\""] jsonschema = ["jsonschema (>=3.2.0)"] maintainer = ["bump2version (>=0.5.11)", "twine (>=4.0.2)"] omegaconf = ["omegaconf (>=2.1.1)"] @@ -2223,7 +2340,7 @@ ruyaml = ["ruyaml (>=0.20.0)"] signatures = ["docstring-parser (>=0.15)", "jsonargparse[typing-extensions]", "typeshed-client (>=2.1.0)"] test = ["attrs (>=22.2.0)", "jsonargparse[test-no-urls]", "pydantic (>=2.3.0)", "responses (>=0.12.0)", "types-PyYAML (>=6.0.11)", "types-requests (>=2.28.9)"] test-no-urls = ["pytest (>=6.2.5)", "pytest-subtests (>=0.8.0)"] -typing-extensions = ["typing-extensions (>=3.10.0.0)"] +typing-extensions = ["typing-extensions (>=3.10.0.0) ; python_version < \"3.10\""] urls = ["requests (>=2.18.4)"] [[package]] @@ -2232,6 +2349,7 @@ version = "2.4" description = "Identify specific nodes in a JSON document (RFC 6901)" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" +groups = ["dev"] files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, @@ -2243,6 +2361,7 @@ version = "4.20.0" description = "An implementation of JSON Schema validation for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jsonschema-4.20.0-py3-none-any.whl", hash = "sha256:ed6231f0429ecf966f5bc8dfef245998220549cbbcf140f913b7464c52c3b6b3"}, {file = "jsonschema-4.20.0.tar.gz", hash = "sha256:4f614fd46d8d61258610998997743ec5492a648b33cf478c1ddc23ed4598a5fa"}, @@ -2272,6 +2391,7 @@ version = "2023.11.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jsonschema_specifications-2023.11.1-py3-none-any.whl", hash = "sha256:f596778ab612b3fd29f72ea0d990393d0540a5aab18bf0407a46632eab540779"}, {file = "jsonschema_specifications-2023.11.1.tar.gz", hash = "sha256:c9b234904ffe02f079bf91b14d79987faa685fd4b39c377a0996954c0090b9ca"}, @@ -2286,6 +2406,7 @@ version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "jupyter-1.0.0-py2.py3-none-any.whl", hash = "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78"}, {file = "jupyter-1.0.0.tar.gz", hash = "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"}, @@ -2306,6 +2427,7 @@ version = "1.0.0" description = "Build a book with Jupyter Notebooks and Sphinx." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "jupyter_book-1.0.0-py3-none-any.whl", hash = "sha256:18238f1e7e1d425731e60ab509a7da878dd6db88b7d77bcfab4690361b72e1be"}, {file = "jupyter_book-1.0.0.tar.gz", hash = "sha256:539c5d0493546200d9de27bd4b5f77eaea03115f8937f825d4ff82b3801a987e"}, @@ -2343,6 +2465,7 @@ version = "0.6.1" description = "A defined interface for working with a cache of jupyter notebooks." optional = false python-versions = "~=3.8" +groups = ["dev"] files = [ {file = "jupyter-cache-0.6.1.tar.gz", hash = "sha256:26f83901143edf4af2f3ff5a91e2d2ad298e46e2cee03c8071d37a23a63ccbfc"}, {file = "jupyter_cache-0.6.1-py3-none-any.whl", hash = "sha256:2fce7d4975805c77f75bdfc1bc2e82bc538b8e5b1af27f2f5e06d55b9f996a82"}, @@ -2370,6 +2493,7 @@ version = "8.6.0" description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_client-8.6.0-py3-none-any.whl", hash = "sha256:909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99"}, {file = "jupyter_client-8.6.0.tar.gz", hash = "sha256:0642244bb83b4764ae60d07e010e15f0e2d275ec4e918a8f7b80fbbef3ca60c7"}, @@ -2384,7 +2508,7 @@ traitlets = ">=5.3" [package.extras] docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] -test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] +test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko ; sys_platform == \"win32\"", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] [[package]] name = "jupyter-console" @@ -2392,6 +2516,7 @@ version = "6.6.3" description = "Jupyter terminal console" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485"}, {file = "jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539"}, @@ -2416,6 +2541,7 @@ version = "5.5.0" description = "Jupyter core package. A base package on which Jupyter projects rely." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_core-5.5.0-py3-none-any.whl", hash = "sha256:e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805"}, {file = "jupyter_core-5.5.0.tar.gz", hash = "sha256:880b86053bf298a8724994f95e99b99130659022a4f7f45f563084b6223861d3"}, @@ -2436,6 +2562,7 @@ version = "0.9.0" description = "Jupyter Event System library" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_events-0.9.0-py3-none-any.whl", hash = "sha256:d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf"}, {file = "jupyter_events-0.9.0.tar.gz", hash = "sha256:81ad2e4bc710881ec274d31c6c50669d71bbaa5dd9d01e600b56faa85700d399"}, @@ -2461,6 +2588,7 @@ version = "2.2.2" description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter-lsp-2.2.2.tar.gz", hash = "sha256:256d24620542ae4bba04a50fc1f6ffe208093a07d8e697fea0a8d1b8ca1b7e5b"}, {file = "jupyter_lsp-2.2.2-py3-none-any.whl", hash = "sha256:3b95229e4168355a8c91928057c1621ac3510ba98b2a925e82ebd77f078b1aa5"}, @@ -2475,6 +2603,7 @@ version = "2.11.2" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_server-2.11.2-py3-none-any.whl", hash = "sha256:0c548151b54bcb516ca466ec628f7f021545be137d01b5467877e87f6fff4374"}, {file = "jupyter_server-2.11.2.tar.gz", hash = "sha256:0c99f9367b0f24141e527544522430176613f9249849be80504c6d2b955004bb"}, @@ -2511,6 +2640,7 @@ version = "0.4.4" description = "A Jupyter Server Extension Providing Terminals." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyter_server_terminals-0.4.4-py3-none-any.whl", hash = "sha256:75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36"}, {file = "jupyter_server_terminals-0.4.4.tar.gz", hash = "sha256:57ab779797c25a7ba68e97bcfb5d7740f2b5e8a83b5e8102b10438041a7eac5d"}, @@ -2530,6 +2660,7 @@ version = "4.2.5" description = "JupyterLab computational environment" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyterlab-4.2.5-py3-none-any.whl", hash = "sha256:73b6e0775d41a9fee7ee756c80f58a6bed4040869ccc21411dc559818874d321"}, {file = "jupyterlab-4.2.5.tar.gz", hash = "sha256:ae7f3a1b8cb88b4f55009ce79fa7c06f99d70cd63601ee4aa91815d054f46f75"}, @@ -2563,6 +2694,7 @@ version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyterlab_pygments-0.2.2-py2.py3-none-any.whl", hash = "sha256:2405800db07c9f770863bcf8049a529c3dd4d3e28536638bd7c1c01d2748309f"}, {file = "jupyterlab_pygments-0.2.2.tar.gz", hash = "sha256:7405d7fde60819d905a9fa8ce89e4cd830e318cdad22a0030f7a901da705585d"}, @@ -2574,6 +2706,7 @@ version = "2.27.3" description = "A set of server components for JupyterLab and JupyterLab like applications." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4"}, {file = "jupyterlab_server-2.27.3.tar.gz", hash = "sha256:eb36caca59e74471988f0ae25c77945610b887f777255aa21f8065def9e51ed4"}, @@ -2599,6 +2732,7 @@ version = "3.0.9" description = "Jupyter interactive widgets for JupyterLab" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "jupyterlab_widgets-3.0.9-py3-none-any.whl", hash = "sha256:3cf5bdf5b897bf3bccf1c11873aa4afd776d7430200f765e0686bd352487b58d"}, {file = "jupyterlab_widgets-3.0.9.tar.gz", hash = "sha256:6005a4e974c7beee84060fdfba341a3218495046de8ae3ec64888e5fe19fdb4c"}, @@ -2610,6 +2744,7 @@ version = "1.4.5" description = "A fast implementation of the Cassowary constraint solver" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"}, {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"}, @@ -2723,6 +2858,7 @@ version = "2.0.1" description = "A lexer and codec to work with LaTeX code in Python." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "latexcodec-2.0.1-py2.py3-none-any.whl", hash = "sha256:c277a193638dc7683c4c30f6684e3db728a06efb0dc9cf346db8bd0aa6c5d271"}, {file = "latexcodec-2.0.1.tar.gz", hash = "sha256:2aa2551c373261cefe2ad3a8953a6d6533e68238d180eb4bb91d7964adb3fe9a"}, @@ -2737,6 +2873,7 @@ version = "2.0.2" description = "Links recognition library with FULL unicode support." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "linkify-it-py-2.0.2.tar.gz", hash = "sha256:19f3060727842c254c808e99d465c80c49d2c7306788140987a1a7a29b0d6ad2"}, {file = "linkify_it_py-2.0.2-py3-none-any.whl", hash = "sha256:a3a24428f6c96f27370d7fe61d2ac0be09017be5190d68d8658233171f1b6541"}, @@ -2757,6 +2894,7 @@ version = "0.43.0" description = "lightweight wrapper around basic LLVM functionality" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, @@ -2787,6 +2925,7 @@ version = "3.5.1" description = "Python implementation of John Gruber's Markdown." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "Markdown-3.5.1-py3-none-any.whl", hash = "sha256:5874b47d4ee3f0b14d764324d2c94c03ea66bee56f2d929da9f2508d65e722dc"}, {file = "Markdown-3.5.1.tar.gz", hash = "sha256:b65d7beb248dc22f2e8a31fb706d93798093c308dc1aba295aedeb9d41a813bd"}, @@ -2802,6 +2941,7 @@ version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, @@ -2826,6 +2966,7 @@ version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, @@ -2895,6 +3036,7 @@ version = "3.8.4" description = "Python plotting package" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "matplotlib-3.8.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014"}, {file = "matplotlib-3.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106"}, @@ -2943,6 +3085,7 @@ version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, @@ -2957,6 +3100,7 @@ version = "0.4.0" description = "Collection of plugins for markdown-it-py" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mdit_py_plugins-0.4.0-py3-none-any.whl", hash = "sha256:b51b3bb70691f57f974e257e367107857a93b36f322a9e6d44ca5bf28ec2def9"}, {file = "mdit_py_plugins-0.4.0.tar.gz", hash = "sha256:d8ab27e9aed6c38aa716819fedfde15ca275715955f8a185a8e1cf90fb1d2c1b"}, @@ -2976,6 +3120,7 @@ version = "0.1.2" description = "Markdown URL utilities" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, @@ -2987,6 +3132,7 @@ version = "3.0.2" description = "A sane and fast Markdown parser with useful plugins and renderers" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, @@ -2998,6 +3144,7 @@ version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, @@ -3006,7 +3153,7 @@ files = [ [package.extras] develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] docs = ["sphinx"] -gmpy = ["gmpy2 (>=2.1.0a4)"] +gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] tests = ["pytest (>=4.6)"] [[package]] @@ -3015,6 +3162,7 @@ version = "1.0.7" description = "MessagePack serializer" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "msgpack-1.0.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04ad6069c86e531682f9e1e71b71c1c3937d6014a7c3e9edd2aa81ad58842862"}, {file = "msgpack-1.0.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cca1b62fe70d761a282496b96a5e51c44c213e410a964bdffe0928e611368329"}, @@ -3080,6 +3228,8 @@ version = "2.3.7" description = "MuJoCo Physics Simulator" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"}, @@ -3120,6 +3270,7 @@ version = "1.7.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5da84d7bf257fd8f66b4f759a904fd2c5a765f70d8b52dde62b521972a0a2357"}, {file = "mypy-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3637c03f4025f6405737570d6cbfa4f1400eb3c649317634d273687a09ffc2f"}, @@ -3166,6 +3317,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -3177,6 +3329,7 @@ version = "1.0.0" description = "A Jupyter Notebook Sphinx reader built on top of the MyST markdown parser." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "myst_nb-1.0.0-py3-none-any.whl", hash = "sha256:ee8febc6dd7d9e32bede0c66a9b962b2e2fdab697428ee9fbfd4919d82380911"}, {file = "myst_nb-1.0.0.tar.gz", hash = "sha256:9077e42a1c6b441ea55078506f83555dda5d6c816ef4930841d71d239e3e0c5e"}, @@ -3205,6 +3358,7 @@ version = "2.0.0" description = "An extended [CommonMark](https://spec.commonmark.org/) compliant parser," optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "myst_parser-2.0.0-py3-none-any.whl", hash = "sha256:7c36344ae39c8e740dad7fdabf5aa6fc4897a813083c6cc9990044eb93656b14"}, {file = "myst_parser-2.0.0.tar.gz", hash = "sha256:ea929a67a6a0b1683cdbe19b8d2e724cd7643f8aa3e7bb18dd65beac3483bead"}, @@ -3231,6 +3385,7 @@ version = "8.4.0" description = "Simple yet flexible natural sorting in Python." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, @@ -3246,6 +3401,7 @@ version = "0.7.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." optional = false python-versions = ">=3.7.0" +groups = ["dev"] files = [ {file = "nbclient-0.7.4-py3-none-any.whl", hash = "sha256:c817c0768c5ff0d60e468e017613e6eae27b6fa31e43f905addd2d24df60c125"}, {file = "nbclient-0.7.4.tar.gz", hash = "sha256:d447f0e5a4cfe79d462459aec1b3dc5c2e9152597262be8ee27f7d4c02566a0d"}, @@ -3268,6 +3424,7 @@ version = "7.11.0" description = "Converting Jupyter Notebooks" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "nbconvert-7.11.0-py3-none-any.whl", hash = "sha256:d1d417b7f34a4e38887f8da5bdfd12372adf3b80f995d57556cb0972c68909fe"}, {file = "nbconvert-7.11.0.tar.gz", hash = "sha256:abedc01cf543177ffde0bfc2a69726d5a478f6af10a332fc1bf29fcb4f0cf000"}, @@ -3305,6 +3462,7 @@ version = "5.9.2" description = "The Jupyter Notebook format" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "nbformat-5.9.2-py3-none-any.whl", hash = "sha256:1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9"}, {file = "nbformat-5.9.2.tar.gz", hash = "sha256:5f98b5ba1997dff175e77e0c17d5c10a96eaed2cbd1de3533d1fc35d5e111192"}, @@ -3326,6 +3484,7 @@ version = "1.7.1" description = "Run any standard Python code quality tool on a Jupyter Notebook" optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "nbqa-1.7.1-py3-none-any.whl", hash = "sha256:77cdff622bfcf527bf260004449984edfb3624f6e065ac6bb35d64cddcdad483"}, {file = "nbqa-1.7.1.tar.gz", hash = "sha256:44f5b5000d6df438c4f1cba339e3ad80acc405e61f4500ac951fa36a177133f4"}, @@ -3346,6 +3505,7 @@ version = "0.6.1" description = "Strips outputs from Jupyter and IPython notebooks" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "nbstripout-0.6.1-py2.py3-none-any.whl", hash = "sha256:5ff6eb0debbcd656c4a64db8e082a24fabcfc753a9e8c9f6d786971e8f29e110"}, {file = "nbstripout-0.6.1.tar.gz", hash = "sha256:9065bcdd1488b386e4f3c081ffc1d48f4513a2f8d8bf4d0d9a28208c5dafe9d3"}, @@ -3360,6 +3520,7 @@ version = "1.5.8" description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "nest_asyncio-1.5.8-py3-none-any.whl", hash = "sha256:accda7a339a70599cb08f9dd09a67e0c2ef8d8d6f4c07f96ab203f2ae254e48d"}, {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"}, @@ -3371,6 +3532,7 @@ version = "3.2.1" description = "Python package for creating and manipulating graphs and networks" optional = false python-versions = ">=3.9" +groups = ["main", "dev"] files = [ {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, @@ -3389,6 +3551,7 @@ version = "1.8.0" description = "Node.js virtual environment builder" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +groups = ["dev"] files = [ {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, @@ -3403,6 +3566,7 @@ version = "7.2.2" description = "Jupyter Notebook - A web-based notebook environment for interactive computing" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "notebook-7.2.2-py3-none-any.whl", hash = "sha256:c89264081f671bc02eec0ed470a627ed791b9156cad9285226b31611d3e9fe1c"}, {file = "notebook-7.2.2.tar.gz", hash = "sha256:2ef07d4220421623ad3fe88118d687bc0450055570cdd160814a59cf3a1c516e"}, @@ -3418,7 +3582,7 @@ tornado = ">=6.2.0" [package.extras] dev = ["hatch", "pre-commit"] docs = ["myst-parser", "nbsphinx", "pydata-sphinx-theme", "sphinx (>=1.3.6)", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] -test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.27.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"] +test = ["importlib-resources (>=5.0) ; python_version < \"3.10\"", "ipykernel", "jupyter-server[test] (>=2.4.0,<3)", "jupyterlab-server[test] (>=2.27.1,<3)", "nbval", "pytest (>=7.0)", "pytest-console-scripts", "pytest-timeout", "pytest-tornasync", "requests"] [[package]] name = "notebook-shim" @@ -3426,6 +3590,7 @@ version = "0.2.3" description = "A shim layer for notebook traits and config" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "notebook_shim-0.2.3-py3-none-any.whl", hash = "sha256:a83496a43341c1674b093bfcebf0fe8e74cbe7eda5fd2bbc56f8e39e1486c0c7"}, {file = "notebook_shim-0.2.3.tar.gz", hash = "sha256:f69388ac283ae008cd506dda10d0288b09a017d822d5e8c7129a152cbd3ce7e9"}, @@ -3443,6 +3608,7 @@ version = "0.60.0" description = "compiling Python code using LLVM" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"}, {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"}, @@ -3477,6 +3643,7 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -3514,6 +3681,8 @@ version = "12.1.3.1" description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, @@ -3525,6 +3694,8 @@ version = "12.1.105" description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, @@ -3536,6 +3707,8 @@ version = "12.1.105" description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, @@ -3547,6 +3720,8 @@ version = "12.1.105" description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, @@ -3558,6 +3733,8 @@ version = "8.9.2.26" description = "cuDNN runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, ] @@ -3571,6 +3748,8 @@ version = "11.0.2.54" description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, @@ -3582,6 +3761,8 @@ version = "10.3.2.106" description = "CURAND native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, @@ -3593,6 +3774,8 @@ version = "11.4.5.107" description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, @@ -3609,6 +3792,8 @@ version = "12.1.0.106" description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, @@ -3623,6 +3808,8 @@ version = "2.18.1" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, ] @@ -3633,6 +3820,8 @@ version = "12.3.101" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"}, @@ -3645,6 +3834,8 @@ version = "12.1.105" description = "NVIDIA Tools Extension" optional = false python-versions = ">=3" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, @@ -3656,6 +3847,7 @@ version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, @@ -3672,6 +3864,8 @@ version = "4.8.1.78" description = "Wrapper package for OpenCV python bindings." optional = true python-versions = ">=3.6" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "opencv-python-4.8.1.78.tar.gz", hash = "sha256:cc7adbbcd1112877a39274106cb2752e04984bc01a031162952e97450d6117f6"}, {file = "opencv_python-4.8.1.78-cp37-abi3-macosx_10_16_x86_64.whl", hash = "sha256:91d5f6f5209dc2635d496f6b8ca6573ecdad051a09e6b5de4c399b8e673c60da"}, @@ -3691,6 +3885,8 @@ version = "0.10.0" description = "Optimized PyTree Utilities." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "optree-0.10.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ac2c0fa383f504f03887a0c0ffcb6a4187c43c8c99c32f52ff14e7eae2c8c69b"}, {file = "optree-0.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8fa16b16203938b7a9caa4603998d0968b408f7f3a1a9f7f84763802daf1cff0"}, @@ -3751,6 +3947,7 @@ version = "4.1.0" description = "An OrderedSet is a custom MutableSet that remembers its order, so that every" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"}, {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"}, @@ -3765,6 +3962,7 @@ version = "7.4.0" description = "A decorator to automatically detect mismatch when overriding a method." optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "overrides-7.4.0-py3-none-any.whl", hash = "sha256:3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d"}, {file = "overrides-7.4.0.tar.gz", hash = "sha256:9502a3cca51f4fac40b5feca985b6703a5c1f6ad815588a7ca9e285b9dca6757"}, @@ -3776,6 +3974,7 @@ version = "23.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, @@ -3787,6 +3986,7 @@ version = "2.1.0" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "pandas-2.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:40dd20439ff94f1b2ed55b393ecee9cb6f3b08104c2c40b0cb7186a2f0046242"}, {file = "pandas-2.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d4f38e4fedeba580285eaac7ede4f686c6701a9e618d8a857b138a126d067f2f"}, @@ -3845,6 +4045,7 @@ version = "1.5.0" description = "Utilities for writing pandoc filters in python" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"}, {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"}, @@ -3856,6 +4057,7 @@ version = "0.8.3" description = "A Python Parser" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, @@ -3871,6 +4073,7 @@ version = "0.2.1" description = "Bring colors to your terminal." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364"}, {file = "pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d"}, @@ -3882,6 +4085,7 @@ version = "0.11.2" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, @@ -3893,6 +4097,7 @@ version = "0.1.2" description = "File system general utilities" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"}, ] @@ -3903,6 +4108,8 @@ version = "0.5.6" description = "A Python package for describing statistical models and for building design matrices." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "patsy-0.5.6-py2.py3-none-any.whl", hash = "sha256:19056886fd8fa71863fa32f0eb090267f21fb74be00f19f5c70b2e9d76c883c6"}, {file = "patsy-0.5.6.tar.gz", hash = "sha256:95c6d47a7222535f84bff7f63d7303f2e297747a598db89cf5c67f0c0c7d2cdb"}, @@ -3921,6 +4128,7 @@ version = "1.24.2" description = "Gymnasium for multi-agent reinforcement learning." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pettingzoo-1.24.2-py3-none-any.whl", hash = "sha256:00268cf990d243654c2bbbbf8c88322c12b041dc0a879b74747f14ee8aa93dd6"}, {file = "pettingzoo-1.24.2.tar.gz", hash = "sha256:0a5856d47de78ab20feddfdac4940959dc892f6becc92107247b1c3a210c0984"}, @@ -3946,6 +4154,8 @@ version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, @@ -3960,6 +4170,7 @@ version = "10.2.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pillow-10.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:7823bdd049099efa16e4246bdf15e5a13dbb18a51b68fa06d6c1d4d8b99a796e"}, {file = "pillow-10.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:83b2021f2ade7d1ed556bc50a399127d7fb245e725aa0113ebd05cfe88aaf588"}, @@ -4036,7 +4247,7 @@ docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] -typing = ["typing-extensions"] +typing = ["typing-extensions ; python_version < \"3.10\""] xmp = ["defusedxml"] [[package]] @@ -4045,6 +4256,8 @@ version = "2.6.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform == \"win32\"" files = [ {file = "platformdirs-2.6.2-py3-none-any.whl", hash = "sha256:83c8f6d04389165de7c9b6f0c682439697887bca0aa2f1c87ef1826be3584490"}, {file = "platformdirs-2.6.2.tar.gz", hash = "sha256:e1fea1fe471b9ff8332e229df3cb7de4f53eeea4998d3b6bfff542115e998bd2"}, @@ -4060,6 +4273,8 @@ version = "3.11.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"}, {file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"}, @@ -4075,6 +4290,7 @@ version = "1.3.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, @@ -4090,6 +4306,7 @@ version = "0.20.0" description = "A task runner that works well with poetry." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "poethepoet-0.20.0-py3-none-any.whl", hash = "sha256:cb37be15f3895ccc65ddf188c2e3d8fb79e26cc9d469a6098cb1c6f994659f6f"}, {file = "poethepoet-0.20.0.tar.gz", hash = "sha256:ca5a2a955f52dfb0a53fad3c989ef0b69ce3d5ec0f6bfa9b1da1f9e32d262e20"}, @@ -4108,6 +4325,7 @@ version = "3.5.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pre_commit-3.5.0-py2.py3-none-any.whl", hash = "sha256:841dc9aef25daba9a0238cd27984041fa0467b4199fc4852e27950664919f660"}, {file = "pre_commit-3.5.0.tar.gz", hash = "sha256:5804465c675b659b0862f07907f96295d490822a450c4c40e747d0b1c6ebcb32"}, @@ -4126,6 +4344,7 @@ version = "0.18.0" description = "Python client for the Prometheus monitoring system." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "prometheus_client-0.18.0-py3-none-any.whl", hash = "sha256:8de3ae2755f890826f4b6479e5571d4f74ac17a81345fe69a6778fdb92579184"}, {file = "prometheus_client-0.18.0.tar.gz", hash = "sha256:35f7a8c22139e2bb7ca5a698e92d38145bc8dc74c1c0bf56f25cca886a764e17"}, @@ -4140,6 +4359,7 @@ version = "2.3" description = "Promises/A+ implementation for Python" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0"}, ] @@ -4156,6 +4376,7 @@ version = "3.0.41" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" +groups = ["dev"] files = [ {file = "prompt_toolkit-3.0.41-py3-none-any.whl", hash = "sha256:f36fe301fafb7470e86aaf90f036eef600a3210be4decf461a5b1ca8403d3cb2"}, {file = "prompt_toolkit-3.0.41.tar.gz", hash = "sha256:941367d97fc815548822aa26c2a269fdc4eb21e9ec05fc5d447cf09bad5d75f0"}, @@ -4170,6 +4391,8 @@ version = "1.6.4" description = "A decorator for caching properties in classes (forked from cached-property)." optional = true python-versions = ">= 3.5" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "property-cached-1.6.4.zip", hash = "sha256:3e9c4ef1ed3653909147510481d7df62a3cfb483461a6986a6f1dcd09b2ebb73"}, {file = "property_cached-1.6.4-py2.py3-none-any.whl", hash = "sha256:135fc059ec969c1646424a0db15e7fbe1b5f8c36c0006d0b3c91ba568c11e7d8"}, @@ -4181,6 +4404,7 @@ version = "3.20.3" description = "Protocol Buffers" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "protobuf-3.20.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99"}, {file = "protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e"}, @@ -4212,6 +4436,7 @@ version = "5.9.6" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +groups = ["dev"] files = [ {file = "psutil-5.9.6-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d"}, {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c"}, @@ -4232,7 +4457,7 @@ files = [ ] [package.extras] -test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +test = ["enum34 ; python_version <= \"3.4\"", "ipaddress ; python_version < \"3.0\"", "mock ; python_version < \"3.0\"", "pywin32 ; sys_platform == \"win32\"", "wmi ; sys_platform == \"win32\""] [[package]] name = "ptyprocess" @@ -4240,6 +4465,8 @@ version = "0.7.0" description = "Run a subprocess in a pseudo terminal" optional = false python-versions = "*" +groups = ["dev"] +markers = "os_name != \"nt\" or sys_platform != \"win32\"" files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, @@ -4251,6 +4478,7 @@ version = "0.2.2" description = "Safely evaluate AST nodes without side effects" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, @@ -4265,6 +4493,7 @@ version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["main"] files = [ {file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"}, {file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"}, @@ -4276,6 +4505,7 @@ version = "0.3.0" description = "A collection of ASN.1-based protocols modules" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["main"] files = [ {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"}, @@ -4290,6 +4520,7 @@ version = "0.24.0" description = "A BibTeX-compatible bibliography processor in Python" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*" +groups = ["dev"] files = [ {file = "pybtex-0.24.0-py2.py3-none-any.whl", hash = "sha256:e1e0c8c69998452fea90e9179aa2a98ab103f3eed894405b7264e517cc2fcc0f"}, {file = "pybtex-0.24.0.tar.gz", hash = "sha256:818eae35b61733e5c007c3fcd2cfb75ed1bc8b4173c1f70b56cc4c0802d34755"}, @@ -4309,6 +4540,7 @@ version = "1.0.3" description = "A docutils backend for pybtex." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pybtex-docutils-1.0.3.tar.gz", hash = "sha256:3a7ebdf92b593e00e8c1c538aa9a20bca5d92d84231124715acc964d51d93c6b"}, {file = "pybtex_docutils-1.0.3-py3-none-any.whl", hash = "sha256:8fd290d2ae48e32fcb54d86b0efb8d573198653c7e2447d5bec5847095f430b9"}, @@ -4324,6 +4556,8 @@ version = "3.2.5" description = "Official Python Interface for the Bullet Physics SDK specialized for Robotics Simulation and Reinforcement Learning" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"pybullet\"" files = [ {file = "pybullet-3.2.5-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:4970aec0dd968924f6b1820655a20f80650da2f85ba38b641937c9701a8a2b14"}, {file = "pybullet-3.2.5-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b64e4523a11d03729035e0a5baa0ce4d2ca58de8d0a242c0b91e8253781b24c4"}, @@ -4341,6 +4575,7 @@ version = "2.11.1" description = "Python style guide checker" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pycodestyle-2.11.1-py2.py3-none-any.whl", hash = "sha256:44fe31000b2d866f2e41841b18528a505fbd7fef9017b04eff4e2648a0fadc67"}, {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, @@ -4352,6 +4587,7 @@ version = "2.21" description = "C parser in Python" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["dev"] files = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, @@ -4363,6 +4599,7 @@ version = "0.14.3" description = "Bootstrap-based Sphinx theme from the PyData community" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pydata_sphinx_theme-0.14.3-py3-none-any.whl", hash = "sha256:b7e40cd75a20449adfe2d7525be379b9fe92f6d31e5233e449fa34ddcd4398d9"}, {file = "pydata_sphinx_theme-0.14.3.tar.gz", hash = "sha256:bd474f347895f3fc5b6ce87390af64330ee54f11ebf9660d5bc3f87d532d4e5c"}, @@ -4390,6 +4627,7 @@ version = "3.2.2" description = "Python bindings for the Enchant spellchecking system" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "pyenchant-3.2.2-py3-none-any.whl", hash = "sha256:5facc821ece957208a81423af7d6ec7810dad29697cb0d77aae81e4e11c8e5a6"}, {file = "pyenchant-3.2.2-py3-none-win32.whl", hash = "sha256:5a636832987eaf26efe971968f4d1b78e81f62bca2bde0a9da210c7de43c3bce"}, @@ -4403,6 +4641,7 @@ version = "2.5.2" description = "Python Game Development" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "pygame-2.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a0769eb628c818761755eb0a0ca8216b95270ea8cbcbc82227e39ac9644643da"}, {file = "pygame-2.5.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed9a3d98adafa0805ccbaaff5d2996a2b5795381285d8437a4a5d248dbd12b4a"}, @@ -4469,13 +4708,14 @@ version = "2.17.1" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pygments-2.17.1-py3-none-any.whl", hash = "sha256:1b37f1b1e1bff2af52ecaf28cc601e2ef7077000b227a0675da25aef85784bc4"}, {file = "pygments-2.17.1.tar.gz", hash = "sha256:e45a0e74bf9c530f564ca81b8952343be986a29f6afe7f5ad95c5f06b7bdf5e8"}, ] [package.extras] -plugins = ["importlib-metadata"] +plugins = ["importlib-metadata ; python_version < \"3.8\""] windows-terminal = ["colorama (>=0.4.6)"] [[package]] @@ -4484,6 +4724,7 @@ version = "6.6.0" description = "Pymunk is a easy-to-use pythonic 2d physics library" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pymunk-6.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6da50dd97683337a290110d594fad07a75153d2d837b570ef972478d739c33f8"}, {file = "pymunk-6.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bcd7d16a2b4d51d45d6780a701f65c8d5b36fdf545c3f4738910da41e2a9c4ee"}, @@ -4555,6 +4796,8 @@ version = "3.1.7" description = "Standard OpenGL bindings for Python" optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"robotics\" or extra == \"mujoco\"" files = [ {file = "PyOpenGL-3.1.7-py3-none-any.whl", hash = "sha256:a6ab19cf290df6101aaf7470843a9c46207789855746399d0af92521a0a92b7a"}, {file = "PyOpenGL-3.1.7.tar.gz", hash = "sha256:eef31a3888e6984fd4d8e6c9961b184c9813ca82604d37fe3da80eb000a76c86"}, @@ -4566,6 +4809,7 @@ version = "3.1.2" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.6.8" +groups = ["main"] files = [ {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, @@ -4580,6 +4824,7 @@ version = "7.4.3" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, @@ -4600,6 +4845,7 @@ version = "4.1.0" description = "Pytest plugin for measuring coverage." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, @@ -4618,6 +4864,7 @@ version = "2.8.2" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, @@ -4632,6 +4879,7 @@ version = "2.0.7" description = "A python library adding a json log formatter" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "python-json-logger-2.0.7.tar.gz", hash = "sha256:23e7ec02d34237c5aa1e29a070193a4ea87583bb4e7f8fd06d3de8264c4b2e1c"}, {file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"}, @@ -4643,6 +4891,7 @@ version = "2024.1" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, @@ -4654,6 +4903,8 @@ version = "306" description = "Python for Window Extensions" optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\"" files = [ {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, @@ -4677,6 +4928,8 @@ version = "2.0.12" description = "Pseudo terminal support for Windows from Python." optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "os_name == \"nt\"" files = [ {file = "pywinpty-2.0.12-cp310-none-win_amd64.whl", hash = "sha256:21319cd1d7c8844fb2c970fb3a55a3db5543f112ff9cfcd623746b9c47501575"}, {file = "pywinpty-2.0.12-cp311-none-win_amd64.whl", hash = "sha256:853985a8f48f4731a716653170cd735da36ffbdc79dcb4c7b7140bce11d8c722"}, @@ -4692,6 +4945,7 @@ version = "6.0.1" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] files = [ {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, @@ -4745,6 +4999,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +markers = {main = "extra == \"argparse\" or extra == \"eval\""} [[package]] name = "pyzmq" @@ -4752,6 +5007,7 @@ version = "25.1.1" description = "Python bindings for 0MQ" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"}, {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"}, @@ -4857,6 +5113,7 @@ version = "5.5.1" description = "Jupyter Qt console" optional = false python-versions = ">= 3.8" +groups = ["dev"] files = [ {file = "qtconsole-5.5.1-py3-none-any.whl", hash = "sha256:8c75fa3e9b4ed884880ff7cea90a1b67451219279ec33deaee1d59e3df1a5d2b"}, {file = "qtconsole-5.5.1.tar.gz", hash = "sha256:a0e806c6951db9490628e4df80caec9669b65149c7ba40f9bf033c025a5b56bc"}, @@ -4882,6 +5139,7 @@ version = "2.4.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "QtPy-2.4.1-py3-none-any.whl", hash = "sha256:1c1d8c4fa2c884ae742b069151b0abe15b3f70491f3972698c683b8e38de839b"}, {file = "QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"}, @@ -4899,6 +5157,8 @@ version = "2.8.0" description = "Ray provides a simple, universal API for building distributed applications." optional = false python-versions = "*" +groups = ["dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "ray-2.8.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:34e0676a0dfa277efa688bccd83ecb7a799bc03078e5b1f1aa747fe9263175a8"}, {file = "ray-2.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:72c696c1b784c55f0ad107d55bb58ecef5d368176765cf44fed87e714538d708"}, @@ -4936,16 +5196,16 @@ pyyaml = "*" requests = "*" [package.extras] -air = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -all = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "dm-tree", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (!=1.56.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "gymnasium (==0.28.1)", "lz4", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml", "ray-cpp (==2.8.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +air = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +all = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "dm-tree", "fastapi", "fsspec", "gpustat (>=1.0.0)", "grpcio (!=1.56.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "gymnasium (==0.28.1)", "lz4", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml", "ray-cpp (==2.8.0)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] client = ["grpcio (!=1.56.0)"] cpp = ["ray-cpp (==2.8.0)"] data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (>=6.0.1)"] -default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "virtualenv (>=20.0.24,<20.21.1)"] +default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "virtualenv (>=20.0.24,<20.21.1)"] observability = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"] rllib = ["dm-tree", "fsspec", "gymnasium (==0.28.1)", "lz4", "pandas", "pyarrow (>=6.0.1)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "tensorboardX (>=1.9)", "typer"] -serve = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] -serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +serve = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] +serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi", "gpustat (>=1.0.0)", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,<20.21.1)", "watchfiles"] train = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] tune = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] @@ -4955,6 +5215,7 @@ version = "0.31.0" description = "JSON Referencing + Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "referencing-0.31.0-py3-none-any.whl", hash = "sha256:381b11e53dd93babb55696c71cf42aef2d36b8a150c49bf0bc301e36d536c882"}, {file = "referencing-0.31.0.tar.gz", hash = "sha256:cc28f2c88fbe7b961a7817a0abc034c09a1e36358f82fedb4ffdf29a25398863"}, @@ -4970,6 +5231,7 @@ version = "2.32.0" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "requests-2.32.0-py3-none-any.whl", hash = "sha256:f2c3881dddb70d056c5bd7600a4fae312b2a300e39be6a118d30b90bd27262b5"}, {file = "requests-2.32.0.tar.gz", hash = "sha256:fa5490319474c82ef1d2c9bc459d3652e3ae4ef4c4ebdd18a21145a47ca4b6b8"}, @@ -4991,6 +5253,7 @@ version = "1.3.1" description = "OAuthlib authentication support for Requests." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +groups = ["main"] files = [ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, @@ -5009,6 +5272,7 @@ version = "0.1.4" description = "A pure python RFC3339 validator" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "rfc3339_validator-0.1.4-py2.py3-none-any.whl", hash = "sha256:24f6ec1eda14ef823da9e36ec7113124b39c04d50a4d3d3a3c2859577e7791fa"}, {file = "rfc3339_validator-0.1.4.tar.gz", hash = "sha256:138a2abdf93304ad60530167e51d2dfb9549521a836871b88d7f4695d0022f6b"}, @@ -5023,6 +5287,7 @@ version = "0.1.1" description = "Pure python rfc3986 validator" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9"}, {file = "rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055"}, @@ -5034,6 +5299,8 @@ version = "1.2.0" description = "rliable: Reliable evaluation on reinforcement learning and machine learning benchmarks." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "rliable-1.2.0.tar.gz", hash = "sha256:72789d9147d7c56e6efa812f9dffedcef44993a866ec08d75506ac7c1fe69cd5"}, ] @@ -5052,6 +5319,7 @@ version = "0.13.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "rpds_py-0.13.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:1758197cc8d7ff383c07405f188253535b4aa7fa745cbc54d221ae84b18e0702"}, {file = "rpds_py-0.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:715df74cbcef4387d623c917f295352127f4b3e0388038d68fa577b4e4c6e540"}, @@ -5160,6 +5428,7 @@ version = "4.9" description = "Pure-Python RSA implementation" optional = false python-versions = ">=3.6,<4" +groups = ["main"] files = [ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, @@ -5174,6 +5443,7 @@ version = "0.18.5" description = "ruamel.yaml is a YAML parser/emitter that supports roundtrip preservation of comments, seq/map flow style, and map key order" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ruamel.yaml-0.18.5-py3-none-any.whl", hash = "sha256:a013ac02f99a69cdd6277d9664689eb1acba07069f912823177c5eced21a6ada"}, {file = "ruamel.yaml-0.18.5.tar.gz", hash = "sha256:61917e3a35a569c1133a8f772e1226961bf5a1198bea7e23f06a0841dea1ab0e"}, @@ -5192,6 +5462,8 @@ version = "0.2.8" description = "C version of reader, parser and emitter for ruamel.yaml derived from libyaml" optional = false python-versions = ">=3.6" +groups = ["dev"] +markers = "platform_python_implementation == \"CPython\" and python_version < \"3.13\"" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, @@ -5251,6 +5523,7 @@ version = "0.0.285" description = "An extremely fast Python linter, written in Rust." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ruff-0.0.285-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:72a3a0936369b986b0e959f9090206ed3c18f9e5e439ea5b8e6867c6707aded5"}, {file = "ruff-0.0.285-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0d9ab6ad16742eb78919e0fba09f914f042409df40ad63423c34bb20d350162a"}, @@ -5277,6 +5550,7 @@ version = "1.11.4" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.9" +groups = ["main", "dev"] files = [ {file = "scipy-1.11.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc9a714581f561af0848e6b69947fda0614915f072dfd14142ed1bfe1b806710"}, {file = "scipy-1.11.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cf00bd2b1b0211888d4dc75656c0412213a8b25e80d73898083f402b50f47e41"}, @@ -5319,6 +5593,8 @@ version = "0.13.2" description = "Statistical data visualization" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"}, {file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"}, @@ -5340,15 +5616,16 @@ version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" +groups = ["dev"] files = [ {file = "Send2Trash-1.8.2-py3-none-any.whl", hash = "sha256:a384719d99c07ce1eefd6905d2decb6f8b7ed054025bb0e618919f945de4f679"}, {file = "Send2Trash-1.8.2.tar.gz", hash = "sha256:c132d59fa44b9ca2b1699af5c86f57ce9f4c5eb56629d5d55fbb7a35f84e2312"}, ] [package.extras] -nativelib = ["pyobjc-framework-Cocoa", "pywin32"] -objc = ["pyobjc-framework-Cocoa"] -win32 = ["pywin32"] +nativelib = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\"", "pywin32 ; sys_platform == \"win32\""] +objc = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\""] +win32 = ["pywin32 ; sys_platform == \"win32\""] [[package]] name = "sensai-utils" @@ -5356,6 +5633,7 @@ version = "1.4.0" description = "Utilities from sensAI, the Python library for sensible AI" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "sensai_utils-1.4.0-py3-none-any.whl", hash = "sha256:ed6fc57552620e43b33cf364ea0bc0fd7df39391069dd7b621b113ef55547507"}, {file = "sensai_utils-1.4.0.tar.gz", hash = "sha256:2d32bdcc91fd1428c5cae0181e98623142d2d5f7e115e23d585a842dd9dc59ba"}, @@ -5370,6 +5648,7 @@ version = "2.8.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sentry_sdk-2.8.0-py2.py3-none-any.whl", hash = "sha256:6051562d2cfa8087bb8b4b8b79dc44690f8a054762a29c07e22588b1f619bfb5"}, {file = "sentry_sdk-2.8.0.tar.gz", hash = "sha256:aa4314f877d9cd9add5a0c9ba18e3f27f99f7de835ce36bd150e48a41c7c646f"}, @@ -5420,6 +5699,7 @@ version = "1.3.3" description = "A Python module to customize the process title" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:897a73208da48db41e687225f355ce993167079eda1260ba5e13c4e53be7f754"}, {file = "setproctitle-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8c331e91a14ba4076f88c29c777ad6b58639530ed5b24b5564b5ed2fd7a95452"}, @@ -5520,6 +5800,7 @@ version = "68.2.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "setuptools-68.2.2-py3-none-any.whl", hash = "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a"}, {file = "setuptools-68.2.2.tar.gz", hash = "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87"}, @@ -5527,7 +5808,7 @@ files = [ [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7) ; platform_python_implementation != \"PyPy\"", "pytest-checkdocs (>=2.4)", "pytest-cov ; platform_python_implementation != \"PyPy\"", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1) ; platform_python_implementation != \"PyPy\"", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-ruff ; sys_platform != \"cygwin\"", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] @@ -5536,6 +5817,8 @@ version = "0.2.1" description = "API for converting popular non-gymnasium environments to a gymnasium compatible environment." optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "extra == \"atari\"" files = [ {file = "Shimmy-0.2.1-py3-none-any.whl", hash = "sha256:2d7d21c4ca679a64bb452e6a4232c6b0f5dba7589f5420454ddc1f0634334334"}, {file = "Shimmy-0.2.1.tar.gz", hash = "sha256:7b96915445ee5488dcb19ccf52ce5581d6f00cc5cf0e0dff36b16cd65bffcb75"}, @@ -5560,6 +5843,7 @@ version = "1.0.11" description = "A generator library for concise, unambiguous and URL-safe UUIDs." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "shortuuid-1.0.11-py3-none-any.whl", hash = "sha256:27ea8f28b1bd0bf8f15057a3ece57275d2059d2b0bb02854f02189962c13b6aa"}, {file = "shortuuid-1.0.11.tar.gz", hash = "sha256:fc75f2615914815a8e4cb1501b3a513745cb66ef0fd5fc6fb9f8c3fa3481f789"}, @@ -5571,6 +5855,7 @@ version = "1.16.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +groups = ["main", "dev"] files = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, @@ -5582,6 +5867,7 @@ version = "5.0.1" description = "A pure Python implementation of a sliding window memory map manager" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"}, {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, @@ -5593,6 +5879,7 @@ version = "1.3.0" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, @@ -5604,6 +5891,7 @@ version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"}, {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, @@ -5615,6 +5903,7 @@ version = "2.5" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, @@ -5626,6 +5915,7 @@ version = "7.2.6" description = "Python documentation generator" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx-7.2.6-py3-none-any.whl", hash = "sha256:1e09160a40b956dc623c910118fa636da93bd3ca0b9876a7b3df90f07d691560"}, {file = "sphinx-7.2.6.tar.gz", hash = "sha256:9a5160e1ea90688d5963ba09a2dcd8bdd526620edbb65c328728f1b2228d5ab5"}, @@ -5660,6 +5950,7 @@ version = "1.19.1" description = "Type hints (PEP 484) support for the Sphinx autodoc extension" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx_autodoc_typehints-1.19.1-py3-none-any.whl", hash = "sha256:9be46aeeb1b315eb5df1f3a7cb262149895d16c7d7dcd77b92513c3c3a1e85e6"}, {file = "sphinx_autodoc_typehints-1.19.1.tar.gz", hash = "sha256:6c841db55e0e9be0483ff3962a2152b60e79306f4288d8c4e7e86ac84486a5ea"}, @@ -5670,7 +5961,7 @@ Sphinx = ">=4.5" [package.extras] testing = ["covdefaults (>=2.2)", "coverage (>=6.3)", "diff-cover (>=6.4)", "nptyping (>=2.1.2)", "pytest (>=7.1)", "pytest-cov (>=3)", "sphobjinv (>=2)", "typing-extensions (>=4.1)"] -type-comments = ["typed-ast (>=1.5.2)"] +type-comments = ["typed-ast (>=1.5.2) ; python_version < \"3.8\""] [[package]] name = "sphinx-book-theme" @@ -5678,6 +5969,7 @@ version = "1.1.0" description = "A clean book theme for scientific explanations and documentation with Sphinx" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_book_theme-1.1.0-py3-none-any.whl", hash = "sha256:088bc69d65fab8446adb8691ed61687f71bf7504c9740af68bc78cf936a26112"}, {file = "sphinx_book_theme-1.1.0.tar.gz", hash = "sha256:ad4f92998e53e24751ecd0978d3eb79fdaa59692f005b1b286ecdd6146ebc9c1"}, @@ -5698,6 +5990,7 @@ version = "0.0.3" description = "Add comments and annotation to your documentation." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-comments-0.0.3.tar.gz", hash = "sha256:00170afff27019fad08e421da1ae49c681831fb2759786f07c826e89ac94cf21"}, {file = "sphinx_comments-0.0.3-py3-none-any.whl", hash = "sha256:1e879b4e9bfa641467f83e3441ac4629225fc57c29995177d043252530c21d00"}, @@ -5717,6 +6010,7 @@ version = "0.5.2" description = "Add a copy button to each of your code cells." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx-copybutton-0.5.2.tar.gz", hash = "sha256:4cf17c82fb9646d1bc9ca92ac280813a3b605d8c421225fd9913154103ee1fbd"}, {file = "sphinx_copybutton-0.5.2-py3-none-any.whl", hash = "sha256:fb543fd386d917746c9a2c50360c7905b605726b9355cd26e9974857afeae06e"}, @@ -5735,6 +6029,7 @@ version = "0.5.0" description = "A sphinx extension for designing beautiful, view size responsive web components." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "sphinx_design-0.5.0-py3-none-any.whl", hash = "sha256:1af1267b4cea2eedd6724614f19dcc88fe2e15aff65d06b2f6252cee9c4f4c1e"}, {file = "sphinx_design-0.5.0.tar.gz", hash = "sha256:e8e513acea6f92d15c6de3b34e954458f245b8e761b45b63950f65373352ab00"}, @@ -5758,6 +6053,7 @@ version = "1.0.1" description = "A sphinx extension that allows the site-map to be defined in a single YAML file." optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_external_toc-1.0.1-py3-none-any.whl", hash = "sha256:d9e02d50731dee9697c1887e4f8b361e7b86d38241f0e66bd5a9f4096779646f"}, {file = "sphinx_external_toc-1.0.1.tar.gz", hash = "sha256:a7d2c63cc47ec688546443b28bc4ef466121827ef3dc7bb509de354bad4ea2e0"}, @@ -5779,6 +6075,7 @@ version = "0.2.0.post1" description = "Patches Jinja2 v3 to restore compatibility with earlier Sphinx versions." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sphinx_jinja2_compat-0.2.0.post1-py3-none-any.whl", hash = "sha256:f9d329174bdde8db19dc12c62528367196eb2f6b46c91754eca604acd0c0f6ad"}, {file = "sphinx_jinja2_compat-0.2.0.post1.tar.gz", hash = "sha256:974289a12a9f402108dead621e9c15f7004e945d5cfcaea8d6419e94d3fa95a3"}, @@ -5794,6 +6091,7 @@ version = "1.0.0" description = "Latex specific features for jupyter book" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinx_jupyterbook_latex-1.0.0-py3-none-any.whl", hash = "sha256:e0cd3e9e1c5af69136434e21a533343fdf013475c410a414d5b7b4922b4f3891"}, {file = "sphinx_jupyterbook_latex-1.0.0.tar.gz", hash = "sha256:f54c6674c13f1616f9a93443e98b9b5353f9fdda8e39b6ec552ccf0b3e5ffb62"}, @@ -5815,6 +6113,7 @@ version = "0.1.3" description = "Supporting continuous HTML section numbering" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-multitoc-numbering-0.1.3.tar.gz", hash = "sha256:c9607671ac511236fa5d61a7491c1031e700e8d498c9d2418e6c61d1251209ae"}, {file = "sphinx_multitoc_numbering-0.1.3-py3-none-any.whl", hash = "sha256:33d2e707a9b2b8ad636b3d4302e658a008025106fe0474046c651144c26d8514"}, @@ -5834,6 +6133,7 @@ version = "1.5.0" description = "Sphinx directive to add unselectable prompt" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx_prompt-1.5.0-py3-none-any.whl", hash = "sha256:fa4e90d8088b5a996c76087d701fc7e31175f8b9dc4aab03a507e45051067162"}, ] @@ -5848,6 +6148,7 @@ version = "3.4.5" description = "Tabbed views for Sphinx" optional = false python-versions = "~=3.7" +groups = ["dev"] files = [ {file = "sphinx-tabs-3.4.5.tar.gz", hash = "sha256:ba9d0c1e3e37aaadd4b5678449eb08176770e0fc227e769b6ce747df3ceea531"}, {file = "sphinx_tabs-3.4.5-py3-none-any.whl", hash = "sha256:92cc9473e2ecf1828ca3f6617d0efc0aa8acb06b08c56ba29d1413f2f0f6cf09"}, @@ -5868,6 +6169,7 @@ version = "0.3.1" description = "Integrate interactive code blocks into your documentation with Thebe and Binder." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "sphinx_thebe-0.3.1-py3-none-any.whl", hash = "sha256:e7e7edee9f0d601c76bc70156c471e114939484b111dd8e74fe47ac88baffc52"}, {file = "sphinx_thebe-0.3.1.tar.gz", hash = "sha256:576047f45560e82f64aa5f15200b1eb094dcfe1c5b8f531a8a65bd208e25a493"}, @@ -5887,6 +6189,7 @@ version = "0.3.2" description = "Toggle page content and collapse admonitions in Sphinx." optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "sphinx-togglebutton-0.3.2.tar.gz", hash = "sha256:ab0c8b366427b01e4c89802d5d078472c427fa6e9d12d521c34fa0442559dc7a"}, {file = "sphinx_togglebutton-0.3.2-py3-none-any.whl", hash = "sha256:9647ba7874b7d1e2d43413d8497153a85edc6ac95a3fea9a75ef9c1e08aaae2b"}, @@ -5907,6 +6210,7 @@ version = "3.5.0" description = "Box of handy tools for Sphinx 🧰 📔" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinx_toolbox-3.5.0-py3-none-any.whl", hash = "sha256:20dfd3566717db6f2da7a400a54dc4b946f064fb31250fa44802d54cfb9b8a03"}, {file = "sphinx_toolbox-3.5.0.tar.gz", hash = "sha256:e5b5a7153f1997572d71a06aaf6cec225483492ec2c60097a84f15aad6df18b7"}, @@ -5941,6 +6245,7 @@ version = "1.0.7" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_applehelp-1.0.7-py3-none-any.whl", hash = "sha256:094c4d56209d1734e7d252f6e0b3ccc090bd52ee56807a5d9315b19c122ab15d"}, {file = "sphinxcontrib_applehelp-1.0.7.tar.gz", hash = "sha256:39fdc8d762d33b01a7d8f026a3b7d71563ea3b72787d5f00ad8465bd9d6dfbfa"}, @@ -5959,6 +6264,7 @@ version = "2.5.0" description = "Sphinx extension for BibTeX style citations." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "sphinxcontrib-bibtex-2.5.0.tar.gz", hash = "sha256:71b42e5db0e2e284f243875326bf9936aa9a763282277d75048826fef5b00eaa"}, {file = "sphinxcontrib_bibtex-2.5.0-py3-none-any.whl", hash = "sha256:748f726eaca6efff7731012103417ef130ecdcc09501b4d0c54283bf5f059f76"}, @@ -5976,6 +6282,7 @@ version = "1.0.5" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp documents" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_devhelp-1.0.5-py3-none-any.whl", hash = "sha256:fe8009aed765188f08fcaadbb3ea0d90ce8ae2d76710b7e29ea7d047177dae2f"}, {file = "sphinxcontrib_devhelp-1.0.5.tar.gz", hash = "sha256:63b41e0d38207ca40ebbeabcf4d8e51f76c03e78cd61abe118cf4435c73d4212"}, @@ -5994,6 +6301,7 @@ version = "2.0.4" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_htmlhelp-2.0.4-py3-none-any.whl", hash = "sha256:8001661c077a73c29beaf4a79968d0726103c5605e27db92b9ebed8bab1359e9"}, {file = "sphinxcontrib_htmlhelp-2.0.4.tar.gz", hash = "sha256:6c26a118a05b76000738429b724a0568dbde5b72391a688577da08f11891092a"}, @@ -6012,6 +6320,7 @@ version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8"}, {file = "sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178"}, @@ -6026,6 +6335,7 @@ version = "1.0.6" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp documents" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_qthelp-1.0.6-py3-none-any.whl", hash = "sha256:bf76886ee7470b934e363da7a954ea2825650013d367728588732c7350f49ea4"}, {file = "sphinxcontrib_qthelp-1.0.6.tar.gz", hash = "sha256:62b9d1a186ab7f5ee3356d906f648cacb7a6bdb94d201ee7adf26db55092982d"}, @@ -6044,6 +6354,7 @@ version = "1.1.9" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)" optional = false python-versions = ">=3.9" +groups = ["dev"] files = [ {file = "sphinxcontrib_serializinghtml-1.1.9-py3-none-any.whl", hash = "sha256:9b36e503703ff04f20e9675771df105e58aa029cfcbc23b8ed716019b7416ae1"}, {file = "sphinxcontrib_serializinghtml-1.1.9.tar.gz", hash = "sha256:0c64ff898339e1fac29abd2bf5f11078f3ec413cfe9c046d3120d7ca65530b54"}, @@ -6062,6 +6373,7 @@ version = "8.0.0" description = "Sphinx spelling extension" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "sphinxcontrib-spelling-8.0.0.tar.gz", hash = "sha256:199d0a16902ad80c387c2966dc9eb10f565b1fb15ccce17210402db7c2443e5c"}, {file = "sphinxcontrib_spelling-8.0.0-py3-none-any.whl", hash = "sha256:b27e0a16aef00bcfc888a6490dc3f16651f901dc475446c6882834278c8dc7b3"}, @@ -6080,6 +6392,7 @@ version = "2.0.23" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"}, {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"}, @@ -6167,6 +6480,7 @@ version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, @@ -6186,6 +6500,8 @@ version = "0.14.0" description = "Statistical computations and models for Python" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"eval\"" files = [ {file = "statsmodels-0.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16bfe0c96a53b20fa19067e3b6bd2f1d39e30d4891ea0d7bc20734a0ae95942d"}, {file = "statsmodels-0.14.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5a6a0a1a06ff79be8aa89c8494b33903442859add133f0dda1daf37c3c71682e"}, @@ -6227,7 +6543,7 @@ scipy = ">=1.4,<1.9.2 || >1.9.2" [package.extras] build = ["cython (>=0.29.26)"] -develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"] +develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty ; os_name == \"nt\"", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"] docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"] [[package]] @@ -6236,6 +6552,8 @@ version = "4.2.0" description = "SWIG is a software development tool that connects programs written in C and C++ with a variety of high-level programming languages." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"box2d\"" files = [ {file = "swig-4.2.0-py2.py3-none-macosx_10_9_universal2.whl", hash = "sha256:71bf282fb30aa179b870e29c8f4fe16b3404e8562377061f85d57a2ec1571d7c"}, {file = "swig-4.2.0-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:071c7a3af61c2c69d1e911c5428479a4536a8103623276847d8e55350da8cf05"}, @@ -6261,6 +6579,7 @@ version = "1.12" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, @@ -6275,6 +6594,7 @@ version = "0.9.0" description = "Pretty-print tabular data" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, @@ -6289,6 +6609,7 @@ version = "2.15.1" description = "TensorBoard lets you watch Tensors Flow" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "tensorboard-2.15.1-py3-none-any.whl", hash = "sha256:c46c1d1cf13a458c429868a78b2531d8ff5f682058d69ec0840b0bc7a38f1c0f"}, ] @@ -6313,6 +6634,7 @@ version = "0.7.2" description = "Fast data loading for TensorBoard" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, @@ -6325,6 +6647,7 @@ version = "0.18.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "terminado-0.18.0-py3-none-any.whl", hash = "sha256:87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e"}, {file = "terminado-0.18.0.tar.gz", hash = "sha256:1ea08a89b835dd1b8c0c900d92848147cef2537243361b2e3f4dc15df9b6fded"}, @@ -6346,6 +6669,7 @@ version = "1.2.1" description = "A tiny CSS parser" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tinycss2-1.2.1-py3-none-any.whl", hash = "sha256:2b80a96d41e7c3914b8cda8bc7f705a4d9c49275616e886103dd839dfc847847"}, {file = "tinycss2-1.2.1.tar.gz", hash = "sha256:8cff3a8f066c2ec677c06dbc7b45619804a6938478d9d73c284b29d14ecb0627"}, @@ -6364,6 +6688,7 @@ version = "5.2.0" description = "A wrapper around the stdlib `tokenize` which roundtrips." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tokenize_rt-5.2.0-py2.py3-none-any.whl", hash = "sha256:b79d41a65cfec71285433511b50271b05da3584a1da144a0752e9c621a285289"}, {file = "tokenize_rt-5.2.0.tar.gz", hash = "sha256:9fe80f8a5c1edad2d3ede0f37481cc0cc1538a2f442c9c2f9e4feacd2792d054"}, @@ -6375,6 +6700,7 @@ version = "2.0.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, @@ -6386,6 +6712,7 @@ version = "2.1.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" +groups = ["main"] files = [ {file = "torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5ebc43f5355a9b7be813392b3fb0133991f0380f6f0fcc8218d5468dc45d1071"}, {file = "torch-2.1.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:84fefd63356416c0cd20578637ccdbb82164993400ed17b57c951dd6376dcee8"}, @@ -6439,6 +6766,7 @@ version = "6.4.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1"}, {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803"}, @@ -6459,6 +6787,7 @@ version = "4.66.3" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "tqdm-4.66.3-py3-none-any.whl", hash = "sha256:4f41d54107ff9a223dca80b53efe4fb654c67efaba7f47bada3ee9d50e05bd53"}, {file = "tqdm-4.66.3.tar.gz", hash = "sha256:23097a41eba115ba99ecae40d06444c15d1c0c698d527a01c6c8bd1c5d0647e5"}, @@ -6479,6 +6808,7 @@ version = "5.13.0" description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "traitlets-5.13.0-py3-none-any.whl", hash = "sha256:baf991e61542da48fe8aef8b779a9ea0aa38d8a54166ee250d5af5ecf4486619"}, {file = "traitlets-5.13.0.tar.gz", hash = "sha256:9b232b9430c8f57288c1024b34a8f0251ddcc47268927367a0dd3eeaca40deb5"}, @@ -6494,6 +6824,8 @@ version = "2.1.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" +groups = ["main"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" files = [ {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, @@ -6519,6 +6851,8 @@ version = "4.24.0.4" description = "Typing stubs for protobuf" optional = true python-versions = ">=3.7" +groups = ["main"] +markers = "sys_platform != \"darwin\" and extra == \"envpool\"" files = [ {file = "types-protobuf-4.24.0.4.tar.gz", hash = "sha256:57ab42cb171dfdba2c74bb5b50c250478538cc3c5ed95b8b368929ad0c9f90a5"}, {file = "types_protobuf-4.24.0.4-py3-none-any.whl", hash = "sha256:131ab7d0cbc9e444bc89c994141327dcce7bcaeded72b1acb72a94827eb9c7af"}, @@ -6530,6 +6864,7 @@ version = "2.8.19.14" description = "Typing stubs for python-dateutil" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "types-python-dateutil-2.8.19.14.tar.gz", hash = "sha256:1f4f10ac98bb8b16ade9dbee3518d9ace017821d94b057a425b069f834737f4b"}, {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"}, @@ -6541,6 +6876,7 @@ version = "2.31.0.20240311" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "types-requests-2.31.0.20240311.tar.gz", hash = "sha256:b1c1b66abfb7fa79aae09097a811c4aa97130eb8831c60e47aee4ca344731ca5"}, {file = "types_requests-2.31.0.20240311-py3-none-any.whl", hash = "sha256:47872893d65a38e282ee9f277a4ee50d1b28bd592040df7d1fdaffdf3779937d"}, @@ -6555,6 +6891,7 @@ version = "0.9.0.20240106" description = "Typing stubs for tabulate" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "types-tabulate-0.9.0.20240106.tar.gz", hash = "sha256:c9b6db10dd7fcf55bd1712dd3537f86ddce72a08fd62bb1af4338c7096ce947e"}, {file = "types_tabulate-0.9.0.20240106-py3-none-any.whl", hash = "sha256:0378b7b6fe0ccb4986299496d027a6d4c218298ecad67199bbd0e2d7e9d335a1"}, @@ -6566,6 +6903,7 @@ version = "4.8.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, @@ -6577,6 +6915,7 @@ version = "2024.1" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, @@ -6588,6 +6927,7 @@ version = "1.0.2" description = "Micro subset of unicode data files for linkify-it-py projects." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "uc-micro-py-1.0.2.tar.gz", hash = "sha256:30ae2ac9c49f39ac6dce743bd187fcd2b574b16ca095fa74cd9396795c954c54"}, {file = "uc_micro_py-1.0.2-py3-none-any.whl", hash = "sha256:8c9110c309db9d9e87302e2f4ad2c3152770930d88ab385cd544e7a7e75f3de0"}, @@ -6602,6 +6942,7 @@ version = "1.3.0" description = "RFC 6570 URI Template Processor" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "uri-template-1.3.0.tar.gz", hash = "sha256:0e00f8eb65e18c7de20d595a14336e9f337ead580c70934141624b6d1ffdacc7"}, {file = "uri_template-1.3.0-py3-none-any.whl", hash = "sha256:a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363"}, @@ -6616,13 +6957,14 @@ version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -6633,6 +6975,8 @@ version = "20.16.3" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.6" +groups = ["main", "dev"] +markers = "sys_platform == \"win32\"" files = [ {file = "virtualenv-20.16.3-py2.py3-none-any.whl", hash = "sha256:4193b7bc8a6cd23e4eb251ac64f29b4398ab2c233531e66e40b19a6b7b0d30c1"}, {file = "virtualenv-20.16.3.tar.gz", hash = "sha256:d86ea0bb50e06252d79e6c241507cb904fcd66090c3271381372d6221a3970f9"}, @@ -6653,6 +6997,8 @@ version = "20.24.6" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] +markers = "sys_platform != \"win32\"" files = [ {file = "virtualenv-20.24.6-py3-none-any.whl", hash = "sha256:520d056652454c5098a00c0f073611ccbea4c79089331f60bf9d7ba247bb7381"}, {file = "virtualenv-20.24.6.tar.gz", hash = "sha256:02ece4f56fbf939dbbc33c0715159951d6bf14aaf5457b092e4548e1382455af"}, @@ -6665,7 +7011,7 @@ platformdirs = ">=3.9.1,<4" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] [[package]] name = "vizdoom" @@ -6673,6 +7019,8 @@ version = "1.2.2" description = "ViZDoom is Doom-based AI Research Platform for Reinforcement Learning from Raw Visual Information." optional = true python-versions = "*" +groups = ["main"] +markers = "extra == \"vizdoom\"" files = [ {file = "vizdoom-1.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3e2f478e1728702f17b828de0e7ee6bf0e2809c1786ce21f69ce00e4a4da82e0"}, {file = "vizdoom-1.2.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:49180ed13d30109bcd99b38e6b923c5bd74e6bb364add8d46beb5cdf7405fe10"}, @@ -6708,6 +7056,7 @@ version = "0.12.21" description = "A CLI and library for interacting with the Weights and Biases API." optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "wandb-0.12.21-py2.py3-none-any.whl", hash = "sha256:150842447d355d90dc7f368b824951a625e5b2d1be355a00e99b11b73728bc1f"}, {file = "wandb-0.12.21.tar.gz", hash = "sha256:1975ff88c5024923c3321c93cfefb8d9b871543c0b009f34001bf0f31e444b04"}, @@ -6746,6 +7095,7 @@ version = "0.2.10" description = "Measures the displayed width of unicode strings in a terminal" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "wcwidth-0.2.10-py2.py3-none-any.whl", hash = "sha256:aec5179002dd0f0d40c456026e74a729661c9d468e1ed64405e3a6c2176ca36f"}, {file = "wcwidth-0.2.10.tar.gz", hash = "sha256:390c7454101092a6a5e43baad8f83de615463af459201709556b6e4b1c861f97"}, @@ -6757,6 +7107,7 @@ version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "webcolors-1.13-py3-none-any.whl", hash = "sha256:29bc7e8752c0a1bd4a1f03c14d6e6a72e93d82193738fa860cbff59d0fcc11bf"}, {file = "webcolors-1.13.tar.gz", hash = "sha256:c225b674c83fa923be93d235330ce0300373d02885cef23238813b0d5668304a"}, @@ -6772,6 +7123,7 @@ version = "0.5.1" description = "Character encoding aliases for legacy web content" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"}, {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"}, @@ -6783,6 +7135,7 @@ version = "1.6.4" description = "WebSocket client for Python with low level API options" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "websocket-client-1.6.4.tar.gz", hash = "sha256:b3324019b3c28572086c4a319f91d1dcd44e6e11cd340232978c684a7650d0df"}, {file = "websocket_client-1.6.4-py3-none-any.whl", hash = "sha256:084072e0a7f5f347ef2ac3d8698a5e0b4ffbfcab607628cadabc650fc9a83a24"}, @@ -6799,6 +7152,7 @@ version = "3.0.6" description = "The comprehensive WSGI web application library." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, @@ -6816,6 +7170,7 @@ version = "0.41.3" description = "A built-package format for Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "wheel-0.41.3-py3-none-any.whl", hash = "sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942"}, {file = "wheel-0.41.3.tar.gz", hash = "sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841"}, @@ -6830,6 +7185,7 @@ version = "4.0.9" description = "Jupyter interactive widgets for Jupyter Notebook" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"}, {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, @@ -6841,6 +7197,7 @@ version = "3.19.1" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, @@ -6863,6 +7220,6 @@ robotics = ["gymnasium-robotics"] vizdoom = ["vizdoom"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.11" content-hash = "575f58bac92d215908d074f946b8593cbefaf83f965beed396253d8d3f38eea7" diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 559c685e2..aba2c7f1e 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -9,7 +9,7 @@ from tianshou.policy.base import RandomActionPolicy, episode_mc_return_to_go from tianshou.policy.modelfree.pg import ActorPolicyProbabilistic from tianshou.policy.optim import AdamOptimizerFactory -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.net.discrete import DiscreteActor @@ -35,7 +35,7 @@ def algorithm(request: pytest.FixtureRequest) -> PPO: if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) actor = ContinuousActorProbabilistic( - preprocess_net=Net( + preprocess_net=MLPActor( state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape ), action_shape=action_space.shape, @@ -48,7 +48,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: elif action_type == "discrete": action_space = gym.spaces.Discrete(3) actor = DiscreteActor( - preprocess_net=Net( + preprocess_net=MLPActor( state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n ), action_shape=action_space.n, @@ -58,7 +58,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: raise ValueError(f"Unknown action type: {action_type}") critic = ContinuousCritic( - preprocess_net=Net(state_shape=obs_shape, hidden_sizes=[64, 64]), + preprocess_net=MLPActor(state_shape=obs_shape, hidden_sizes=[64, 64]), ) optim = AdamOptimizerFactory(lr=1e-3) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 6c992a165..cc448e5da 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -9,7 +9,7 @@ from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, RunningMeanStd -from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.common import MLP, MLPActor from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.torch_utils import create_uniform_action_dist, torch_train_mode @@ -62,7 +62,7 @@ def test_net() -> None: action_shape = (5,) data = torch.rand([bsz, *state_shape]) expect_output_shape = [bsz, *action_shape] - net = Net( + net = MLPActor( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128], @@ -73,7 +73,7 @@ def test_net() -> None: assert str(net).count("LayerNorm") == 2 assert str(net).count("ReLU") == 0 Q_param = V_param = {"hidden_sizes": [128, 128]} - net = Net( + net = MLPActor( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128], @@ -81,11 +81,13 @@ def test_net() -> None: ) assert list(net(data)[0].shape) == expect_output_shape # concat - net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], concat=True) + net = MLPActor( + state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], concat=True + ) data = torch.rand([bsz, int(np.prod(state_shape)) + int(np.prod(action_shape))]) expect_output_shape = [bsz, 128] assert list(net(data)[0].shape) == expect_output_shape - net = Net( + net = MLPActor( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 518d53ef5..4c8fc1393 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -16,7 +16,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -74,13 +74,13 @@ def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = T test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 35f6b58f9..29c2962db 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -78,7 +78,7 @@ def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -87,7 +87,7 @@ def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( - preprocess_net=Net( + preprocess_net=MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index c265d2955..c9701609f 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -16,7 +16,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.common import ActorCritic, MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -84,12 +84,12 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( - preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), + preprocess_net=MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 1662f2abe..e0e932d90 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import EnsembleLinear, Net +from tianshou.utils.net.common import EnsembleLinear, MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -82,7 +82,7 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, @@ -94,7 +94,7 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T def linear(x: int, y: int) -> nn.Module: return EnsembleLinear(args.ensemble_size, x, y) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index b2d9e7f77..c698668aa 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -16,7 +16,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ( ContinuousActorDeterministic, ContinuousActorProbabilistic, @@ -94,12 +94,12 @@ def test_sac_with_il( test_envs.seed(args.seed + args.training_num) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -107,7 +107,7 @@ def test_sac_with_il( ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -184,7 +184,7 @@ def stop_fn(mean_rewards: float) -> bool: # here we define an imitation collector with a trivial policy if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal - il_net = Net( + il_net = MLPActor( state_shape=args.state_shape, hidden_sizes=args.imitation_hidden_sizes, ) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 3567f7668..a34478eac 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -16,7 +16,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -77,14 +77,14 @@ def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -92,7 +92,7 @@ def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 658df9efe..ae938d3c9 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -79,7 +79,7 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -88,7 +88,7 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( - preprocess_net=Net( + preprocess_net=MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, diff --git a/test/determinism_test.py b/test/determinism_test.py index 5cf3b8773..fa6e9babe 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -39,7 +39,7 @@ class AlgorithmDeterminismTest: 3. Inspect determinism_tests.log """ - ENABLED = False + ENABLED = True """ whether determinism tests are enabled. """ diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 862f10450..cf12c008b 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic try: @@ -97,7 +97,7 @@ def test_a2c_with_il( default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) critic = DiscreteCritic(preprocess_net=net).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) @@ -167,7 +167,7 @@ def stop_fn(mean_rewards: float) -> bool: # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 33b3c4cd6..1486aec13 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -11,7 +11,7 @@ from tianshou.policy.modelfree.bdqn import BDQNPolicy from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams -from tianshou.utils.net.common import BranchingNet +from tianshou.utils.net.common import BranchingActor from tianshou.utils.torch_utils import policy_within_training_step @@ -92,7 +92,7 @@ def test_bdq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = BranchingNet( + net = BranchingActor( state_shape=args.state_shape, num_branches=args.num_branches, action_per_branch=args.action_per_branch, diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index ec291da18..cd865e8aa 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -22,7 +22,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -87,7 +87,7 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index ed1df42f2..b4d16422f 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -83,15 +83,15 @@ def test_discrete_sac( # model obs_dim = space_info.observation_info.obs_dim action_dim = space_info.action_info.action_dim - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, softmax_output=False ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_c1 = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) critic1 = DiscreteCritic(preprocess_net=net_c1, last_size=action_dim).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = Net(state_shape=obs_dim, hidden_sizes=args.hidden_sizes) + net_c2 = MLPActor(state_shape=obs_dim, hidden_sizes=args.hidden_sizes) critic2 = DiscreteCritic(preprocess_net=net_c2, last_size=action_dim).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index cf002bbd2..7eab90569 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -82,7 +82,7 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # Q_param = V_param = {"hidden_sizes": [128]} # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 5b43866aa..98a1d8ab0 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, RMSpropOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -86,7 +86,7 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - feature_net = Net( + feature_net = MLPActor( state_shape=args.state_shape, action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index a6932309b..70b7abd80 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.discrete import ImplicitQuantileNetwork from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -86,7 +86,7 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - feature_net = Net( + feature_net = MLPActor( state_shape=args.state_shape, action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 9893f0dc5..01aa6d4a3 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -16,7 +16,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.space_info import SpaceInfo @@ -68,7 +68,7 @@ def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tru test_envs.seed(args.seed) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index dbfda23f2..144707d92 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -5,7 +5,6 @@ import gymnasium as gym import numpy as np import torch -import torch.nn as nn from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, CollectStats, VectorReplayBuffer @@ -16,7 +15,13 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net +from tianshou.utils.net.common import ( + ActorCritic, + ActorForwardInterface, + DataParallelNet, + MLPActor, + PolicyForwardDataParallelWrapper, +) from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -80,11 +85,11 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) - actor: nn.Module - critic: nn.Module + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + critic: DiscreteCritic | DataParallelNet + actor: ActorForwardInterface if torch.cuda.is_available(): - actor = DataParallelNet( + actor = PolicyForwardDataParallelWrapper( DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) ) critic = DataParallelNet(DiscreteCritic(preprocess_net=net).to(args.device)) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index bc1c17102..f64cea5ca 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -20,7 +20,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -86,7 +86,7 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 07d59a2a7..6aff5642a 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.discrete import NoisyLinear from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -96,7 +96,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: return NoisyLinear(x, y, args.noisy_std) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index d314ec2ab..ed2afaad8 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -18,7 +18,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.common import MLP, MLPActor from tianshou.utils.net.discrete import IntrinsicCuriosityModule from tianshou.utils.space_info import SpaceInfo @@ -99,7 +99,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: # Q_param = V_param = {"hidden_sizes": [128]} # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 9045e29c8..a5c75d591 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -16,7 +16,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLP, ActorCritic, Net +from tianshou.utils.net.common import MLP, ActorCritic, MLPActor from tianshou.utils.net.discrete import ( DiscreteActor, DiscreteCritic, @@ -104,7 +104,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) critic = DiscreteCritic(preprocess_net=net).to(args.device) actor_critic = ActorCritic(actor, critic) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 9ed763027..2b0d42be3 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -20,7 +20,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.space_info import SpaceInfo @@ -89,7 +89,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index b5291f21c..ac019469e 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -15,7 +15,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -92,14 +92,14 @@ def gather_data() -> VectorReplayBuffer: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 6839847f3..ca17f39cf 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer.base import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.common import MLP, MLPActor from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo @@ -109,7 +109,7 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index acc60461c..2c1c2936d 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -105,7 +105,7 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # model # actor network - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -119,7 +119,7 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network - net_c = Net( + net_c = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 4baeaa128..fedc743c4 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.discrete import DiscreteActor from tianshou.utils.space_info import SpaceInfo @@ -80,7 +80,7 @@ def test_discrete_bcq( test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) + net = MLPActor(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) policy_net = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 2f1efbec6..be56effb5 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.space_info import SpaceInfo @@ -77,7 +77,7 @@ def test_discrete_cql( test_envs.seed(args.seed) # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 278268ff5..37a831322 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -21,7 +21,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -75,7 +75,7 @@ def test_discrete_crr( test_envs.seed(args.seed) # model and algorithm - net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) + net = MLPActor(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 81a25a613..b9493b787 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.common import ActorCritic, MLPActor from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -92,7 +92,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, @@ -101,7 +101,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T args.device, ) critic = ContinuousCritic( - preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), + preprocess_net=MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization @@ -112,7 +112,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T optim = AdamOptimizerFactory(lr=args.lr) # discriminator disc_net = ContinuousCritic( - preprocess_net=Net( + preprocess_net=MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index e62472224..651317728 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -19,7 +19,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -96,7 +96,7 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = test_envs.seed(args.seed) # actor network - net_a = Net( + net_a = MLPActor( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, ) @@ -108,13 +108,13 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic networks - net_c1 = Net( + net_c1 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = Net( + net_c2 = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 1634581fe..cfb7de7da 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -17,7 +17,7 @@ from tianshou.policy.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor def get_parser() -> argparse.ArgumentParser: @@ -96,7 +96,7 @@ def get_agents( optims = [] for _ in range(args.n_pistons): # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index e4ea9ba2d..0360eab1d 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -24,7 +24,7 @@ from tianshou.policy.optim import AdamOptimizerFactory, OptimizerFactory from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor def get_env(render_mode: str | None = None) -> PettingZooEnv: @@ -115,7 +115,7 @@ def get_agents( args.action_shape = env.action_space.shape or int(env.action_space.n) if agent_learn is None: # model - net = Net( + net = MLPActor( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 70478a87d..3a380fcbd 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -73,6 +73,7 @@ TDistribution = TypeVar("TDistribution", bound=Distribution) T = TypeVar("T") TArr = torch.Tensor | np.ndarray +TObsArr = torch.Tensor | np.ndarray log = logging.getLogger(__name__) diff --git a/tianshou/data/types.py b/tianshou/data/types.py index 35ec917f6..b87984ea2 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -4,7 +4,9 @@ import torch from tianshou.data import Batch -from tianshou.data.batch import BatchProtocol, TArr +from tianshou.data.batch import BatchProtocol, TArr, TObsArr + +TObs = TObsArr | BatchProtocol TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"] @@ -15,14 +17,19 @@ class ObsBatchProtocol(BatchProtocol, Protocol): Typically used inside a policy's forward """ - obs: TArr | BatchProtocol - info: TArr | BatchProtocol + obs: TObs + """the observations as generated by the environment in `step`. + If it is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors)""" + info: TArr + """array of info dicts generated by the environment in `step`""" class RolloutBatchProtocol(ObsBatchProtocol, Protocol): """Typically, the outcome of sampling from a replay buffer.""" - obs_next: TArr | BatchProtocol + obs_next: TObs + """the observations after obs as generated by the environment in `step`. + If it is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors)""" act: TArr rew: np.ndarray terminated: TArr @@ -39,6 +46,7 @@ class PrioBatchProtocol(RolloutBatchProtocol, Protocol): """Contains weights that can be used for prioritized replay.""" weight: np.ndarray | torch.Tensor + """can be used for prioritized replay.""" class RecurrentStateBatch(BatchProtocol, Protocol): @@ -118,7 +126,7 @@ class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): taus: torch.Tensor -class ImitationBatchProtocol(ActBatchProtocol, Protocol): +class ImitationBatchProtocol(ModelOutputBatchProtocol, Protocol): """Similar to other batches, but contains `imitation_logits` and `q_value` fields.""" state: dict | Batch | np.ndarray | None diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index f14d73670..971855b78 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -1,10 +1,12 @@ from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, TypeVar import numpy as np import torch from torch import nn +from tianshou.data import Batch +from tianshou.data.types import TObs from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.core import ( @@ -16,7 +18,7 @@ ) from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont -from tianshou.utils.net.common import NetBase +from tianshou.utils.net.common import Actor, ModuleWithVectorOutput from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear from tianshou.utils.torch_utils import torch_device @@ -28,29 +30,35 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0. return layer -class ScaledObsInputModule(NetBase): - def __init__(self, module: NetBase, denom: float = 255.0) -> None: +T = TypeVar("T") + + +class ScaledObsInputModule(Actor): + def __init__(self, module: Actor, denom: float = 255.0) -> None: super().__init__(module.get_output_dim()) self.module = module self.denom = denom + def get_preprocess_net(self) -> ModuleWithVectorOutput: + return self.module.get_preprocess_net() + def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, - info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: + obs: TObs, + state: T | None = None, + info: dict[str, T] | None = None, + ) -> tuple[torch.Tensor | Sequence[torch.Tensor], T | None]: if info is None: info = {} - return self.module.forward(obs / self.denom, state, info) - - -def scale_obs(module: NetBase, denom: float = 255.0) -> ScaledObsInputModule: - """TODO.""" - return ScaledObsInputModule(module, denom=denom) + scaler = lambda arr: arr / self.denom + if isinstance(obs, Batch): + scaled_obs = obs.apply_values_transform(scaler) + else: + scaled_obs = scaler(obs) + return self.module.forward(scaled_obs, state, info) -class DQNet(NetBase[Any]): +class DQNet(Actor[Any]): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -104,14 +112,19 @@ def __init__( super().__init__(output_dim) self.net = net + def get_preprocess_net(self) -> ModuleWithVectorOutput: + return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) + def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, - info: dict[str, Any] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - r"""Mapping: s -> Q(s, \*).""" + obs: TObs, + state: T | None = None, + info: dict[str, T] | None = None, + ) -> tuple[torch.Tensor, T | None]: + r"""Mapping: s -> Q(s, \*). + + For more info, see docstring of parent. + """ device = torch_device(self) obs = torch.as_tensor(obs, device=device, dtype=torch.float32) return self.net(obs), state @@ -139,11 +152,10 @@ def __init__( def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, - info: dict[str, Any] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: + obs: TObs, + state: T | None = None, + info: dict[str, T] | None = None, + ) -> tuple[torch.Tensor, T | None]: r"""Mapping: x -> Z(x, \*).""" obs, state = super().forward(obs) obs = obs.view(-1, self.num_atoms).softmax(dim=-1) @@ -195,12 +207,10 @@ def linear(x: int, y: int) -> NoisyLinear | nn.Linear: def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - r"""Mapping: x -> Z(x, \*).""" + ) -> tuple[torch.Tensor, T | None]: obs, state = super().forward(obs) q = self.Q(obs) q = q.view(-1, self.action_num, self.num_atoms) @@ -236,12 +246,10 @@ def __init__( def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any | None = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: - r"""Mapping: x -> Z(x, \*).""" + ) -> tuple[torch.Tensor, T | None]: obs, state = super().forward(obs) obs = obs.view(-1, self.action_num, self.num_quantiles) return obs, state @@ -276,7 +284,7 @@ def create_module(self, envs: Environments, device: TDevice) -> DiscreteActor: layer_init=layer_init, ) if self.scale_obs: - net = scale_obs(net) + net = ScaledObsInputModule(net) return DiscreteActor( preprocess_net=net, action_shape=envs.get_action_shape(), diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index e3807569b..667a6c62f 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -24,7 +24,12 @@ ) from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net import continuous, discrete -from tianshou.utils.net.common import Actor, ModuleType, ModuleWithVectorOutput, Net +from tianshou.utils.net.common import ( + Actor, + MLPActor, + ModuleType, + ModuleWithVectorOutput, +) class ContinuousActorType(Enum): @@ -146,7 +151,7 @@ def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: - net_a = Net( + net_a = MLPActor( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, @@ -182,7 +187,7 @@ def __init__( self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: - net_a = Net( + net_a = MLPActor( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, @@ -217,7 +222,7 @@ def __init__( self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: - net_a = Net( + net_a = MLPActor( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 54596be12..f76ab6be5 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -9,7 +9,7 @@ from tianshou.highlevel.module.actor import ActorFuture from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.utils.net import continuous -from tianshou.utils.net.common import Actor, EnsembleLinear, ModuleType, Net +from tianshou.utils.net.common import Actor, EnsembleLinear, MLPActor, ModuleType from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @@ -91,7 +91,7 @@ def create_module( discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 - net_c = Net( + net_c = MLPActor( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, @@ -116,7 +116,7 @@ def create_module( discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 - net_c = Net( + net_c = MLPActor( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, @@ -239,7 +239,7 @@ def linear_layer(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(ensemble_size, x, y) action_shape = envs.get_action_shape() if use_action else 0 - net_c = Net( + net_c = MLPActor( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 55957ecb2..846a708bb 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -240,47 +240,6 @@ def __init__( def action_type(self) -> Literal["discrete", "continuous"]: return self._action_type - @abstractmethod - def forward( - self, - batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, - ) -> ActBatchProtocol | ActStateBatchProtocol: # TODO: make consistent typing - """Compute action over the given batch data. - - :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: - - * ``act`` a numpy.ndarray or a torch.Tensor, the action over \ - given batch data. - * ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \ - internal state of the policy, ``None`` as default. - - Other keys are user-defined. It depends on the algorithm. For example, - :: - - # some code - return Batch(logits=..., act=..., state=None, dist=...) - - The keyword ``policy`` is reserved and the corresponding data will be - stored into the replay buffer. For instance, - :: - - # some code - return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) - # and in the sampled data batch, you can directly use - # batch.policy.log_prob to get your data. - - .. note:: - - In continuous action space, you should do another step "map_action" to get - the real action: - :: - - act = policy(batch).act # doesn't map to the target action range - act = policy.map_action(act, batch) - """ - @staticmethod def _action_to_numpy(act: TArr) -> np.ndarray: act = to_numpy(act) # NOTE: to_numpy could confusingly also return a Batch diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 6a66f986c..056aa0cb3 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -6,6 +6,7 @@ import numpy as np import torch import torch.nn.functional as F +from torch import nn from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data.types import ( @@ -99,17 +100,16 @@ def __init__( self._log_tau = math.log(unlikely_action_threshold) else: self._log_tau = -np.inf - self.max_action_num: int | None = None - def forward( # type: ignore + def forward( self, batch: ObsBatchProtocol, - state: dict | Batch | np.ndarray | None = None, - **kwargs: Any, + state: Any | None = None, + model: nn.Module | None = None, ) -> ImitationBatchProtocol: - q_value, state = self.model(batch.obs, state=state, info=batch.info) - if self.max_action_num is None: - self.max_action_num = q_value.shape[1] + if model is None: + model = self.model + q_value, state = model(batch.obs, state=state, info=batch.info) imitation_logits, _ = self.imitator(batch.obs, state=state, info=batch.info) # mask actions for argmax @@ -117,7 +117,13 @@ def forward( # type: ignore mask = (ratio < self._log_tau).float() act = (q_value - INF * mask).argmax(dim=-1) - result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits) + result = Batch( + act=act, + state=state, + q_value=q_value, + imitation_logits=imitation_logits, + logits=imitation_logits, + ) return cast(ImitationBatchProtocol, result) diff --git a/tianshou/policy/modelfree/bdqn.py b/tianshou/policy/modelfree/bdqn.py index 195a22666..e8196004a 100644 --- a/tianshou/policy/modelfree/bdqn.py +++ b/tianshou/policy/modelfree/bdqn.py @@ -21,16 +21,16 @@ ) from tianshou.policy.modelfree.pg import SimpleLossTrainingStats from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.common import BranchingNet +from tianshou.utils.net.common import BranchingActor mark_used(ActBatchProtocol) -class BDQNPolicy(DiscreteQLearningPolicy[BranchingNet]): +class BDQNPolicy(DiscreteQLearningPolicy[BranchingActor]): def __init__( self, *, - model: BranchingNet, + model: BranchingActor, action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, eps_training: float = 0.0, @@ -69,6 +69,7 @@ def forward( ) -> ModelOutputBatchProtocol: if model is None: model = self.model + assert model is not None obs = batch.obs # TODO: this is very contrived, see also iqn.py obs_next_BO = obs.obs if hasattr(obs, "obs") else obs diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 01dbbbeab..a8ced0fdc 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -10,13 +10,13 @@ ) from tianshou.policy.modelfree.pg import LossSequenceTrainingStats from tianshou.policy.optim import OptimizerFactory -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor class C51Policy(DiscreteQLearningPolicy): def __init__( self, - model: torch.nn.Module | Net, + model: torch.nn.Module | MLPActor, action_space: gym.spaces.Space, observation_space: gym.Space | None = None, num_atoms: int = 51, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c8fd70907..37420b806 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -5,10 +5,10 @@ import gymnasium as gym import numpy as np import torch +from gymnasium.spaces.discrete import Discrete from sensai.util.helper import mark_used from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as -from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( ActBatchProtocol, BatchWithReturnsProtocol, @@ -28,11 +28,11 @@ ) from tianshou.policy.optim import OptimizerFactory from tianshou.utils.lagged_network import EvalModeModuleWrapper -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import MLPActor mark_used(ActBatchProtocol) -TModel = TypeVar("TModel", bound=torch.nn.Module | Net) +TModel = TypeVar("TModel", bound=torch.nn.Module | MLPActor) log = logging.getLogger(__name__) @@ -68,8 +68,8 @@ def __init__( action_scaling=False, action_bound_method=None, ) + self.action_space = cast(Discrete, self.action_space) self.model = model - self.max_action_num: int | None = None self.eps_training = eps_training self.eps_inference = eps_inference @@ -101,9 +101,8 @@ def set_eps_inference(self, eps: float) -> None: def forward( self, batch: ObsBatchProtocol, - state: dict | BatchProtocol | np.ndarray | None = None, + state: Any | None = None, model: torch.nn.Module | None = None, - **kwargs: Any, ) -> ModelOutputBatchProtocol: """Compute action over the given batch data. @@ -121,6 +120,10 @@ def forward( ... ) + :param batch: + :param state: optional hidden state (for RNNs) + :param model: if not passed will use `self.model`. Typically used to pass + the lagged target network instead of using the current model. :return: A :class:`~tianshou.data.Batch` which has 3 keys: * ``act`` the action. @@ -130,12 +133,11 @@ def forward( if model is None: model = self.model obs = batch.obs + mask = obs.mask # TODO: this is convoluted! See also other places where this is done. - obs_next = obs.obs if hasattr(obs, "obs") else obs - action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info) - q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None)) - if self.max_action_num is None: - self.max_action_num = q.shape[1] + obs_arr = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_arr, state=state, info=batch.info) + q = self.compute_q_value(action_values_BA, mask) act_B = to_numpy(q.argmax(dim=1)) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) @@ -158,10 +160,9 @@ def add_exploration_noise( if isinstance(act, np.ndarray) and not np.isclose(eps, 0.0): batch_size = len(act) rand_mask = np.random.rand(batch_size) < eps - assert ( - self.max_action_num is not None - ), "Can't call this method before max_action_num was set in first forward" - q = np.random.rand(batch_size, self.max_action_num) # [0, 1] + self.action_space = cast(Discrete, self.action_space) # for mypy + action_num = int(self.action_space.n) + q = np.random.rand(batch_size, action_num) # [0, 1] if hasattr(batch.obs, "mask"): q += batch.obs.mask rand_act = q.argmax(axis=1) diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index c709056b9..ee9fb25ff 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -95,9 +95,6 @@ def forward( # type: ignore q = DiscreteQLearningPolicy.compute_q_value( self, weighted_logits.sum(2), getattr(obs, "mask", None) ) - if self.max_action_num is None: # type: ignore - # TODO: see same thing in DQNPolicy! Also reduce code duplication. - self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) result = Batch( logits=logits, diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index dec4d87a0..9a7dbb658 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -95,9 +95,6 @@ def forward( info=batch.info, ) q = self.compute_q_value(logits, getattr(obs, "mask", None)) - if self.max_action_num is None: # type: ignore - # TODO: see same thing in DQNPolicy! - self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) result = Batch(logits=logits, act=act, state=hidden, taus=taus) return cast(QuantileRegressionBatchProtocol, result) diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 0dd4b9685..e0216dc42 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Literal, TypeVar, cast +from typing import Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -31,6 +31,7 @@ from tianshou.policy.optim import OptimizerFactory from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ( + ActorForwardInterface, ContinuousActorProbabilisticInterface, DiscreteActorInterface, ) @@ -73,7 +74,9 @@ class ActorPolicyProbabilistic(Policy): def __init__( self, *, - actor: ContinuousActorProbabilisticInterface | DiscreteActorInterface, + actor: ContinuousActorProbabilisticInterface + | DiscreteActorInterface + | ActorForwardInterface, dist_fn: TDistFnDiscrOrCont, deterministic_eval: bool = False, action_space: gym.Space, @@ -162,7 +165,6 @@ def forward( self, batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, - **kwargs: Any, ) -> DistBatchProtocol: """Compute action over the given batch data by applying the actor. @@ -170,7 +172,6 @@ def forward( Returns a new object representing the processed batch data (contrary to other methods that modify the input batch inplace). """ - # TODO - ALGO: marked for algorithm refactoring action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A # therefore action_dist_input_BD is equivalent to logits_BA @@ -192,7 +193,7 @@ class DiscreteActorPolicy(ActorPolicyProbabilistic): def __init__( self, *, - actor: DiscreteActorInterface, + actor: DiscreteActorInterface | ActorForwardInterface, dist_fn: TDistFnDiscrete = dist_fn_categorical_from_logits, deterministic_eval: bool = False, action_space: gym.Space, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 1516bc936..bbd05ab38 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -7,8 +7,8 @@ from gymnasium import spaces from torch import nn -from tianshou.data.batch import Batch, BatchProtocol -from tianshou.data.types import RecurrentStateBatch +from tianshou.data.batch import Batch +from tianshou.data.types import RecurrentStateBatch, TObs from tianshou.utils.space_info import ActionSpaceInfo from tianshou.utils.torch_utils import torch_device @@ -181,24 +181,47 @@ def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: TRecurrentState = TypeVar("TRecurrentState", bound=Any) -class PolicyForwardInterface(Generic[TRecurrentState], ABC): - """Defines the `forward` interface for neural networks used in policies.""" +class ActorForwardInterface(Generic[TRecurrentState], nn.Module, ABC): + """Defines the `forward` interface for neural networks used as actors in policies. + + Note that for DQN-like algorithms the critic is used as an actor (since actions + are computed from it), see e.g. :class:`~DiscreteActor`. + """ @abstractmethod def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, state: TRecurrentState | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, TRecurrentState | None]: - pass + ) -> tuple[torch.Tensor | Sequence[torch.Tensor], TRecurrentState | None]: + """ + The main method for tianshou to compute action representations (such as actions, inputs of distributions, Q-values, etc) + from env observations. + Implementations will always make use of the preprocess_net as the first processing step. + :param obs: the observations from the environment as retrieved from `ObsBatchProtocol.obs`. + If the environment is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors). + :param state: the hidden state of the RNN, if applicable + :param info: the info object from the environment step + :return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or + a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution), + and hidden_state is the new hidden state of the RNN, if applicable. + """ + + +class Actor(Generic[T], ModuleWithVectorOutput, ActorForwardInterface[T], ABC): + @abstractmethod + def get_preprocess_net(self) -> ModuleWithVectorOutput: + """Typically a first part of the network that preprocesses the input into a latent representation. -class NetBase(ModuleWithVectorOutput, PolicyForwardInterface[TRecurrentState], ABC): - """Base class for NNs used in policies which produce vector outputs.""" + E.g., a CNN (often used in atari examples). We need this method to be able to + share latent representation with other networks (e.g., critic) within an Algorithm. + Networks that don't have this can use nn.Identity() as a preprocess net (see :class:`RandomActor`). + """ -class Net(NetBase[Any]): +class MLPActor(Actor[Any]): """Wrapper of MLP to support more specific DRL usage. For advanced usage (how to customize the network), please refer to @@ -298,12 +321,15 @@ def __init__( self.Q = Q self.V = V + def get_preprocess_net(self) -> ModuleWithVectorOutput: + return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) + def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: + ) -> tuple[torch.Tensor, T | Any]: """Mapping: obs -> flatten (inside MLP)-> logits. :param obs: @@ -327,7 +353,7 @@ def forward( return logits, state -class Recurrent(NetBase[RecurrentStateBatch]): +class Recurrent(Actor[RecurrentStateBatch]): """Simple Recurrent network based on LSTM. For advanced usage (how to customize the network), please refer to @@ -353,9 +379,12 @@ def __init__( self.fc1 = nn.Linear(int(np.prod(state_shape)), hidden_layer_size) self.fc2 = nn.Linear(hidden_layer_size, output_dim) + def get_preprocess_net(self) -> ModuleWithVectorOutput: + return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) + def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, state: RecurrentStateBatch | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, RecurrentStateBatch]: @@ -436,7 +465,7 @@ class DataParallelNet(nn.Module): Tensor. If the input is a nested dictionary, the user should create a similar class to do the same thing. - :param nn.Module net: the network to be distributed in different GPUs. + :param net: the network to be distributed in different GPUs. """ def __init__(self, net: nn.Module) -> None: @@ -445,13 +474,33 @@ def __init__(self, net: nn.Module) -> None: def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, *args: Any, **kwargs: Any, ) -> tuple[Any, Any]: if not isinstance(obs, torch.Tensor): obs = torch.as_tensor(obs, dtype=torch.float32) - return self.net(obs=obs.cuda(), *args, **kwargs) # noqa: B026 + obs = obs.cuda() + return self.net(obs, *args, **kwargs) + + +# The same functionality as DataParallelNet +# The duplication is worth it because the PolicyForwardInterface is so important +class PolicyForwardDataParallelWrapper(ActorForwardInterface): + def __init__(self, net: ActorForwardInterface) -> None: + super().__init__() + self.net = nn.DataParallel(net) + + def forward( + self, + obs: TObs, + state: TRecurrentState | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, TRecurrentState | None]: + if not isinstance(obs, torch.Tensor): + obs = torch.as_tensor(obs, dtype=torch.float32) + obs = obs.cuda() + return self.net(obs, state=state, info=info) class EnsembleLinear(nn.Module): @@ -489,7 +538,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class BranchingNet(nn.Module, PolicyForwardInterface): +class BranchingActor(ActorForwardInterface): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module @@ -596,10 +645,10 @@ def __init__( def forward( self, - obs: np.ndarray | torch.Tensor, - state: Any = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: + ) -> tuple[torch.Tensor, T | None]: """Mapping: obs -> model -> logits.""" common_out = self.common(obs) value_out = self.value(common_out) @@ -652,7 +701,7 @@ def preprocess_obs(obs: Batch | dict | torch.Tensor | np.ndarray) -> torch.Tenso @no_type_check def decorator_fn(net_class): class new_net_class(net_class): - def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: + def forward(self, obs: TObs, *args, **kwargs) -> Any: return super().forward(preprocess_obs(obs), *args, **kwargs) return new_net_class @@ -660,35 +709,6 @@ def forward(self, obs: np.ndarray | torch.Tensor, *args, **kwargs) -> Any: return decorator_fn, new_state_shape -class Actor(ModuleWithVectorOutput, ABC): - @abstractmethod - def get_preprocess_net(self) -> ModuleWithVectorOutput: - """Typically a first part of the network that preprocesses the input into a latent representation. - E.g., a CNN (often used in atari examples). We need this method to be able to - share latent representation with other networks (e.g., critic) within an Algorithm. - Networks that don't have this can use nn.Identity() as a preprocess net (see :class:`RandomActor`). - """ - - @abstractmethod - def forward( - self, - obs: np.ndarray | torch.Tensor, - state: T | None = None, - info: dict[str, Any] | None = None, - ) -> tuple[np.ndarray | torch.Tensor | Sequence[torch.Tensor], T | None]: - """ - The main method for tianshou to compute actions from env observations. - Implementations will always make use of the preprocess_net as the first processing step. - - :param obs: the observation from the environment - :param state: the hidden state of the RNN, if applicable - :param info: the info object from the environment step - :return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or - a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution), - and hidden_state is the new hidden state of the RNN, if applicable. - """ - - class ContinuousActorProbabilisticInterface(Actor, ABC): """Marker interface for probabilistic actors defined by users (outside of Tianshou code).""" @@ -737,22 +757,26 @@ def is_discrete(self) -> bool: def forward( self, - obs: np.ndarray | torch.Tensor | BatchProtocol, - state: Any | None = None, + obs: TObs, + state: T | None = None, info: dict[str, Any] | None = None, - ) -> tuple[np.ndarray, Any | None]: + ) -> tuple[torch.Tensor, T | None]: batch_size = len(obs) if isinstance(self.action_space, spaces.Box): action = np.stack([self.action_space.sample() for _ in range(batch_size)]) else: # Discrete Actors currently return an n-dimensional array of probabilities for each action action = 1 / self.action_space.n * np.ones((batch_size, self.action_space.n)) - return action, state + return torch.Tensor(action), state - def compute_action_batch(self, obs: np.ndarray | torch.Tensor | BatchProtocol) -> np.ndarray: + def compute_action_batch(self, obs: TObs) -> torch.Tensor: if self.is_discrete: # Different from forward which returns discrete probabilities, see comment there assert isinstance(self.action_space, spaces.Discrete) # for mypy - return np.random.randint(low=0, high=self.action_space.n, size=len(obs)) + return torch.Tensor(np.random.randint(low=0, high=self.action_space.n, size=len(obs))) else: return self.forward(obs)[0] + + +class NetBase: + pass diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index fc06bc455..89adc9281 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -8,6 +8,7 @@ from sensai.util.pickle import setstate from torch import nn +from tianshou.data.types import TObs from tianshou.utils.net.common import ( MLP, Actor, @@ -71,7 +72,7 @@ def get_output_dim(self) -> int: def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | None]: @@ -227,11 +228,10 @@ def get_preprocess_net(self) -> ModuleWithVectorOutput: def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[tuple[torch.Tensor, torch.Tensor], T | None]: - """Mapping: obs -> logits -> (mu, sigma).""" if info is None: info = {} logits, hidden = self.preprocess(obs, state) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index e1405be1f..53569fa33 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -7,6 +7,7 @@ from torch import nn from tianshou.data import Batch, to_torch +from tianshou.data.types import TObs from tianshou.utils.net.common import ( MLP, DiscreteActorInterface, @@ -68,7 +69,7 @@ def get_preprocess_net(self) -> ModuleWithVectorOutput: def forward( self, - obs: np.ndarray | torch.Tensor, + obs: TObs, state: T | None = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, T | None]: @@ -117,11 +118,12 @@ def __init__( input_dim = preprocess_net.get_output_dim() self.last = MLP(input_dim=input_dim, output_dim=last_size, hidden_sizes=hidden_sizes) - # TODO: make a proper interface! - def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: + def forward( + self, obs: TObs, state: T | None = None, info: dict[str, Any] | None = None + ) -> torch.Tensor: """Mapping: s_B -> V(s)_B.""" # TODO: don't use this mechanism for passing state - logits, _ = self.preprocess(obs, state=kwargs.get("state", None)) + logits, _ = self.preprocess(obs, state=state) return self.last(logits) From 5ced88228660ab5bfd39ecd2d66ae0c2e418189f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 16 May 2025 18:27:46 +0200 Subject: [PATCH 187/230] Minor fix in type and getattr (needs explicit None) --- examples/discrete/discrete_dqn.py | 2 +- tianshou/algorithm/modelfree/dqn.py | 2 +- tianshou/algorithm/random.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 9ea6e92a5..37bf9fea3 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -40,7 +40,7 @@ def main() -> None: policy = DiscreteQLearningPolicy( model=net, action_space=env.action_space, eps_training=eps_train, eps_inference=eps_test ) - algorithm: ts.policy.DQN = ts.policy.DQN( + algorithm = ts.algorithm.DQN( policy=policy, optim=optim, gamma=gamma, diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index 7b4b7c902..57142d2d6 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -133,7 +133,7 @@ def forward( if model is None: model = self.model obs = batch.obs - mask = obs.mask + mask = getattr(obs, "mask", None) # TODO: this is convoluted! See also other places where this is done. obs_arr = obs.obs if hasattr(obs, "obs") else obs action_values_BA, hidden_BH = model(obs_arr, state=state, info=batch.info) diff --git a/tianshou/algorithm/random.py b/tianshou/algorithm/random.py index b374ef301..2d66040fa 100644 --- a/tianshou/algorithm/random.py +++ b/tianshou/algorithm/random.py @@ -3,7 +3,8 @@ import gymnasium as gym import numpy as np -from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, Policy, TrainingStats +from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm, TrainingStats +from tianshou.algorithm.algorithm_base import Policy as BasePolicy from tianshou.data import Batch from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol @@ -19,7 +20,7 @@ class MARLRandomDiscreteMaskedOffPolicyAlgorithm(OffPolicyAlgorithm): It randomly chooses an action from the legal actions (according to the given mask). """ - class Policy(Policy): + class Policy(BasePolicy): """A random agent used in multi-agent learning. It randomly chooses an action from the legal actions. From a4877ed602368fc56fb2154c7e6b23659be186fd Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 19:58:28 +0200 Subject: [PATCH 188/230] Relax determinism tests: * Require only core message equivalence (network parameter hashes) for the test to pass * Allow to ignore certain messages on a per-test level --- test/continuous/test_redq.py | 7 +- test/determinism_test.py | 20 +++++- tianshou/utils/determinism.py | 120 ++++++++++++++++++++-------------- 3 files changed, 95 insertions(+), 52 deletions(-) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index b685113c6..c2d80c587 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -176,4 +176,9 @@ def stop_fn(mean_rewards: float) -> bool: def test_redq_determinism() -> None: main_fn = lambda args: test_redq(args, enable_assertions=False) - AlgorithmDeterminismTest("continuous_redq", main_fn, get_args()).run() + ignored_messages = [ + "Params[actor_old]" + ] # actor_old only present in v1 (due to flawed inheritance) + AlgorithmDeterminismTest( + "continuous_redq", main_fn, get_args(), ignored_messages=ignored_messages + ).run() diff --git a/test/determinism_test.py b/test/determinism_test.py index fa6e9babe..7d6a9f1b1 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -1,5 +1,5 @@ from argparse import Namespace -from collections.abc import Callable +from collections.abc import Callable, Sequence from pathlib import Path from typing import Any @@ -49,6 +49,13 @@ class AlgorithmDeterminismTest: Enable this when running on the "old" branch and you want to prepare the snapshots for a comparison with the "new" branch. """ + PASS_IF_CORE_MESSAGES_UNCHANGED = True + """ + whether to pass the test if only the core messages are unchanged. + If this is False, then the full log is required to be equivalent, whereas if it is True, + only the core messages need to be equivalent. + The core messages test whether the algorithm produces the same network parameters. + """ def __init__( self, @@ -56,6 +63,7 @@ def __init__( main_fn: Callable[[Namespace], Any], args: Namespace, is_offline: bool = False, + ignored_messages: Sequence[str] = (), ): """ :param name: the (unique!) name of the test @@ -64,10 +72,13 @@ def __init__( for the test) :param is_offline: whether the algorithm being tested is an offline algorithm and therefore does not configure the number of training environments (`training_num`) + :param ignored_messages: message fragments to ignore in the trace log (if any) """ self.determinism_test = TraceDeterminismTest( base_path=Path(__file__).parent / "resources" / "determinism", log_filename="determinism_tests.log", + core_messages=["Params"], + ignored_messages=ignored_messages, ) self.name = name @@ -104,4 +115,9 @@ def run(self, update_snapshot: bool = False) -> None: self.main_fn(self.args) log = trace.get_log() - self.determinism_test.check(log, self.name, create_reference_result=update_snapshot) + self.determinism_test.check( + log, + self.name, + create_reference_result=update_snapshot, + pass_if_core_messages_unchanged=self.PASS_IF_CORE_MESSAGES_UNCHANGED, + ) diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index 11bb325bb..77149f4a5 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -141,14 +141,19 @@ def filter_messages( self, required_messages: Sequence[str] = (), optional_messages: Sequence[str] = (), + ignored_messages: Sequence[str] = (), ) -> "TraceLog": """ - Reduces the set of log messages to a set of core messages that indicate that the fundamental - trace is still the same (same actions, same states, same images). + Applies inclusion and or exclusion filtering to the log messages. + If either `required_messages` or `optional_messages` is empty, inclusion filtering is applied. + If `ignored_messages` is empty, exclusion filtering is applied. + If both inclusion and exclusion filtering are applied, the exclusion filtering takes precedence. - :param required_messages: message substrings to filter for; each message is required to appear at least once + :param required_messages: required message substrings to filter for; each message is required to appear at least once (triggering exception otherwise) :param optional_messages: additional messages fragments to filter for; these are not required + :param ignored_messages: message fragments that result in exclusion; takes precedence over + `required_messages` and `optional_messages` :return: the result with reduced log messages """ import numpy as np @@ -156,11 +161,17 @@ def filter_messages( required_message_counters = np.zeros(len(required_messages)) def retain_line(line: str) -> bool: - for i, main_message in enumerate(required_messages): - if main_message in line: - required_message_counters[i] += 1 - return True - return any(add_message in line for add_message in optional_messages) + for ignored_message in ignored_messages: + if ignored_message in line: + return False + if required_messages or optional_messages: + for i, main_message in enumerate(required_messages): + if main_message in line: + required_message_counters[i] += 1 + return True + return any(add_message in line for add_message in optional_messages) + else: + return True lines = [] for line in self.log_lines: @@ -241,16 +252,20 @@ def __init__( self, base_path: Path, core_messages: Sequence[str] = (), + ignored_messages: Sequence[str] = (), log_filename: str | None = None, ) -> None: """ :param base_path: the directory where the reference results are stored (will be created if necessary) :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core + :param ignored_messages: message fragments to ignore in the trace log (if any); takes precedence over + `core_messages` :param log_filename: the name of the log file to which results are to be written (if any) """ base_path.mkdir(parents=True, exist_ok=True) self.base_path = base_path self.core_messages = core_messages + self.ignored_messages = ignored_messages self.log_filename = log_filename @dataclass(kw_only=True) @@ -263,6 +278,7 @@ def check( current_log: TraceLog, name: str, create_reference_result: bool = False, + pass_if_core_messages_unchanged: bool = False, ) -> None: """ Checks the given log against the reference result for the given name. @@ -285,8 +301,12 @@ def check( ) reference_log = reference_result.log - current_log_reduced = current_log.reduce_log_to_messages() - reference_log_reduced = reference_log.reduce_log_to_messages() + current_log_reduced = current_log.reduce_log_to_messages().filter_messages( + ignored_messages=self.ignored_messages + ) + reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages( + ignored_messages=self.ignored_messages + ) results: list[tuple[TraceLog, str]] = [ (reference_log_reduced, "expected"), @@ -312,55 +332,57 @@ def check( result_main_messages = current_log_reduced reference_result_main_messages = reference_log_reduced - status_passed = True logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log() if logs_equivalent: status_passed = True status_message = "OK" else: - status_passed = False - - # save files for comparison - files = [] - for r, suffix in results: - path = os.path.abspath(f"determinism_{name}_{suffix}.txt") - r.save_log(path) - files.append(path) - - paths_str = "\n".join(files) - main_message = ( - f"Please inspect the changes by diffing the log files:\n{paths_str}\n" - f"If the changes are OK, enable the `create_reference_result` flag temporarily, " - "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n" - ) - - # compute diff and add to message - num_diff_lines_to_show = 30 - for i, line in enumerate( - difflib.unified_diff( - reference_log_reduced.log_lines, - current_log_reduced.log_lines, - fromfile="expected.txt", - tofile="current.txt", - lineterm="", - ), - ): - if i == num_diff_lines_to_show: - break - main_message += line + "\n" - - core_messages_changed_only = ( + core_messages_unchanged = ( len(self.core_messages) > 0 and result_main_messages.get_full_log() == reference_result_main_messages.get_full_log() ) - if core_messages_changed_only: - status_message = ( - "The behaviour log has changed, but the core messages are still the same (so this " - f"probably isn't an issue). {main_message}" - ) + status_passed = core_messages_unchanged and pass_if_core_messages_unchanged + + if status_passed: + status_message = "OK (core messages unchanged)" else: - status_message = f"The behaviour log has changed; even the core messages are different. {main_message}" + # save files for comparison + files = [] + for r, suffix in results: + path = os.path.abspath(f"determinism_{name}_{suffix}.txt") + r.save_log(path) + files.append(path) + + paths_str = "\n".join(files) + main_message = ( + f"Please inspect the changes by diffing the log files:\n{paths_str}\n" + f"If the changes are OK, enable the `create_reference_result` flag temporarily, " + "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n" + ) + + # compute diff and add to message + num_diff_lines_to_show = 30 + for i, line in enumerate( + difflib.unified_diff( + reference_log_reduced.log_lines, + current_log_reduced.log_lines, + fromfile="expected.txt", + tofile="current.txt", + lineterm="", + ), + ): + if i == num_diff_lines_to_show: + break + main_message += line + "\n" + + if core_messages_unchanged: + status_message = ( + "The behaviour log has changed, but the core messages are still the same (so this " + f"probably isn't an issue). {main_message}" + ) + else: + status_message = f"The behaviour log has changed; even the core messages are different. {main_message}" # write log message if self.log_filename: From ada0eaa013f005c871f329d8b8707d455d05b878 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 20:23:27 +0200 Subject: [PATCH 189/230] v2: Transfer recent algorithm parameter changes to high-level API --- tianshou/highlevel/algorithm.py | 2 +- tianshou/highlevel/params/algorithm_params.py | 54 ++++++++----------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 6f4dc5220..9a52f1366 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -675,7 +675,7 @@ def _create_policy( return self._create_policy_from_args( SACPolicy, params, - ["exploration_noise", "deterministic_eval", "action_scaling", "action_bound_method"], + ["exploration_noise", "deterministic_eval", "action_scaling"], actor=actor, action_space=envs.get_action_space(), observation_space=envs.get_observation_space(), diff --git a/tianshou/highlevel/params/algorithm_params.py b/tianshou/highlevel/params/algorithm_params.py index e6b1e45e7..b408e0205 100644 --- a/tianshou/highlevel/params/algorithm_params.py +++ b/tianshou/highlevel/params/algorithm_params.py @@ -235,6 +235,13 @@ class ParamsMixinActionScaling(GetParamTransformersProtocol): strategies. Should be disabled if the actor model already produces outputs in the correct range. """ + + def _get_param_transformers(self) -> list[ParamTransformer]: + return [ParamTransformerActionScaling("action_scaling")] + + +@dataclass(kw_only=True) +class ParamsMixinActionScalingAndBounding(ParamsMixinActionScaling): action_bound_method: Literal["clip", "tanh"] | None = "clip" """ the method used for bounding actions in continuous action spaces @@ -253,9 +260,6 @@ class ParamsMixinActionScaling(GetParamTransformersProtocol): Typically used together with `action_scaling=True`. """ - def _get_param_transformers(self) -> list[ParamTransformer]: - return [ParamTransformerActionScaling("action_scaling")] - @dataclass(kw_only=True) class ParamsMixinExplorationNoise(GetParamTransformersProtocol): @@ -337,13 +341,13 @@ class ParamsMixinDeterministicEval: class OnPolicyAlgorithmParams( Params, ParamsMixinGamma, - ParamsMixinActionScaling, + ParamsMixinActionScalingAndBounding, ParamsMixinSingleModel, ParamsMixinDeterministicEval, ): def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() - transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) transformers.extend(ParamsMixinSingleModel._get_param_transformers(self)) return transformers @@ -499,7 +503,7 @@ class PPOParams(A2CParams): @dataclass(kw_only=True) -class NPGParams(ReinforceParams, ParamsMixinGeneralAdvantageEstimation): +class NPGParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation): optim_critic_iters: int = 5 """ the number of optimization steps performed on the critic network for each policy (actor) update. @@ -621,7 +625,7 @@ class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScalin def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) - transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) return transformers @@ -647,21 +651,6 @@ class QLearningOffPolicyParams( Typically set between 100-10000 for DQN variants, with exact values depending on environment complexity. """ - return_scaling: bool = False - """ - flag indicating whether to enable scaling of estimated returns by - dividing them by their running standard deviation without centering the mean. - This reduces the magnitude variation of advantages across different episodes while - preserving their signs and relative ordering. - The use of running statistics (rather than batch-specific scaling) means that early - training experiences may be scaled differently than later ones as the statistics evolve. - When enabled, this improves training stability in environments with highly variable - reward scales and makes the algorithm less sensitive to learning rate settings. - However, it may reduce the algorithm's ability to distinguish between episodes with - different absolute return magnitudes. - Best used in environments where the relative ordering of actions is more important - than the absolute scale of returns. - """ eps_training: float = 0.0 """ the epsilon value for epsilon-greedy exploration during training. @@ -698,13 +687,16 @@ class DQNParams(QLearningOffPolicyParams): from the target network. Note: Double Q-learning will only be effective when a target network is used (target_update_freq > 0). """ - clip_loss_grad: bool = False + huber_loss_delta: float | None = None """ - flag indicating whether to use the Huber loss instead of the MSE loss for the TD error. - If True, uses the Huber loss as described in the Nature DQN paper (nature14236), which limits the influence - of outliers. Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber + controls whether to use the Huber loss instead of the MSE loss for the TD error and the threshold for + the Huber loss. + If None, the MSE loss is used. + If not None, uses the Huber loss as described in the Nature DQN paper (nature14236) with the given delta, + which limits the influence of outliers. + Unlike the MSE loss where the gradients grow linearly with the error magnitude, the Huber loss causes the gradients to plateau at a constant value for large errors, providing more stable training. - If False, uses the standard MSE loss where the gradient magnitude continues to scale with the error size. + NOTE: The magnitude of delta should depend on the scale of the returns obtained in the environment. """ def _get_param_transformers(self) -> list[ParamTransformer]: @@ -738,7 +730,7 @@ class DDPGParams( ParamsMixinGamma, ParamsMixinActorAndCritic, ParamsMixinExplorationNoise, - ParamsMixinActionScaling, + ParamsMixinActionScalingAndBounding, ParamsMixinNStepReturnHorizon, ParamsMixinTau, ): @@ -746,7 +738,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self)) transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) - transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) return transformers @@ -806,7 +798,7 @@ class TD3Params( ParamsMixinGamma, ParamsMixinActorAndDualCritics, ParamsMixinExplorationNoise, - ParamsMixinActionScaling, + ParamsMixinActionScalingAndBounding, ParamsMixinNStepReturnHorizon, ParamsMixinTau, ): @@ -845,7 +837,7 @@ def _get_param_transformers(self) -> list[ParamTransformer]: transformers = super()._get_param_transformers() transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self)) transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self)) - transformers.extend(ParamsMixinActionScaling._get_param_transformers(self)) + transformers.extend(ParamsMixinActionScalingAndBounding._get_param_transformers(self)) transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise")) transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip")) return transformers From b2fd31fd3e32bbe2469a4e825f87a4a8b4534186 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 20:27:26 +0200 Subject: [PATCH 190/230] v2: Relax discrete_sac determinism test (to account for v1 inheritance flaws) --- test/discrete/test_discrete_sac.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index cb747675c..e7690ca7b 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -161,4 +161,9 @@ def stop_fn(mean_rewards: float) -> bool: def test_discrete_sac_determinism() -> None: main_fn = lambda args: test_discrete_sac(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_sac", main_fn, get_args()).run() + ignored_messages = [ + "Params[actor_old]", # actor_old only present in v1 (due to flawed inheritance) + ] + AlgorithmDeterminismTest( + "discrete_sac", main_fn, get_args(), ignored_messages=ignored_messages + ).run() From a06c6ae387d81bf79ade0d5941ecaee4cbcab1ca Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 23:12:12 +0200 Subject: [PATCH 191/230] v2: Establish backward compatibility with persisted v1 buffers --- tianshou/data/buffer/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tianshou/data/buffer/__init__.py b/tianshou/data/buffer/__init__.py index e69de29bb..3a90d3bb8 100644 --- a/tianshou/data/buffer/__init__.py +++ b/tianshou/data/buffer/__init__.py @@ -0,0 +1,10 @@ +def _backward_compatibility(): + import sys + + from . import buffer_base + + # backward compatibility with persisted buffers from v1 for determinism tests + sys.modules["tianshou.data.buffer.base"] = buffer_base + + +_backward_compatibility() From fd3d19492cebd9550dc056bf089541b27f78141f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 16 May 2025 23:27:48 +0200 Subject: [PATCH 192/230] v2: renamed many params --- CHANGELOG.md | 2 + README.md | 37 +- docs/01_tutorials/00_dqn.rst | 6 +- docs/01_tutorials/04_tictactoe.rst | 18 +- docs/02_notebooks/L0_overview.ipynb | 496 ++++++------- docs/02_notebooks/L6_Trainer.ipynb | 562 +++++++-------- docs/02_notebooks/L7_Experiment.ipynb | 678 +++++++++--------- examples/atari/README.md | 24 +- examples/atari/atari_c51.py | 20 +- examples/atari/atari_dqn.py | 20 +- examples/atari/atari_dqn_hl.py | 12 +- examples/atari/atari_fqf.py | 20 +- examples/atari/atari_iqn.py | 20 +- examples/atari/atari_iqn_hl.py | 12 +- examples/atari/atari_ppo.py | 36 +- examples/atari/atari_ppo_hl.py | 24 +- examples/atari/atari_qrdqn.py | 20 +- examples/atari/atari_rainbow.py | 20 +- examples/atari/atari_sac.py | 22 +- examples/atari/atari_sac_hl.py | 20 +- examples/box2d/acrobot_dualdqn.py | 18 +- examples/box2d/bipedal_bdq.py | 18 +- examples/box2d/bipedal_hardcore_sac.py | 16 +- examples/box2d/lunarlander_dqn.py | 18 +- examples/box2d/mcc_sac.py | 16 +- examples/discrete/discrete_dqn.py | 8 +- examples/discrete/discrete_dqn_hl.py | 2 +- examples/inverse/irl_gail.py | 34 +- examples/modelbased/README.md | 6 +- examples/mujoco/README.md | 228 +++--- examples/mujoco/fetch_her_ddpg.py | 24 +- examples/mujoco/mujoco_a2c.py | 30 +- examples/mujoco/mujoco_a2c_hl.py | 20 +- examples/mujoco/mujoco_ddpg.py | 18 +- examples/mujoco/mujoco_ddpg_hl.py | 12 +- examples/mujoco/mujoco_npg.py | 38 +- examples/mujoco/mujoco_npg_hl.py | 32 +- examples/mujoco/mujoco_ppo.py | 34 +- examples/mujoco/mujoco_ppo_hl.py | 24 +- examples/mujoco/mujoco_ppo_hl_multi.py | 2 +- examples/mujoco/mujoco_redq.py | 18 +- examples/mujoco/mujoco_redq_hl.py | 12 +- examples/mujoco/mujoco_reinforce.py | 30 +- examples/mujoco/mujoco_reinforce_hl.py | 20 +- examples/mujoco/mujoco_sac.py | 18 +- examples/mujoco/mujoco_sac_hl.py | 12 +- examples/mujoco/mujoco_td3.py | 18 +- examples/mujoco/mujoco_td3_hl.py | 12 +- examples/mujoco/mujoco_trpo.py | 34 +- examples/mujoco/mujoco_trpo_hl.py | 24 +- examples/offline/atari_bcq.py | 6 +- examples/offline/atari_cql.py | 6 +- examples/offline/atari_crr.py | 6 +- examples/offline/atari_il.py | 6 +- examples/offline/d4rl_bcq.py | 8 +- examples/offline/d4rl_cql.py | 8 +- examples/offline/d4rl_il.py | 8 +- examples/offline/d4rl_td3_bc.py | 8 +- examples/vizdoom/README.md | 4 +- examples/vizdoom/env.py | 6 +- examples/vizdoom/vizdoom_c51.py | 18 +- examples/vizdoom/vizdoom_ppo.py | 34 +- test/continuous/test_ddpg.py | 18 +- test/continuous/test_npg.py | 34 +- test/continuous/test_ppo.py | 32 +- test/continuous/test_redq.py | 16 +- test/continuous/test_sac_with_il.py | 26 +- test/continuous/test_td3.py | 16 +- test/continuous/test_trpo.py | 30 +- test/determinism_test.py | 6 +- test/discrete/test_a2c_with_il.py | 34 +- test/discrete/test_bdqn.py | 20 +- test/discrete/test_c51.py | 18 +- test/discrete/test_discrete_sac.py | 16 +- test/discrete/test_dqn.py | 18 +- test/discrete/test_drqn.py | 18 +- test/discrete/test_fqf.py | 18 +- test/discrete/test_iqn.py | 18 +- test/discrete/test_ppo_discrete.py | 28 +- test/discrete/test_qrdqn.py | 18 +- test/discrete/test_rainbow.py | 18 +- test/discrete/test_reinforce.py | 24 +- test/highlevel/test_experiment_builder.py | 10 +- test/modelbased/test_dqn_icm.py | 18 +- test/modelbased/test_ppo_icm.py | 28 +- test/modelbased/test_psrl.py | 16 +- test/offline/gather_cartpole_data.py | 18 +- test/offline/gather_pendulum_data.py | 16 +- test/offline/test_bcq.py | 8 +- test/offline/test_cql.py | 8 +- test/offline/test_discrete_bcq.py | 8 +- test/offline/test_discrete_cql.py | 8 +- test/offline/test_discrete_crr.py | 8 +- test/offline/test_gail.py | 28 +- test/offline/test_td3_bc.py | 8 +- test/pettingzoo/pistonball.py | 18 +- test/pettingzoo/pistonball_continuous.py | 32 +- test/pettingzoo/tic_tac_toe.py | 18 +- tianshou/algorithm/modelfree/npg.py | 14 +- tianshou/algorithm/modelfree/ppo.py | 4 +- tianshou/algorithm/modelfree/trpo.py | 8 +- tianshou/env/atari/atari_wrapper.py | 18 +- tianshou/highlevel/config.py | 10 +- tianshou/highlevel/params/algorithm_params.py | 6 +- tianshou/trainer/trainer.py | 13 +- 105 files changed, 1884 insertions(+), 1848 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2e3160da..9ca3729e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -109,7 +109,9 @@ Developers: `LRSchedulerFactory`). The parameter `lr_scheduler` has thus been removed from all algorithm constructors. * The flag `updating` has been removed (no internal usage, general usefulness questionable). + * Removed `max_action_num`, instead read it off from `action_space` * Parameter changes: + * `actor_step_size` -> `trust_region_size` in NP * `discount_factor` -> `gamma` (was already used internally almost everywhere) * `reward_normalization` -> `return_standardization` or `return_scaling` (more precise naming) or removed (was actually unsupported by Q-learning algorithms) * `return_standardization` in `Reinforce` and `DiscreteCRR` (as it applies standardization of returns) diff --git a/README.md b/README.md index f7ae9d75e..cc14117a6 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,6 @@ 1. Convenient high-level interfaces for applications of RL (training an implemented algorithm on a custom environment). 1. Large scope: online (on- and off-policy) and offline RL, experimental support for multi-agent RL (MARL), experimental support for model-based RL, and more - Unlike other reinforcement learning libraries, which may have complex codebases, unfriendly high-level APIs, or are not optimized for speed, Tianshou provides a high-performance, modularized framework and user-friendly interfaces for building deep reinforcement learning agents. One more aspect that sets Tianshou apart is its @@ -183,15 +182,17 @@ Atari and MuJoCo benchmark results can be found in the [examples/atari/](example ### Algorithm Abstraction Reinforcement learning algorithms are build on abstractions for - * on-policy algorithms (`OnPolicyAlgorithm`), - * off-policy algorithms (`OffPolicyAlgorithm`), and - * offline algorithms (`OfflineAlgorithm`), + +- on-policy algorithms (`OnPolicyAlgorithm`), +- off-policy algorithms (`OffPolicyAlgorithm`), and +- offline algorithms (`OfflineAlgorithm`), all of which clearly separate the core algorithm from the training process and the respective environment interactions. In each case, the implementation of an algorithm necessarily involves only the implementation of methods for - * pre-processing a batch of data, augmenting it with necessary information/sufficient statistics for learning (`_preprocess_batch`), - * updating model parameters based on an augmented batch of data (`_update_with_batch`). + +- pre-processing a batch of data, augmenting it with necessary information/sufficient statistics for learning (`_preprocess_batch`), +- updating model parameters based on an augmented batch of data (`_update_with_batch`). The implementation of these methods suffices for a new algorithm to be applicable within Tianshou, making experimentation with new approaches particularly straightforward. @@ -249,12 +250,12 @@ experiment = ( ), OffPolicyTrainingConfig( num_epochs=10, - step_per_epoch=10000, + epoch_num_steps=10000, batch_size=64, num_train_envs=10, num_test_envs=100, buffer_size=20000, - step_per_collect=10, + collection_step_num_env_steps=10, update_per_step=1 / 10, ), ) @@ -288,10 +289,10 @@ The experiment builder takes three arguments: - the training configuration, which controls fundamental training parameters, such as the total number of epochs we run the experiment for (`num_epochs=10`) and the number of environment steps each epoch shall consist of - (`step_per_epoch=10000`). + (`epoch_num_steps=10000`). Every epoch consists of a series of data collection (rollout) steps and training steps. - The parameter `step_per_collect` controls the amount of data that is + The parameter `collection_step_num_env_steps` controls the amount of data that is collected in each collection step and after each collection step, we perform a training step, applying a gradient-based update based on a sample of data (`batch_size=64`) taken from the buffer of data that has been @@ -299,10 +300,10 @@ The experiment builder takes three arguments: We then proceed to configure some of the parameters of the DQN algorithm itself and of the neural network model we want to use. -A DQN-specific detail is the way in which we control the epsilon parameter for -exploration. -We want to use random exploration during rollouts for training (`eps_training`), -but we don't when evaluating the agent's performance in the test environments +A DQN-specific detail is the way in which we control the epsilon parameter for +exploration. +We want to use random exploration during rollouts for training (`eps_training`), +but we don't when evaluating the agent's performance in the test environments (`eps_inference`). Find the script in [examples/discrete/discrete_dqn_hl.py](examples/discrete/discrete_dqn_hl.py). @@ -340,7 +341,7 @@ train_num, test_num = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 -step_per_epoch, step_per_collect = 10000, 10 +epoch_num_steps, collection_step_num_env_steps = 10000, 10 ``` Initialize the logger: @@ -400,11 +401,11 @@ result = ts.trainer.OffPolicyTrainer( train_collector=train_collector, test_collector=test_collector, max_epoch=epoch, - step_per_epoch=step_per_epoch, - step_per_collect=step_per_collect, + epoch_num_steps=epoch_num_steps, + collection_step_num_env_steps=collection_step_num_env_steps, episode_per_test=test_num, batch_size=batch_size, - update_per_step=1 / step_per_collect, + update_per_step=1 / collection_step_num_env_steps, train_fn=lambda epoch, env_step: policy.set_eps_training(eps_train), test_fn=lambda epoch, env_step: policy.set_eps_training(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, diff --git a/docs/01_tutorials/00_dqn.rst b/docs/01_tutorials/00_dqn.rst index bb73d4c52..c77030800 100644 --- a/docs/01_tutorials/00_dqn.rst +++ b/docs/01_tutorials/00_dqn.rst @@ -198,7 +198,7 @@ reaches the stop condition ``stop_fn`` on test collector. Since DQN is an off-po policy=policy, train_collector=train_collector, test_collector=test_collector, - max_epoch=10, step_per_epoch=10000, step_per_collect=10, + max_epoch=10, epoch_num_steps=10000, collection_step_num_env_steps=10, update_per_step=0.1, episode_per_test=100, batch_size=64, train_fn=lambda epoch, env_step: policy.set_eps(0.1), test_fn=lambda epoch, env_step: policy.set_eps(0.05), @@ -209,8 +209,8 @@ reaches the stop condition ``stop_fn`` on test collector. Since DQN is an off-po The meaning of each parameter is as follows (full description can be found at :class:`~tianshou.trainer.OffpolicyTrainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; -* ``step_per_epoch``: The number of environment step (a.k.a. transition) collected per epoch; -* ``step_per_collect``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; +* ``epoch_num_steps``: The number of environment step (a.k.a. transition) collected per epoch; +* ``collection_step_num_env_steps``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst index 300d17019..22d712024 100644 --- a/docs/01_tutorials/04_tictactoe.rst +++ b/docs/01_tutorials/04_tictactoe.rst @@ -224,15 +224,15 @@ The explanation of each Tianshou class/function will be deferred to their first parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=50) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--epoch_num_steps', type=int, default=1000) + parser.add_argument('--collection_step_num_env_steps', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--batch_size', type=int, default=64) parser.add_argument( '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] ) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--num_train_envs', type=int, default=10) + parser.add_argument('--num_test_envs', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) parser.add_argument( @@ -356,7 +356,7 @@ With the above preparation, we are close to the first learned agent. The followi ) -> Tuple[dict, BasePolicy]: # ======== environment setup ========= - train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -378,7 +378,7 @@ With the above preparation, we are close to the first learned agent. The followi ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # ======== tensorboard logging setup ========= log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') @@ -416,8 +416,8 @@ With the above preparation, we are close to the first learned agent. The followi train_collector, test_collector, args.epoch, - args.step_per_epoch, - args.step_per_collect, + args.epoch_num_steps, + args.collection_step_num_env_steps, args.test_num, args.batch_size, train_fn=train_fn, diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index 4c983deb3..0d9495839 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -1,250 +1,250 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "editable": true, - "id": "r7aE6Rq3cAEE", - "slideshow": { - "slide_type": "" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "id": "r7aE6Rq3cAEE", + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Overview\n", + "To begin, ensure you have Tianshou and the Gym environment installed by executing the following commands. This tutorials will always keep up with the latest version of Tianshou since they also serve as a test for the latest version. For users on older versions of Tianshou, please consult the [documentation](https://tianshou.readthedocs.io/en/latest/) corresponding to your version..\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1_mLTSEIcY2c" + }, + "source": [ + "## Run the code" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IcFNmCjYeIIU" + }, + "source": [ + "Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v1\n", + "problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n", + "exactly what this tutorial is for.\n", + "\n", + "If the script ends normally, you will see the evaluation result printed out before the first\n", + "epoch is finished." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-cell", + "remove-output" + ] + }, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "import gymnasium as gym\n", + "import torch\n", + "\n", + "from tianshou.algorithm import PPOPolicy\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.trainer import OnpolicyTrainer\n", + "from tianshou.utils.net.common import ActorCritic, MLPActor\n", + "from tianshou.utils.net.discrete import Actor, Critic\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "# environments\n", + "env = gym.make(\"CartPole-v1\")\n", + "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n", + "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])\n", + "\n", + "# model & optimizer\n", + "assert env.observation_space.shape is not None # for mypy\n", + "net = MLPActor(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", + "\n", + "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", + "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", + "critic = Critic(preprocess_net=net, device=device).to(device)\n", + "actor_critic = ActorCritic(actor, critic)\n", + "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n", + "\n", + "# PPO policy\n", + "dist = torch.distributions.Categorical\n", + "policy: PPOPolicy = PPOPolicy(\n", + " actor=actor,\n", + " critic=critic,\n", + " optim=optim,\n", + " dist_fn=dist,\n", + " action_space=env.action_space,\n", + " action_scaling=False,\n", + ")\n", + "\n", + "# collector\n", + "train_collector = Collector[CollectStats](\n", + " policy,\n", + " train_envs,\n", + " VectorReplayBuffer(20000, len(train_envs)),\n", + ")\n", + "test_collector = Collector[CollectStats](policy, test_envs)\n", + "\n", + "# trainer\n", + "train_result = OnpolicyTrainer(\n", + " policy=policy,\n", + " batch_size=256,\n", + " train_collector=train_collector,\n", + " test_collector=test_collector,\n", + " max_epoch=10,\n", + " epoch_num_steps=50000,\n", + " update_step_num_repetitions=10,\n", + " episode_per_test=10,\n", + " collection_step_num_env_steps=2000,\n", + " stop_fn=lambda mean_reward: mean_reward >= 195,\n", + ").run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "train_result.pprint_asdict()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "G9YEQptYvCgx", + "outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a" + }, + "outputs": [], + "source": [ + "# Let's watch its performance!\n", + "policy.eval()\n", + "eval_result = test_collector.collect(n_episode=3, render=False)\n", + "print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xFYlcPo8fpPU" + }, + "source": [ + "## Tutorial Introduction\n", + "\n", + "A common DRL experiment as is shown above may require many components to work together. The agent, the\n", + "environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n", + "training task.\n", + "\n", + "
\n", + "\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kV_uOyimj-bk" + }, + "source": [ + "In Tianshou, all of these main components are factored out as different building blocks, which you\n", + "can use to create your own algorithm and finish your own experiment.\n", + "\n", + "Building blocks may include:\n", + "- Batch\n", + "- Replay Buffer\n", + "- Vectorized Environment Wrapper\n", + "- Policy (the agent and the training algorithm)\n", + "- Data Collector\n", + "- Trainer\n", + "- Logger\n", + "\n", + "\n", + "These notebooks tutorials will guide you through all the modules one by one." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S0mNKwH9i6Ek" + }, + "source": [ + "## Further reading" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M3NPSUnAov4L" + }, + "source": [ + "### What if I am not familiar with the PPO algorithm itself?\n", + "As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n", + "plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's tutorials, we will\n", + "focus on the usages of different modules, but not the algorithms themselves." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } }, - "tags": [] - }, - "source": [ - "# Overview\n", - "To begin, ensure you have Tianshou and the Gym environment installed by executing the following commands. This tutorials will always keep up with the latest version of Tianshou since they also serve as a test for the latest version. For users on older versions of Tianshou, please consult the [documentation](https://tianshou.readthedocs.io/en/latest/) corresponding to your version..\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1_mLTSEIcY2c" - }, - "source": [ - "## Run the code" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IcFNmCjYeIIU" - }, - "source": [ - "Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v1\n", - "problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n", - "exactly what this tutorial is for.\n", - "\n", - "If the script ends normally, you will see the evaluation result printed out before the first\n", - "epoch is finished." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.algorithm import PPOPolicy\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import ActorCritic, MLPActor\n", - "from tianshou.utils.net.discrete import Actor, Critic\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# environments\n", - "env = gym.make(\"CartPole-v1\")\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])\n", - "\n", - "# model & optimizer\n", - "assert env.observation_space.shape is not None # for mypy\n", - "net = MLPActor(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", - "\n", - "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", - "critic = Critic(preprocess_net=net, device=device).to(device)\n", - "actor_critic = ActorCritic(actor, critic)\n", - "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n", - "\n", - "# PPO policy\n", - "dist = torch.distributions.Categorical\n", - "policy: PPOPolicy = PPOPolicy(\n", - " actor=actor,\n", - " critic=critic,\n", - " optim=optim,\n", - " dist_fn=dist,\n", - " action_space=env.action_space,\n", - " action_scaling=False,\n", - ")\n", - "\n", - "# collector\n", - "train_collector = Collector[CollectStats](\n", - " policy,\n", - " train_envs,\n", - " VectorReplayBuffer(20000, len(train_envs)),\n", - ")\n", - "test_collector = Collector[CollectStats](policy, test_envs)\n", - "\n", - "# trainer\n", - "train_result = OnpolicyTrainer(\n", - " policy=policy,\n", - " batch_size=256,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=50000,\n", - " repeat_per_collect=10,\n", - " episode_per_test=10,\n", - " step_per_collect=2000,\n", - " stop_fn=lambda mean_reward: mean_reward >= 195,\n", - ").run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "train_result.pprint_asdict()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "G9YEQptYvCgx", - "outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a" - }, - "outputs": [], - "source": [ - "# Let's watch its performance!\n", - "policy.eval()\n", - "eval_result = test_collector.collect(n_episode=3, render=False)\n", - "print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xFYlcPo8fpPU" - }, - "source": [ - "## Tutorial Introduction\n", - "\n", - "A common DRL experiment as is shown above may require many components to work together. The agent, the\n", - "environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n", - "training task.\n", - "\n", - "
\n", - "\n", - "\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kV_uOyimj-bk" - }, - "source": [ - "In Tianshou, all of these main components are factored out as different building blocks, which you\n", - "can use to create your own algorithm and finish your own experiment.\n", - "\n", - "Building blocks may include:\n", - "- Batch\n", - "- Replay Buffer\n", - "- Vectorized Environment Wrapper\n", - "- Policy (the agent and the training algorithm)\n", - "- Data Collector\n", - "- Trainer\n", - "- Logger\n", - "\n", - "\n", - "These notebooks tutorials will guide you through all the modules one by one." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S0mNKwH9i6Ek" - }, - "source": [ - "## Further reading" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "M3NPSUnAov4L" - }, - "source": [ - "### What if I am not familiar with the PPO algorithm itself?\n", - "As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n", - "plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's tutorials, we will\n", - "focus on the usages of different modules, but not the algorithms themselves." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index ffa18168b..1bd8671e4 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -1,283 +1,283 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "S3-tJZy35Ck_" - }, - "source": [ - "# Trainer\n", - "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n", - "\n", - "
\n", - "\n", - "
\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ifsEQMzZ6mmz" - }, - "source": [ - "## Usages\n", - "In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XfsuU2AAE52C" - }, - "source": [ - "### Pseudocode\n", - "
\n", - "\n", - "
\n", - "\n", - "For the on-policy trainer, the main difference is that we clear the buffer after Line 10." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Hcp_o0CCFz12" - }, - "source": [ - "### Training without trainer\n", - "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n", - "\n", - "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "do-xZ-8B7nVH", - "slideshow": { - "slide_type": "" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "S3-tJZy35Ck_" + }, + "source": [ + "# Trainer\n", + "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ifsEQMzZ6mmz" + }, + "source": [ + "## Usages\n", + "In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XfsuU2AAE52C" + }, + "source": [ + "### Pseudocode\n", + "
\n", + "\n", + "
\n", + "\n", + "For the on-policy trainer, the main difference is that we clear the buffer after Line 10." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hcp_o0CCFz12" + }, + "source": [ + "### Training without trainer\n", + "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n", + "\n", + "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "id": "do-xZ-8B7nVH", + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-cell", + "remove-output" + ] + }, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "import gymnasium as gym\n", + "import torch\n", + "\n", + "from tianshou.algorithm import PGPolicy\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.trainer import OnpolicyTrainer\n", + "from tianshou.utils.net.common import Net\n", + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_env_num = 4\n", + "# Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n", + "buffer_size = 2000\n", + "\n", + "\n", + "# Create the environments, used for training and evaluation\n", + "env = gym.make(\"CartPole-v1\")\n", + "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", + "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n", + "\n", + "# Create the Policy instance\n", + "assert env.observation_space.shape is not None\n", + "net = Net(\n", + " env.observation_space.shape,\n", + " hidden_sizes=[\n", + " 16,\n", + " ],\n", + ")\n", + "\n", + "assert isinstance(env.action_space, gym.spaces.Discrete)\n", + "actor = Actor(net, env.action_space.n)\n", + "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n", + "\n", + "# We choose to use REINFORCE algorithm, also known as Policy Gradient\n", + "policy: PGPolicy = PGPolicy(\n", + " actor=actor,\n", + " optim=optim,\n", + " dist_fn=torch.distributions.Categorical,\n", + " action_space=env.action_space,\n", + " action_scaling=False,\n", + ")\n", + "\n", + "# Create the replay buffer and the collector\n", + "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", + "test_collector = Collector[CollectStats](policy, test_envs)\n", + "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wiEGiBgQIiFM" + }, + "source": [ + "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMUNPN5SI_kd", + "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d" + }, + "outputs": [], + "source": [ + "train_collector.reset()\n", + "train_envs.reset()\n", + "test_collector.reset()\n", + "test_envs.reset()\n", + "replayBuffer.reset()\n", + "\n", + "n_episode = 10\n", + "for _i in range(n_episode):\n", + " # for test collector, we set the wrapped torch module to evaluation mode\n", + " # by default, the policy object itself is not within the training step\n", + " with torch_train_mode(policy, enabled=False):\n", + " evaluation_result = test_collector.collect(n_episode=n_episode)\n", + " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", + " # for collecting data for training, the policy object should be within the training step\n", + " # (affecting e.g. whether the policy is stochastic or deterministic)\n", + " with policy_within_training_step(policy):\n", + " train_collector.collect(n_step=2000)\n", + " # 0 means taking all data stored in train_collector.buffer\n", + " # for updating the policy, the wrapped torch module should be in training mode\n", + " with torch_train_mode(policy):\n", + " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", + " train_collector.reset_buffer(keep_statistics=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QXBHIBckMs_2" + }, + "source": [ + "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p-7U_cwgF5Ej" + }, + "source": [ + "### Training with trainer\n", + "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vcvw9J8RNtFE", + "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5", + "tags": [ + "remove-output" + ] + }, + "outputs": [], + "source": [ + "train_collector.reset()\n", + "train_envs.reset()\n", + "test_collector.reset()\n", + "test_envs.reset()\n", + "replayBuffer.reset()\n", + "\n", + "result = OnpolicyTrainer(\n", + " policy=policy,\n", + " train_collector=train_collector,\n", + " test_collector=test_collector,\n", + " max_epoch=10,\n", + " epoch_num_steps=1,\n", + " update_step_num_repetitions=1,\n", + " episode_per_test=10,\n", + " collection_step_num_env_steps=2000,\n", + " batch_size=512,\n", + ").run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "result.pprint_asdict()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_j3aUJZQ7nml" + }, + "source": [ + "## Further Reading\n", + "### Logger usages\n", + "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.org/en/master/03_api/utils/logger/base.html#tianshou.utils.logger.base.BaseLogger) for details.\n", + "\n", + "### Learn more about the APIs of Trainers\n", + "[documentation](https://tianshou.org/en/master/03_api/trainer/index.html)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "S3-tJZy35Ck_", + "XfsuU2AAE52C", + "p-7U_cwgF5Ej", + "_j3aUJZQ7nml" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.algorithm import PGPolicy\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor\n", - "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_env_num = 4\n", - "# Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n", - "buffer_size = 2000\n", - "\n", - "\n", - "# Create the environments, used for training and evaluation\n", - "env = gym.make(\"CartPole-v1\")\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n", - "\n", - "# Create the Policy instance\n", - "assert env.observation_space.shape is not None\n", - "net = Net(\n", - " env.observation_space.shape,\n", - " hidden_sizes=[\n", - " 16,\n", - " ],\n", - ")\n", - "\n", - "assert isinstance(env.action_space, gym.spaces.Discrete)\n", - "actor = Actor(net, env.action_space.n)\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n", - "\n", - "# We choose to use REINFORCE algorithm, also known as Policy Gradient\n", - "policy: PGPolicy = PGPolicy(\n", - " actor=actor,\n", - " optim=optim,\n", - " dist_fn=torch.distributions.Categorical,\n", - " action_space=env.action_space,\n", - " action_scaling=False,\n", - ")\n", - "\n", - "# Create the replay buffer and the collector\n", - "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", - "test_collector = Collector[CollectStats](policy, test_envs)\n", - "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wiEGiBgQIiFM" - }, - "source": [ - "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMUNPN5SI_kd", - "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d" - }, - "outputs": [], - "source": [ - "train_collector.reset()\n", - "train_envs.reset()\n", - "test_collector.reset()\n", - "test_envs.reset()\n", - "replayBuffer.reset()\n", - "\n", - "n_episode = 10\n", - "for _i in range(n_episode):\n", - " # for test collector, we set the wrapped torch module to evaluation mode\n", - " # by default, the policy object itself is not within the training step\n", - " with torch_train_mode(policy, enabled=False):\n", - " evaluation_result = test_collector.collect(n_episode=n_episode)\n", - " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", - " # for collecting data for training, the policy object should be within the training step\n", - " # (affecting e.g. whether the policy is stochastic or deterministic)\n", - " with policy_within_training_step(policy):\n", - " train_collector.collect(n_step=2000)\n", - " # 0 means taking all data stored in train_collector.buffer\n", - " # for updating the policy, the wrapped torch module should be in training mode\n", - " with torch_train_mode(policy):\n", - " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", - " train_collector.reset_buffer(keep_statistics=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QXBHIBckMs_2" - }, - "source": [ - "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p-7U_cwgF5Ej" - }, - "source": [ - "### Training with trainer\n", - "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vcvw9J8RNtFE", - "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5", - "tags": [ - "remove-output" - ] - }, - "outputs": [], - "source": [ - "train_collector.reset()\n", - "train_envs.reset()\n", - "test_collector.reset()\n", - "test_envs.reset()\n", - "replayBuffer.reset()\n", - "\n", - "result = OnpolicyTrainer(\n", - " policy=policy,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=1,\n", - " repeat_per_collect=1,\n", - " episode_per_test=10,\n", - " step_per_collect=2000,\n", - " batch_size=512,\n", - ").run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_j3aUJZQ7nml" - }, - "source": [ - "## Further Reading\n", - "### Logger usages\n", - "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.org/en/master/03_api/utils/logger/base.html#tianshou.utils.logger.base.BaseLogger) for details.\n", - "\n", - "### Learn more about the APIs of Trainers\n", - "[documentation](https://tianshou.org/en/master/03_api/trainer/index.html)" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [ - "S3-tJZy35Ck_", - "XfsuU2AAE52C", - "p-7U_cwgF5Ej", - "_j3aUJZQ7nml" - ], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 9da33c98b..0f256f510 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -1,341 +1,341 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "_UaXOSRjDUF9" - }, - "source": [ - "# Experiment\n", - "Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2QRbCJvDHNAd" - }, - "source": [ - "## Experiment\n", - "To conduct this experiment, we need the following building blocks.\n", - "\n", - "\n", - "* Two vectorized environments, one for training and one for evaluation\n", - "* A PPO agent\n", - "* A replay buffer to store transition data\n", - "* Two collectors to manage the data collecting process, one for training and one for evaluation\n", - "* A trainer to manage the training loop\n", - "\n", - "
\n", - "\n", - "\n", - "
\n", - "\n", - "Let us do this step by step." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-Hh4E6i0Hj0I" - }, - "source": [ - "## Preparation\n", - "Firstly, install Tianshou if you haven't installed it before." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7E4EhiBeHxD5" - }, - "source": [ - "Import libraries we might need later." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "ao9gWJDiHgG-", - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.algorithm import PPOPolicy\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import ActorCritic, MLPActor\n", - "from tianshou.utils.net.discrete import Actor, Critic\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QnRg5y7THRYw" - }, - "source": [ - "## Environment" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YZERKCGtH8W1" - }, - "source": [ - "We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Mpuj5PFnDKVS" - }, - "outputs": [], - "source": [ - "env = gym.make(\"CartPole-v1\")\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BJtt_Ya8DTAh" - }, - "source": [ - "## Policy\n", - "Next we need to initialize our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n", - "\n", - "The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n", - "\n", - "Luckily, Tianshou already provides basic network modules that we can use in this experiment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_Vy8uPWXP4m_" - }, - "outputs": [], - "source": [ - "# net is the shared head of the actor and the critic\n", - "assert env.observation_space.shape is not None # for mypy\n", - "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "net = MLPActor(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", - "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", - "critic = Critic(preprocess_net=net, device=device).to(device)\n", - "actor_critic = ActorCritic(actor=actor, critic=critic)\n", - "\n", - "# optimizer of the actor and the critic\n", - "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Lh2-hwE5Dn9I" - }, - "source": [ - "Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OiJ2GkT0Qnbr" - }, - "outputs": [], - "source": [ - "dist = torch.distributions.Categorical\n", - "policy: PPOPolicy = PPOPolicy(\n", - " actor=actor,\n", - " critic=critic,\n", - " optim=optim,\n", - " dist_fn=dist,\n", - " action_space=env.action_space,\n", - " deterministic_eval=True,\n", - " action_scaling=False,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "okxfj6IEQ-r8" - }, - "source": [ - "`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n5XAAbuBZarO" - }, - "source": [ - "## Collector\n", - "We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ezwz0qerZhQM" - }, - "outputs": [], - "source": [ - "train_collector = Collector[CollectStats](\n", - " policy=policy,\n", - " env=train_envs,\n", - " buffer=VectorReplayBuffer(20000, len(train_envs)),\n", - ")\n", - "test_collector = Collector[CollectStats](policy=policy, env=test_envs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZaoPxOd2hm0b" - }, - "source": [ - "We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qBoE9pLUiC-8" - }, - "source": [ - "## Trainer\n", - "Finally, we can use the trainer to help us set up the training loop." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "_UaXOSRjDUF9" + }, + "source": [ + "# Experiment\n", + "Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2QRbCJvDHNAd" + }, + "source": [ + "## Experiment\n", + "To conduct this experiment, we need the following building blocks.\n", + "\n", + "\n", + "* Two vectorized environments, one for training and one for evaluation\n", + "* A PPO agent\n", + "* A replay buffer to store transition data\n", + "* Two collectors to manage the data collecting process, one for training and one for evaluation\n", + "* A trainer to manage the training loop\n", + "\n", + "
\n", + "\n", + "\n", + "
\n", + "\n", + "Let us do this step by step." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-Hh4E6i0Hj0I" + }, + "source": [ + "## Preparation\n", + "Firstly, install Tianshou if you haven't installed it before." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7E4EhiBeHxD5" + }, + "source": [ + "Import libraries we might need later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "id": "ao9gWJDiHgG-", + "tags": [ + "hide-cell", + "remove-output" + ] + }, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "import gymnasium as gym\n", + "import torch\n", + "\n", + "from tianshou.algorithm import PPOPolicy\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.trainer import OnpolicyTrainer\n", + "from tianshou.utils.net.common import ActorCritic, MLPActor\n", + "from tianshou.utils.net.discrete import Actor, Critic\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QnRg5y7THRYw" + }, + "source": [ + "## Environment" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YZERKCGtH8W1" + }, + "source": [ + "We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Mpuj5PFnDKVS" + }, + "outputs": [], + "source": [ + "env = gym.make(\"CartPole-v1\")\n", + "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n", + "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BJtt_Ya8DTAh" + }, + "source": [ + "## Policy\n", + "Next we need to initialize our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n", + "\n", + "The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n", + "\n", + "Luckily, Tianshou already provides basic network modules that we can use in this experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_Vy8uPWXP4m_" + }, + "outputs": [], + "source": [ + "# net is the shared head of the actor and the critic\n", + "assert env.observation_space.shape is not None # for mypy\n", + "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", + "net = MLPActor(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", + "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", + "critic = Critic(preprocess_net=net, device=device).to(device)\n", + "actor_critic = ActorCritic(actor=actor, critic=critic)\n", + "\n", + "# optimizer of the actor and the critic\n", + "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lh2-hwE5Dn9I" + }, + "source": [ + "Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OiJ2GkT0Qnbr" + }, + "outputs": [], + "source": [ + "dist = torch.distributions.Categorical\n", + "policy: PPOPolicy = PPOPolicy(\n", + " actor=actor,\n", + " critic=critic,\n", + " optim=optim,\n", + " dist_fn=dist,\n", + " action_space=env.action_space,\n", + " deterministic_eval=True,\n", + " action_scaling=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "okxfj6IEQ-r8" + }, + "source": [ + "`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n5XAAbuBZarO" + }, + "source": [ + "## Collector\n", + "We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ezwz0qerZhQM" + }, + "outputs": [], + "source": [ + "train_collector = Collector[CollectStats](\n", + " policy=policy,\n", + " env=train_envs,\n", + " buffer=VectorReplayBuffer(20000, len(train_envs)),\n", + ")\n", + "test_collector = Collector[CollectStats](policy=policy, env=test_envs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZaoPxOd2hm0b" + }, + "source": [ + "We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qBoE9pLUiC-8" + }, + "source": [ + "## Trainer\n", + "Finally, we can use the trainer to help us set up the training loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "editable": true, + "id": "i45EDnpxQ8gu", + "outputId": "b1666b88-0bfa-4340-868e-58611872d988", + "tags": [ + "remove-output" + ] + }, + "outputs": [], + "source": [ + "result = OnpolicyTrainer(\n", + " policy=policy,\n", + " train_collector=train_collector,\n", + " test_collector=test_collector,\n", + " max_epoch=10,\n", + " epoch_num_steps=50000,\n", + " update_step_num_repetitions=10,\n", + " episode_per_test=10,\n", + " batch_size=256,\n", + " collection_step_num_env_steps=2000,\n", + " stop_fn=lambda mean_reward: mean_reward >= 195,\n", + ").run()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ckgINHE2iTFR" + }, + "source": [ + "## Results\n", + "Print the training result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tJCPgmiyiaaX", + "outputId": "40123ae3-3365-4782-9563-46c43812f10f", + "tags": [] + }, + "outputs": [], + "source": [ + "result.pprint_asdict()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "A-MJ9avMibxN" + }, + "source": [ + "We can also test our trained agent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mnMANFcciiAQ", + "outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21" + }, + "outputs": [], + "source": [ + "# Let's watch its performance!\n", + "policy.eval()\n", + "result = test_collector.collect(n_episode=1, render=False)\n", + "print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } }, - "editable": true, - "id": "i45EDnpxQ8gu", - "outputId": "b1666b88-0bfa-4340-868e-58611872d988", - "tags": [ - "remove-output" - ] - }, - "outputs": [], - "source": [ - "result = OnpolicyTrainer(\n", - " policy=policy,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=50000,\n", - " repeat_per_collect=10,\n", - " episode_per_test=10,\n", - " batch_size=256,\n", - " step_per_collect=2000,\n", - " stop_fn=lambda mean_reward: mean_reward >= 195,\n", - ").run()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ckgINHE2iTFR" - }, - "source": [ - "## Results\n", - "Print the training result." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tJCPgmiyiaaX", - "outputId": "40123ae3-3365-4782-9563-46c43812f10f", - "tags": [] - }, - "outputs": [], - "source": [ - "result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "A-MJ9avMibxN" - }, - "source": [ - "We can also test our trained agent." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mnMANFcciiAQ", - "outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21" - }, - "outputs": [], - "source": [ - "# Let's watch its performance!\n", - "policy.eval()\n", - "result = test_collector.collect(n_episode=1, render=False)\n", - "print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")" - ] - } - ], - "metadata": { - "colab": { - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/examples/atari/README.md b/examples/atari/README.md index 62e58487b..ae80141d8 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -24,13 +24,13 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | time cost | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- | -| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch-size 64` | ~30 min (~15 epoch) | -| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test-num 100` | 3~4h (100 epoch) | -| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | -| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) | +| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | +| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --num_test_envs 100` | 3~4h (100 epoch) | +| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | +| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | +| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | +| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --num_test_envs 100` | 3~4h (100 epoch) | Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. @@ -42,7 +42,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 20 | ![](results/c51/Pong_rew.png) | `python3 atari_c51.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 536.6 | ![](results/c51/Breakout_rew.png) | `python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1032 | ![](results/c51/Enduro_rew.png) | `python3 atari_c51.py --task "EnduroNoFrameskip-v4 " ` | | QbertNoFrameskip-v4 | 16245 | ![](results/c51/Qbert_rew.png) | `python3 atari_c51.py --task "QbertNoFrameskip-v4"` | @@ -58,7 +58,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20 | ![](results/qrdqn/Pong_rew.png) | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 20 | ![](results/qrdqn/Pong_rew.png) | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 409.2 | ![](results/qrdqn/Breakout_rew.png) | `python3 atari_qrdqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1055.9 | ![](results/qrdqn/Enduro_rew.png) | `python3 atari_qrdqn.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 14990 | ![](results/qrdqn/Qbert_rew.png) | `python3 atari_qrdqn.py --task "QbertNoFrameskip-v4"` | @@ -72,7 +72,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.3 | ![](results/iqn/Pong_rew.png) | `python3 atari_iqn.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 20.3 | ![](results/iqn/Pong_rew.png) | `python3 atari_iqn.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 496.7 | ![](results/iqn/Breakout_rew.png) | `python3 atari_iqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1545 | ![](results/iqn/Enduro_rew.png) | `python3 atari_iqn.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 15342.5 | ![](results/iqn/Qbert_rew.png) | `python3 atari_iqn.py --task "QbertNoFrameskip-v4"` | @@ -86,7 +86,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.7 | ![](results/fqf/Pong_rew.png) | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 20.7 | ![](results/fqf/Pong_rew.png) | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 517.3 | ![](results/fqf/Breakout_rew.png) | `python3 atari_fqf.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 2240.5 | ![](results/fqf/Enduro_rew.png) | `python3 atari_fqf.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 16172.5 | ![](results/fqf/Qbert_rew.png) | `python3 atari_fqf.py --task "QbertNoFrameskip-v4"` | @@ -100,7 +100,7 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch-size 64` | +| PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch_size 64` | | BreakoutNoFrameskip-v4 | 684.6 | ![](results/rainbow/Breakout_rew.png) | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` | | EnduroNoFrameskip-v4 | 1625.9 | ![](results/rainbow/Enduro_rew.png) | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` | | QbertNoFrameskip-v4 | 16192.5 | ![](results/rainbow/Qbert_rew.png) | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` | diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index b24ca107b..537d5bdb8 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -22,7 +22,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.005) parser.add_argument("--eps-train", type=float, default=1.0) parser.add_argument("--eps-train-final", type=float, default=0.05) @@ -35,12 +35,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -72,7 +72,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, @@ -200,15 +200,15 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # trainer result = algorithm.run_training( OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index df046018b..e23d238f8 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -24,7 +24,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.005) parser.add_argument("--eps-train", type=float, default=1.0) parser.add_argument("--eps-train-final", type=float, default=0.05) @@ -34,12 +34,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -89,7 +89,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, @@ -242,7 +242,7 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( @@ -250,8 +250,8 @@ def watch() -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 25dc84c72..1bdcc15bd 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -38,11 +38,11 @@ def main( n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, - step_per_epoch: int = 100000, - step_per_collect: int = 10, + epoch_num_steps: int = 100000, + collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, - training_num: int = 10, + num_train_envs: int = 10, test_num: int = 10, frames_stack: int = 4, icm_lr_scale: float = 0.0, @@ -53,12 +53,12 @@ def main( training_config = OffPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 49fe491b4..63f098b34 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -23,7 +23,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=3128) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.005) parser.add_argument("--eps-train", type=float, default=1.0) parser.add_argument("--eps-train-final", type=float, default=0.05) @@ -38,12 +38,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -75,7 +75,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, @@ -216,7 +216,7 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( @@ -224,8 +224,8 @@ def watch() -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 591484f8f..12ab60420 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -23,7 +23,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1234) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.005) parser.add_argument("--eps-train", type=float, default=1.0) parser.add_argument("--eps-train-final", type=float, default=0.05) @@ -38,12 +38,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -75,7 +75,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, @@ -210,7 +210,7 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( @@ -218,8 +218,8 @@ def watch() -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index fc1a23156..86de5cbce 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -39,11 +39,11 @@ def main( n_step: int = 3, target_update_freq: int = 500, epoch: int = 100, - step_per_epoch: int = 100000, - step_per_collect: int = 10, + epoch_num_steps: int = 100000, + collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 32, - training_num: int = 10, + num_train_envs: int = 10, test_num: int = 10, frames_stack: int = 4, ) -> None: @@ -51,12 +51,12 @@ def main( training_config = OffPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 9fd340d99..b34e61e2d 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -34,19 +34,19 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) - parser.add_argument("--scale-obs", type=int, default=1) + parser.add_argument("--scale_obs", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=100000) parser.add_argument("--lr", type=float, default=2.5e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=1000) - parser.add_argument("--repeat-per-collect", type=int, default=4) - parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1000) + parser.add_argument("--update_step_num_repetitions", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--hidden-size", type=int, default=512) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--rew-norm", type=int, default=False) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--return_scaling", type=int, default=False) parser.add_argument("--vf-coef", type=float, default=0.25) parser.add_argument("--ent-coef", type=float, default=0.01) parser.add_argument("--gae-lambda", type=float, default=0.95) @@ -55,7 +55,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--eps-clip", type=float, default=0.1) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -106,7 +106,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, scale=0, frame_stack=args.frames_stack, @@ -141,8 +141,8 @@ def main(args: argparse.Namespace = get_args()) -> None: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) @@ -160,11 +160,11 @@ def main(args: argparse.Namespace = get_args()) -> None: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - return_scaling=args.rew_norm, + return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: @@ -272,7 +272,7 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( @@ -280,11 +280,11 @@ def watch() -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 2f1fdcb7a..a071acf49 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -31,14 +31,14 @@ def main( lr: float = 2.5e-4, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 100000, - step_per_collect: int = 1000, - repeat_per_collect: int = 4, + epoch_num_steps: int = 100000, + collection_step_num_env_steps: int = 1000, + update_step_num_repetitions: int = 4, batch_size: int = 256, hidden_sizes: Sequence[int] = (512,), - training_num: int = 10, + num_train_envs: int = 10, test_num: int = 10, - rew_norm: bool = False, + return_scaling: bool = False, vf_coef: float = 0.25, ent_coef: float = 0.01, gae_lambda: float = 0.95, @@ -47,7 +47,7 @@ def main( eps_clip: float = 0.1, dual_clip: float | None = None, value_clip: bool = True, - norm_adv: bool = True, + advantage_normalization: bool = True, recompute_adv: bool = False, frames_stack: int = 4, save_buffer_name: str | None = None, # TODO add support in high-level API? @@ -59,13 +59,13 @@ def main( training_config = OnPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_step_num_repetitions=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -85,12 +85,12 @@ def main( PPOParams( gamma=gamma, gae_lambda=gae_lambda, - return_scaling=rew_norm, + return_scaling=return_scaling, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, value_clip=value_clip, - advantage_normalization=norm_adv, + advantage_normalization=advantage_normalization, eps_clip=eps_clip, dual_clip=dual_clip, recompute_advantage=recompute_adv, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index fc1e1a125..c56c81d26 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -22,7 +22,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.005) parser.add_argument("--eps-train", type=float, default=1.0) parser.add_argument("--eps-train-final", type=float, default=0.05) @@ -33,12 +33,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -70,7 +70,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, @@ -204,7 +204,7 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( @@ -212,8 +212,8 @@ def watch() -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index c86555ecf..af69814ee 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -27,7 +27,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.005) parser.add_argument("--eps-train", type=float, default=1.0) parser.add_argument("--eps-train-final", type=float, default=0.05) @@ -49,12 +49,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -86,7 +86,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, @@ -247,7 +247,7 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( @@ -255,8 +255,8 @@ def watch() -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 7e1774424..752f2888c 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -28,7 +28,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--buffer-size", type=int, default=100000) parser.add_argument("--actor-lr", type=float, default=1e-5) parser.add_argument("--critic-lr", type=float, default=1e-5) @@ -39,14 +39,14 @@ def get_args() -> argparse.Namespace: parser.add_argument("--auto-alpha", action="store_true", default=False) parser.add_argument("--alpha-lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-size", type=int, default=512) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--rew-norm", type=int, default=False) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--return_scaling", type=int, default=False) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -96,7 +96,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, @@ -258,7 +258,7 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( @@ -266,8 +266,8 @@ def watch() -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index ddbe98840..e8adbfb6b 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -37,12 +37,12 @@ def main( auto_alpha: bool = False, alpha_lr: float = 3e-4, epoch: int = 100, - step_per_epoch: int = 100000, - step_per_collect: int = 10, + epoch_num_steps: int = 100000, + collection_step_num_env_steps: int = 10, update_per_step: float = 0.1, batch_size: int = 64, hidden_sizes: Sequence[int] = (512,), - training_num: int = 10, + num_train_envs: int = 10, test_num: int = 10, frames_stack: int = 4, icm_lr_scale: float = 0.0, @@ -53,13 +53,13 @@ def main( training_config = OffPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, update_step_num_gradient_steps_per_sample=update_per_step, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, replay_buffer_stack_num=frames_stack, replay_buffer_ignore_obs_next=True, replay_buffer_save_only_last_obs=True, @@ -82,9 +82,11 @@ def main( critic2_lr=critic_lr, gamma=gamma, tau=tau, - alpha=AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98) - if auto_alpha - else alpha, + alpha=( + AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98) + if auto_alpha + else alpha + ), n_step_return_horizon=n_step, ), ) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 8c38ab3f8..17993c820 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -31,15 +31,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=100) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=100) parser.add_argument("--update-per-step", type=float, default=0.01) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128]) parser.add_argument("--dueling-q-hidden-sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--dueling-v-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -58,7 +58,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -98,7 +98,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) @@ -130,8 +130,8 @@ def train_fn(epoch: int, env_step: int) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index a987feaa5..e75f142a7 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -38,12 +38,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--target-update-freq", type=int, default=1000) parser.add_argument("--epoch", type=int, default=25) - parser.add_argument("--step-per-epoch", type=int, default=80000) - parser.add_argument("--step-per-collect", type=int, default=16) + parser.add_argument("--epoch_num_steps", type=int, default=80000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=16) parser.add_argument("--update-per-step", type=float, default=0.0625) - parser.add_argument("--batch-size", type=int, default=512) - parser.add_argument("--training-num", type=int, default=20) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=512) + parser.add_argument("--num_train_envs", type=int, default=20) + parser.add_argument("--num_test_envs", type=int, default=10) # other parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -77,7 +77,7 @@ def run_bdq(args: argparse.Namespace = get_args()) -> None: train_envs = SubprocVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) - for _ in range(args.training_num) + for _ in range(args.num_train_envs) ], ) # test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) @@ -123,7 +123,7 @@ def run_bdq(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") log_path = os.path.join(args.logdir, "bdq", args.task, current_time) @@ -148,8 +148,8 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index c87618019..d98e8613b 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -35,13 +35,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--auto-alpha", type=int, default=1) parser.add_argument("--alpha-lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--n-step", type=int, default=4) @@ -93,7 +93,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action train_envs = SubprocVectorEnv( - [lambda: Wrapper(gym.make(args.task)) for _ in range(args.training_num)], + [lambda: Wrapper(gym.make(args.task)) for _ in range(args.num_train_envs)], ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( @@ -195,8 +195,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 556949fe0..d32105fc4 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -32,15 +32,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=4) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=80000) - parser.add_argument("--step-per-collect", type=int, default=16) + parser.add_argument("--epoch_num_steps", type=int, default=80000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=16) parser.add_argument("--update-per-step", type=float, default=0.0625) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--dueling-q-hidden-sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--dueling-v-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -60,7 +60,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: args.max_action = space_info.action_info.max_action # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -100,7 +100,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) @@ -127,8 +127,8 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 4e9e1e370..3fd57599c 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -35,13 +35,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--auto_alpha", type=int, default=1) parser.add_argument("--alpha", type=float, default=0.2) parser.add_argument("--epoch", type=int, default=20) - parser.add_argument("--step-per-epoch", type=int, default=12000) - parser.add_argument("--step-per-collect", type=int, default=5) + parser.add_argument("--epoch_num_steps", type=int, default=12000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=5) parser.add_argument("--update-per-step", type=float, default=0.2) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=5) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=5) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -59,7 +59,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -144,8 +144,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 37bf9fea3..f6b171397 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -16,7 +16,7 @@ def main() -> None: gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 - step_per_epoch, step_per_collect = 10000, 10 + epoch_num_steps, collection_step_num_env_steps = 10000, 10 logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # TensorBoard is supported! # For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html @@ -72,11 +72,11 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=epoch, - epoch_num_steps=step_per_epoch, - collection_step_num_env_steps=step_per_collect, + epoch_num_steps=epoch_num_steps, + collection_step_num_env_steps=collection_step_num_env_steps, test_step_num_episodes=test_num, batch_size=batch_size, - update_step_num_gradient_steps_per_sample=1 / step_per_collect, + update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, stop_fn=stop_fn, logger=logger, test_in_train=True, diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index 59ead1e26..0ba102f2b 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -34,7 +34,7 @@ def main() -> None: num_train_envs=10, num_test_envs=100, buffer_size=20000, - step_per_collect=10, + collection_step_num_env_steps=10, update_step_num_gradient_steps_per_sample=1 / 10, ), ) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 331800b0f..fccb6453e 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -59,15 +59,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--disc-lr", type=float, default=2.5e-5) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument("--update_step_num_repetitions", type=int, default=10) parser.add_argument("--disc-update-num", type=int, default=2) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--training-num", type=int, default=64) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_train_envs", type=int, default=64) + parser.add_argument("--num_test_envs", type=int, default=10) # ppo special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) # In theory, `vf-coef` will not make any difference if using Adam optimizer. parser.add_argument("--vf-coef", type=float, default=0.25) parser.add_argument("--ent-coef", type=float, default=0.001) @@ -78,7 +78,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=0) parser.add_argument("--recompute-adv", type=int, default=1) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -108,7 +108,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: print("Action range:", args.min_action, args.max_action) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.training_num)], + [lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.num_train_envs)], ) train_envs = VectorEnvNormObs(train_envs) # test_envs = gym.make(args.task) @@ -173,8 +173,8 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) @@ -224,11 +224,11 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - return_scaling=args.rew_norm, + return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) @@ -239,7 +239,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # collector buffer: ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -263,11 +263,11 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/modelbased/README.md b/examples/modelbased/README.md index c3563f629..879847ac4 100644 --- a/examples/modelbased/README.md +++ b/examples/modelbased/README.md @@ -1,7 +1,7 @@ # PSRL -`NChain-v0`: `python3 psrl.py --task NChain-v0 --step-per-epoch 10 --rew-mean-prior 0 --rew-std-prior 1` +`NChain-v0`: `python3 psrl.py --task NChain-v0 --epoch_num_steps 10 --rew-mean-prior 0 --rew-std-prior 1` -`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20` +`FrozenLake-v0`: `python3 psrl.py --task FrozenLake-v0 --epoch_num_steps 1000 --rew-mean-prior 0 --rew-std-prior 1 --add-done-loop --epoch 20` -`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --step-per-epoch 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` +`Taxi-v3`: `python3 psrl.py --task Taxi-v3 --epoch_num_steps 1000 --rew-mean-prior 0 --rew-std-prior 2 --epoch 20` diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md index f633273a7..3e6ceee29 100644 --- a/examples/mujoco/README.md +++ b/examples/mujoco/README.md @@ -3,15 +3,16 @@ We benchmarked Tianshou algorithm implementations in 9 out of 13 environments from the MuJoCo Gym task suite[[1]](#footnote1). For each supported algorithm and supported mujoco environments, we provide: + - Default hyperparameters used for benchmark and scripts to reproduce the benchmark; - A comparison of performance (or code level details) with other open source implementations or classic papers; - Graphs and raw data that can be used for research purposes[[2]](#footnote2); - Log details obtained during training[[2]](#footnote2); - Pretrained agents[[2]](#footnote2); - Some hints on how to tune the algorithm. - Supported algorithms are listed below: + - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) @@ -79,62 +80,64 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai ### Notes -1. In offpolicy algorithms (DDPG, TD3, SAC), the shared hyperparameters are almost the same, and unless otherwise stated, hyperparameters are consistent with those used for benchmark in SpinningUp's implementations (e.g. we use batchsize 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `step_per_collect`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.). +1. In offpolicy algorithms (DDPG, TD3, SAC), the shared hyperparameters are almost the same, and unless otherwise stated, hyperparameters are consistent with those used for benchmark in SpinningUp's implementations (e.g. we use batchsize 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `collection_step_num_env_steps`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.). 2. By comparison to both classic literature and open source implementations (e.g., SpinningUp)[[1]](#footnote1)[[2]](#footnote2), Tianshou's implementations of DDPG, TD3, and SAC are roughly at-parity with or better than the best reported results for these algorithms, so you can definitely use Tianshou's benchmark for research purposes. 3. We didn't compare offpolicy algorithms to OpenAI baselines [benchmark](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm), because for now it seems that they haven't provided benchmark for offpolicy algorithms, but in [SpinningUp docs](https://spinningup.openai.com/en/latest/spinningup/bench.html) they stated that "SpinningUp implementations of DDPG, TD3, and SAC are roughly at-parity with the best-reported results for these algorithms", so we think lack of comparisons with OpenAI baselines is okay. ### DDPG | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper (DDPG)](https://arxiv.org/abs/1802.09477) | [TD3 paper (OurDDPG)](https://arxiv.org/abs/1802.09477) | -| :--------------------: | :---------------: | :----------------------------------------------------------: | :--------------------------------------------------: | :-----------------------------------------------------: | -| Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 | -| HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 | -| Hopper | **2197.0±971.6** | ~1800 | **2020.5** | 1860.0 | -| Walker2d | 1400.6±905.0 | ~1950 | 1843.6 | **3098.1** | -| Swimmer | **144.1±6.5** | ~137 | N | N | -| Humanoid | **177.3±77.6** | N | N | N | -| Reacher | **-3.3±0.3** | N | -6.51 | -4.01 | -| InvertedPendulum | **1000.0±0.0** | N | **1000.0** | **1000.0** | -| InvertedDoublePendulum | 8364.3±2778.9 | N | **9355.5** | 8370.0 | +| :--------------------: | :---------------: | :------------------------------------------------------------------------------------: | :--------------------------------------------------: | :-----------------------------------------------------: | +| Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 | +| HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 | +| Hopper | **2197.0±971.6** | ~1800 | **2020.5** | 1860.0 | +| Walker2d | 1400.6±905.0 | ~1950 | 1843.6 | **3098.1** | +| Swimmer | **144.1±6.5** | ~137 | N | N | +| Humanoid | **177.3±77.6** | N | N | N | +| Reacher | **-3.3±0.3** | N | -6.51 | -4.01 | +| InvertedPendulum | **1000.0±0.0** | N | **1000.0** | **1000.0** | +| InvertedDoublePendulum | 8364.3±2778.9 | N | **9355.5** | 8370.0 | \* details[[4]](#footnote4)[[5]](#footnote5)[[6]](#footnote6) ### TD3 | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper](https://arxiv.org/abs/1802.09477) | -| :--------------------: | :---------------: | :----------------------------------------------------------: | :-------------------------------------------: | -| Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 | -| HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 | -| Hopper | 3472.2±116.8 | ~2860 | **3564.1±114.7** | -| Walker2d | 3982.4±274.5 | ~4000 | **4682.8±539.6** | -| Swimmer | **104.2±34.2** | ~78 | N | -| Humanoid | **5189.5±178.5** | N | N | -| Reacher | **-2.7±0.2** | N | -3.6±0.6 | -| InvertedPendulum | **1000.0±0.0** | N | **1000.0±0.0** | -| InvertedDoublePendulum | **9349.2±14.3** | N | **9337.5±15.0** | +| :--------------------: | :---------------: | :------------------------------------------------------------------------------------: | :-------------------------------------------: | +| Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 | +| HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 | +| Hopper | 3472.2±116.8 | ~2860 | **3564.1±114.7** | +| Walker2d | 3982.4±274.5 | ~4000 | **4682.8±539.6** | +| Swimmer | **104.2±34.2** | ~78 | N | +| Humanoid | **5189.5±178.5** | N | N | +| Reacher | **-2.7±0.2** | N | -3.6±0.6 | +| InvertedPendulum | **1000.0±0.0** | N | **1000.0±0.0** | +| InvertedDoublePendulum | **9349.2±14.3** | N | **9337.5±15.0** | \* details[[4]](#footnote4)[[5]](#footnote5)[[6]](#footnote6) #### Hints for TD3 + 1. TD3's learning rate is set to 3e-4 while it is 1e-3 for DDPG/SAC. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because SpinningUp do so) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can! ### SAC | Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [SAC paper](https://arxiv.org/abs/1801.01290) | -| :--------------------: | :----------------: | :----------------------------------------------------------: | :-------------------------------------------: | -| Ant | **5850.2±475.7** | ~3980 | ~3720 | -| HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 | -| Hopper | **3542.2±51.5** | ~3150 | ~3370 | -| Walker2d | **5007.0±251.5** | ~4250 | ~3740 | -| Swimmer | **44.4±0.5** | ~41.7 | N | -| Humanoid | **5488.5±81.2** | N | ~5200 | -| Reacher | **-2.6±0.2** | N | N | -| InvertedPendulum | **1000.0±0.0** | N | N | -| InvertedDoublePendulum | **9359.5±0.4** | N | N | +| :--------------------: | :----------------: | :------------------------------------------------------------------------------------: | :-------------------------------------------: | +| Ant | **5850.2±475.7** | ~3980 | ~3720 | +| HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 | +| Hopper | **3542.2±51.5** | ~3150 | ~3370 | +| Walker2d | **5007.0±251.5** | ~4250 | ~3740 | +| Swimmer | **44.4±0.5** | ~41.7 | N | +| Humanoid | **5488.5±81.2** | N | ~5200 | +| Reacher | **-2.6±0.2** | N | N | +| InvertedPendulum | **1000.0±0.0** | N | N | +| InvertedDoublePendulum | **9359.5±0.4** | N | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for SAC + 1. SAC's start-timesteps is set to 10000 by default while it is 25000 is DDPG/TD3. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because SpinningUp do so) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can! 2. DO NOT share the same network with two critic networks. 3. The sigma (of the Gaussian policy) should be conditioned on input. @@ -143,6 +146,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai ## Onpolicy Algorithms ### Notes + 1. In A2C and PPO, unless otherwise stated, most hyperparameters are consistent with those used for benchmark in [ikostrikov/pytorch-a2c-ppo-acktr-gail](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail). 2. Gernally speaking, by comparison to both classic literature and open source implementations (e.g., OPENAI Baselines)[[1]](#footnote1)[[2]](#footnote2), Tianshou's implementations of REINFORCE, A2C, PPO are better than the best reported results for these algorithms, so you can definitely use Tianshou's benchmark for research purposes. @@ -160,18 +164,17 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai | InvertedPendulum | **1000.0±0.0** | | InvertedDoublePendulum | **7726.2±1287.3** | - | Environment | Tianshou (3M) | [Spinning Up (VPG PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)[[7]](#footnote7) | -| :--------------------: | :---------------: | :----------------------------------------------------------: | -| Ant | **474.9+-133.5** | ~5 | -| HalfCheetah | **884.0+-41.0** | ~600 | -| Hopper | 395.8+-64.5* | **~800** | -| Walker2d | 412.0+-52.4 | **~460** | -| Swimmer | 35.3+-1.4 | **~51** | -| Humanoid | **438.2+-47.8** | N | -| Reacher | **-10.5+-0.7** | N | -| InvertedPendulum | **999.2+-2.4** | N | -| InvertedDoublePendulum | **1059.7+-307.7** | N | +| :--------------------: | :---------------: | :------------------------------------------------------------------------------------------------------------------------: | +| Ant | **474.9+-133.5** | ~5 | +| HalfCheetah | **884.0+-41.0** | ~600 | +| Hopper | 395.8+-64.5\* | **~800** | +| Walker2d | 412.0+-52.4 | **~460** | +| Swimmer | 35.3+-1.4 | **~51** | +| Humanoid | **438.2+-47.8** | N | +| Reacher | **-10.5+-0.7** | N | +| InvertedPendulum | **999.2+-2.4** | N | +| InvertedDoublePendulum | **1059.7+-307.7** | N | \* details[[4]](#footnote4)[[5]](#footnote5) @@ -188,35 +191,35 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai ### A2C | Environment | Tianshou (3M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html) | -| :--------------------: | :----------------: | :----------------------------------------------------------: | -| Ant | **5236.8+-236.7** | ~5 | -| HalfCheetah | **2377.3+-1363.7** | ~600 | -| Hopper | **1608.6+-529.5** | ~800 | -| Walker2d | **1805.4+-1055.9** | ~460 | -| Swimmer | 40.2+-1.8 | **~51** | -| Humanoid | **5316.6+-554.8** | N | -| Reacher | **-5.2+-0.5** | N | -| InvertedPendulum | **1000.0+-0.0** | N | -| InvertedDoublePendulum | **9351.3+-12.8** | N | +| :--------------------: | :----------------: | :----------------------------------------------------------------------------------------: | +| Ant | **5236.8+-236.7** | ~5 | +| HalfCheetah | **2377.3+-1363.7** | ~600 | +| Hopper | **1608.6+-529.5** | ~800 | +| Walker2d | **1805.4+-1055.9** | ~460 | +| Swimmer | 40.2+-1.8 | **~51** | +| Humanoid | **5316.6+-554.8** | N | +| Reacher | **-5.2+-0.5** | N | +| InvertedPendulum | **1000.0+-0.0** | N | +| InvertedDoublePendulum | **9351.3+-12.8** | N | | Environment | Tianshou (1M) | [PPO paper](https://arxiv.org/abs/1707.06347) A2C | [PPO paper](https://arxiv.org/abs/1707.06347) A2C + Trust Region | -| :--------------------: | :----------------: | :-----------------------------------------------: | :----------------------------------------------------------: | -| Ant | **3485.4+-433.1** | N | N | -| HalfCheetah | **1829.9+-1068.3** | ~1000 | ~930 | -| Hopper | **1253.2+-458.0** | ~900 | ~1220 | -| Walker2d | **1091.6+-709.2** | ~850 | ~700 | -| Swimmer | **36.6+-2.1** | ~31 | **~36** | -| Humanoid | **1726.0+-1070.1** | N | N | -| Reacher | **-6.7+-2.3** | ~-24 | ~-27 | -| InvertedPendulum | **1000.0+-0.0** | **~1000** | **~1000** | -| InvertedDoublePendulum | **9257.7+-277.4** | ~7100 | ~8100 | +| :--------------------: | :----------------: | :-----------------------------------------------: | :--------------------------------------------------------------: | +| Ant | **3485.4+-433.1** | N | N | +| HalfCheetah | **1829.9+-1068.3** | ~1000 | ~930 | +| Hopper | **1253.2+-458.0** | ~900 | ~1220 | +| Walker2d | **1091.6+-709.2** | ~850 | ~700 | +| Swimmer | **36.6+-2.1** | ~31 | **~36** | +| Humanoid | **1726.0+-1070.1** | N | N | +| Reacher | **-6.7+-2.3** | ~-24 | ~-27 | +| InvertedPendulum | **1000.0+-0.0** | **~1000** | **~1000** | +| InvertedDoublePendulum | **9257.7+-277.4** | ~7100 | ~8100 | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for A2C 1. We choose `clip` action method in A2C instead of `tanh` option as used in REINFORCE simply to be consistent with original implementation. `tanh` may be better or equally well but we didn't have a try. -2. (Initial) learning rate, lr\_decay, `step-per-collect` and `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents have been trained). Below are our findings. +2. (Initial) learning rate, lr_decay, `step-per-collect` and `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents have been trained). Below are our findings. 3. `step-per-collect` / `training-num` are equal to `bootstrap-lenghth`, which is the max length of an "episode" used in GAE estimator and 80/16=5 in default settings. When `bootstrap-lenghth` is small, (maybe) because GAE can look forward at most 5 steps and use bootstrap strategy very often, the critic is less well-trained leading the actor to a not very high score. However, if we increase `step-per-collect` to increase `bootstrap-lenghth` (e.g. 256/16=16), actor/critic will be updated less often, resulting in low sample efficiency and slow training process. To conclude, If you don't restrict env timesteps, you can try using larger `bootstrap-lenghth` and train with more steps to get a better converged score. Train slower, achieve higher. 4. The learning rate 7e-4 with decay strategy is appropriate for `step-per-collect=80` and `training-num=16`. But if you use a larger `step-per-collect`(e.g. 256 - 2048), 7e-4 is a little bit small for `lr` because each update will have more data, less noise and thus smaller deviation in this case. So it is more appropriate to use a higher learning rate (e.g. 1e-3) to boost performance in this setting. If plotting results arise fast in early stages and become unstable later, consider lr decay first before decreasing lr. 5. `max-grad-norm` didn't really help in our experiments. We simply keep it for consistency with other open-source implementations (e.g. SB3). @@ -227,58 +230,60 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai ### PPO | Environment | Tianshou (1M) | [ikostrikov/pytorch-a2c-ppo-acktr-gail](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) | -| :--------------------: | :----------------: | :----------------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | -| Ant | **3258.4+-1079.3** | N | N | N | ~650 | -| HalfCheetah | **5783.9+-1244.0** | ~3120 | ~1800 | ~1700 | ~1670 | -| Hopper | **2609.3+-700.8** | ~2300 | ~2330 | ~2400 | ~1850 | -| Walker2d | 3588.5+-756.6 | **~4000** | ~3460 | ~3510 | ~1230 | -| Swimmer | 66.7+-99.1 | N | ~108 | ~111 | **~120** | -| Humanoid | **787.1+-193.5** | N | N | N | N | -| Reacher | **-4.1+-0.3** | ~-5 | ~-7 | ~-6 | N | -| InvertedPendulum | **1000.0+-0.0** | N | **~1000** | ~940 | N | -| InvertedDoublePendulum | **9231.3+-270.4** | N | ~8000 | ~7350 | N | +| :--------------------: | :----------------: | :-----------------------------------------------------------------------------------------------: | :-----------------------------------------------: | :-----------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------: | +| Ant | **3258.4+-1079.3** | N | N | N | ~650 | +| HalfCheetah | **5783.9+-1244.0** | ~3120 | ~1800 | ~1700 | ~1670 | +| Hopper | **2609.3+-700.8** | ~2300 | ~2330 | ~2400 | ~1850 | +| Walker2d | 3588.5+-756.6 | **~4000** | ~3460 | ~3510 | ~1230 | +| Swimmer | 66.7+-99.1 | N | ~108 | ~111 | **~120** | +| Humanoid | **787.1+-193.5** | N | N | N | N | +| Reacher | **-4.1+-0.3** | ~-5 | ~-7 | ~-6 | N | +| InvertedPendulum | **1000.0+-0.0** | N | **~1000** | ~940 | N | +| InvertedDoublePendulum | **9231.3+-270.4** | N | ~8000 | ~7350 | N | | Environment | Tianshou (3M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) | -| :--------------------: | :----------------: | :----------------------------------------------------------: | -| Ant | **4079.3+-880.2** | ~3000 | -| HalfCheetah | **7337.4+-1508.2** | ~3130 | -| Hopper | **3127.7+-413.0** | ~2460 | -| Walker2d | **4895.6+-704.3** | ~2600 | -| Swimmer | 81.4+-96.0 | **~120** | -| Humanoid | **1359.7+-572.7** | N | -| Reacher | **-3.7+-0.3** | N | -| InvertedPendulum | **1000.0+-0.0** | N | -| InvertedDoublePendulum | **9231.3+-270.4** | N | +| :--------------------: | :----------------: | :----------------------------------------------------------------------------------------: | +| Ant | **4079.3+-880.2** | ~3000 | +| HalfCheetah | **7337.4+-1508.2** | ~3130 | +| Hopper | **3127.7+-413.0** | ~2460 | +| Walker2d | **4895.6+-704.3** | ~2600 | +| Swimmer | 81.4+-96.0 | **~120** | +| Humanoid | **1359.7+-572.7** | N | +| Reacher | **-3.7+-0.3** | N | +| InvertedPendulum | **1000.0+-0.0** | N | +| InvertedDoublePendulum | **9231.3+-270.4** | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for PPO + 1. Following [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990) Sec 3.5, we use "recompute advantage" strategy, which contributes a lot to our SOTA benchmark. However, I personally don't quite agree with their explanation about why "recompute advantage" helps. They stated that it's because old strategy "makes it impossible to compute advantages as the temporal structure is broken", but PPO's update equation is designed to learn from slightly-outdated advantages. I think the only reason "recompute advantage" works is that it update the critic several times rather than just one time per update, which leads to a better value function estimation. 2. We have done full scale ablation studies of PPO algorithm's hyperparameters. Here are our findings: In Mujoco settings, `value-clip` and `norm-adv` may help a litte bit in some games (e.g. `norm-adv` helps stabilize training in InvertedPendulum-v2), but they make no difference to overall performance. So in our benchmark we do not use such tricks. We validate that setting `ent-coef` to 0.0 rather than 0.01 will increase overall performance in mujoco environments. `max-grad-norm` still offers no help for PPO algorithm, but we still keep it for consistency. 3. [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990)'s work indicates that using `gae-lambda` 0.9 and changing policy network's width based on which game you play (e.g. use [16, 16] `hidden-sizes` for `actor` network in HalfCheetah and [256, 256] for Ant) may help boost performance. Our ablation studies say otherwise: both options may lead to equal or lower performance overall in our experiments. We are not confident about this claim because we didn't change learning rate and other maybe-correlated factors in our experiments. So if you want, you can still have a try. -4. `batch-size` 128 and 64 (default) work equally well. Changing `training-num` alone slightly (maybe in range [8, 128]) won't affect performance. For bound action method, both `clip` and `tanh` work quite well. +4. `batch-size` 128 and 64 (default) work equally well. Changing `training-num` alone slightly (maybe in range [8, 128]) won't affect performance. For bound action method, both `clip` and `tanh` work quite well. 5. In OPENAI implementations of PPO, they multiply value loss with a factor of 0.5 for no good reason (see this [issue](https://github.com/openai/baselines/issues/445#issuecomment-777988738)). We do not do so and therefore make our `vf-coef` 0.25 (half of standard 0.5). However, since value loss is only used to optimize `critic` network, setting different `vf-coef` should in theory make no difference if using Adam optimizer. - + ### TRPO | Environment | Tianshou (1M) | [ACKTR paper](https://arxiv.org/pdf/1708.05144.pdf) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (Tensorflow)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | -| :--------------------: | :---------------: | :-------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | -| Ant | **2866.7±707.9** | ~0 | N | N | ~150 | -| HalfCheetah | **4471.2±804.9** | ~400 | ~0 | ~1350 | ~850 | -| Hopper | 2046.0±1037.9 | ~1400 | ~2100 | **~2200** | ~1200 | -| Walker2d | **3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 | -| Swimmer | 40.9±19.6 | ~40 | **~121** | ~95 | ~85 | -| Humanoid | **810.1±126.1** | N | N | N | N | -| Reacher | **-5.1±0.8** | -8 | ~-115 | **~-5** | N | -| InvertedPendulum | **1000.0±0.0** | **~1000** | **~1000** | ~910 | N | -| InvertedDoublePendulum | **8435.2±1073.3** | ~800 | ~200 | ~7000 | N | +| :--------------------: | :---------------: | :-------------------------------------------------: | :-----------------------------------------------: | :-----------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------: | +| Ant | **2866.7±707.9** | ~0 | N | N | ~150 | +| HalfCheetah | **4471.2±804.9** | ~400 | ~0 | ~1350 | ~850 | +| Hopper | 2046.0±1037.9 | ~1400 | ~2100 | **~2200** | ~1200 | +| Walker2d | **3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 | +| Swimmer | 40.9±19.6 | ~40 | **~121** | ~95 | ~85 | +| Humanoid | **810.1±126.1** | N | N | N | N | +| Reacher | **-5.1±0.8** | -8 | ~-115 | **~-5** | N | +| InvertedPendulum | **1000.0±0.0** | **~1000** | **~1000** | ~910 | N | +| InvertedDoublePendulum | **8435.2±1073.3** | ~800 | ~200 | ~7000 | N | \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for TRPO + 1. We have tried `step-per-collect` in (80, 1024, 2048, 4096), and `training-num` in (4, 16, 32, 64), and found out 1024 for `step-per-collect` (same as OpenAI Baselines) and smaller `training-num` (below 16) are good choices. Set `training-num` to 4 is actually better but we still use 16 considering the boost of training speed. 2. Advantage normalization is a standard trick in TRPO, but we found it of minor help, just like in PPO. -3. Larger `optim-critic-iters` (than 5, as used in OpenAI Baselines) helps in most environments. Smaller lr and lr\_decay strategy also help a tiny little bit for performance. +3. Larger `optim-critic-iters` (than 5, as used in OpenAI Baselines) helps in most environments. Smaller lr and lr_decay strategy also help a tiny little bit for performance. 4. `gae-lambda` 0.98 and 0.95 work equally well. 5. We use GAE returns (GAE advantage + value) as the target of critic network when updating, while people usually tend to use reward to go (lambda = 0.) as target. We found that they work equally well although using GAE returns is a little bit inaccurate (biased) by math. 6. Empirically, Swimmer-v3 usually requires larger bootstrap lengths and learning rate. Humanoid-v3 and InvertedPendulum-v2, however, are on the opposite. @@ -302,33 +307,36 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai \* details[[4]](#footnote4)[[5]](#footnote5) #### Hints for NPG + 1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are. 2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general. ## Others ### HER -| Environment | DDPG without HER | DDPG with HER | -| :--------------------: | :--------------: | :--------------: | -| FetchReach | -49.9±0.2. | **-17.6±21.7** | + +| Environment | DDPG without HER | DDPG with HER | +| :---------: | :--------------: | :------------: | +| FetchReach | -49.9±0.2. | **-17.6±21.7** | #### Hints for HER -1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is ``FetchReach-v3`` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics). -2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since *DDPG without HER* failed in every experiment, the best hyperparameters for *DDPG with HER* are used in the evaluation of both settings. -3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for ``FetchReach-v3`` is -50 which we can imply that *DDPG without HER* performs as good as a random policy. *DDPG with HER* although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds. + +1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is `FetchReach-v3` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics). +2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since _DDPG without HER_ failed in every experiment, the best hyperparameters for _DDPG with HER_ are used in the evaluation of both settings. +3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for `FetchReach-v3` is -50 which we can imply that _DDPG without HER_ performs as good as a random policy. _DDPG with HER_ although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds. ## Note -[1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures. +[1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures. -[2] Pretrained agents, detailed graphs (single agent, single game) and log details can all be found at [Google Drive](https://drive.google.com/drive/folders/1IycImzTmWcyEeD38viea5JHoboC4zmNP?usp=share_link). +[2] Pretrained agents, detailed graphs (single agent, single game) and log details can all be found at [Google Drive](https://drive.google.com/drive/folders/1IycImzTmWcyEeD38viea5JHoboC4zmNP?usp=share_link). -[3] We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though) +[3] We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though) -[4] ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided. +[4] ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided. -[5] Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered, if not otherwise stated. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34). +[5] Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered, if not otherwise stated. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34). -[6] In TD3 paper, shaded region represents only half of standard deviation. +[6] In TD3 paper, shaded region represents only half of standard deviation. -[7] Comparing Tianshou's REINFORCE algorithm with SpinningUp's VPG is quite unfair because SpinningUp's VPG uses a generative advantage estimator (GAE) which requires a dnn value predictor (critic network), which makes so called "VPG" more like A2C (advantage actor critic) algorithm. Even so, you can see that we are roughly at-parity with each other even if tianshou's REINFORCE do not use a critic or GAE. +[7] Comparing Tianshou's REINFORCE algorithm with SpinningUp's VPG is quite unfair because SpinningUp's VPG uses a generative advantage estimator (GAE) which requires a dnn value predictor (critic network), which makes so called "VPG" more like A2C (advantage actor critic) algorithm. Even so, you can see that we are roughly at-parity with each other even if tianshou's REINFORCE do not use a critic or GAE. diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index d5a5af132..84953c705 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -45,16 +45,16 @@ def get_args() -> argparse.Namespace: parser.add_argument("--exploration-noise", type=float, default=0.1) parser.add_argument("--start-timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) parser.add_argument("--update-per-step", type=int, default=1) parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--replay-buffer", type=str, default="her", choices=["normal", "her"]) parser.add_argument("--her-horizon", type=int, default=50) parser.add_argument("--her-future-k", type=int, default=8) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -82,12 +82,12 @@ def get_args() -> argparse.Namespace: def make_fetch_env( task: str, - training_num: int, + num_train_envs: int, test_num: int, ) -> tuple[gym.Env, BaseVectorEnv, BaseVectorEnv]: env = TruncatedAsTerminated(gym.make(task)) train_envs = ShmemVectorEnv( - [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)], + [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_train_envs)], ) test_envs = ShmemVectorEnv( [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)], @@ -117,7 +117,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - env, train_envs, test_envs = make_fetch_env(args.task, args.training_num, args.test_num) + env, train_envs, test_envs = make_fetch_env(args.task, args.num_train_envs, args.test_num) # The method HER works with goal-based environments if not isinstance(env.observation_space, gym.spaces.Dict): raise ValueError( @@ -196,12 +196,12 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: buffer: VectorReplayBuffer | ReplayBuffer | HERReplayBuffer | HERVectorReplayBuffer if args.replay_buffer == "normal": - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) else: - if args.training_num > 1: + if args.num_train_envs > 1: buffer = HERVectorReplayBuffer( args.buffer_size, len(train_envs), @@ -231,8 +231,8 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 339b551e1..7a9ae036b 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -31,15 +31,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=7e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=80) - parser.add_argument("--repeat-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=80) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=None) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) # a2c special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) parser.add_argument("--vf-coef", type=float, default=0.5) parser.add_argument("--ent-coef", type=float, default=0.01) parser.add_argument("--gae-lambda", type=float, default=0.95) @@ -75,7 +75,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=True, ) @@ -131,8 +131,8 @@ def main(args: argparse.Namespace = get_args()) -> None: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) @@ -156,7 +156,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - return_scaling=args.rew_norm, + return_scaling=args.return_scaling, ) # load a previous policy @@ -169,7 +169,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -208,11 +208,11 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 1a85c52fc..903b4651b 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -27,13 +27,13 @@ def main( lr: float = 7e-4, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 80, - repeat_per_collect: int = 1, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 80, + update_step_num_repetitions: int = 1, batch_size: int = 16, - training_num: int = 16, + num_train_envs: int = 16, test_num: int = 10, - rew_norm: bool = True, + return_scaling: bool = True, vf_coef: float = 0.5, ent_coef: float = 0.01, gae_lambda: float = 0.95, @@ -45,13 +45,13 @@ def main( training_config = OnPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_step_num_repetitions=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory( @@ -68,7 +68,7 @@ def main( gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - return_scaling=rew_norm, + return_scaling=return_scaling, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index c2c769f32..97d48e7eb 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -34,13 +34,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--exploration-noise", type=float, default=0.1) parser.add_argument("--start-timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) parser.add_argument("--update-per-step", type=int, default=1) parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -70,7 +70,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=False, ) @@ -122,7 +122,7 @@ def main(args: argparse.Namespace = get_args()) -> None: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -162,8 +162,8 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 7795347f8..0eac73b61 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -28,24 +28,24 @@ def main( exploration_noise: float = 0.1, start_timesteps: int = 25000, epoch: int = 200, - step_per_epoch: int = 5000, - step_per_collect: int = 1, + epoch_num_steps: int = 5000, + collection_step_num_env_steps: int = 1, update_per_step: int = 1, n_step: int = 1, batch_size: int = 256, - training_num: int = 1, + num_train_envs: int = 1, test_num: int = 10, ) -> None: log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 4585a1b51..87be75b89 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -36,23 +36,23 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=1024) - parser.add_argument("--repeat-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1024) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=None) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) # npg special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) parser.add_argument("--gae-lambda", type=float, default=0.95) parser.add_argument("--bound-action-method", type=str, default="clip") parser.add_argument("--lr-decay", type=int, default=True) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--optim-critic-iters", type=int, default=20) - parser.add_argument("--actor-step-size", type=float, default=0.1) + parser.add_argument("--trust_region_size", type=float, default=0.1) parser.add_argument( "--device", type=str, @@ -80,7 +80,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=True, ) @@ -129,8 +129,8 @@ def main(args: argparse.Namespace = get_args()) -> None: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) @@ -151,10 +151,10 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: optim=optim, gamma=args.gamma, gae_lambda=args.gae_lambda, - return_scaling=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, optim_critic_iters=args.optim_critic_iters, - actor_step_size=args.actor_step_size, + trust_region_size=args.trust_region_size, ) # load a previous policy @@ -167,7 +167,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -206,11 +206,11 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 7099fcbe0..a73010d22 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -26,31 +26,31 @@ def main( lr: float = 1e-3, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 1024, - repeat_per_collect: int = 1, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 1024, + update_step_num_repetitions: int = 1, batch_size: int = 16, - training_num: int = 16, - test_num: int = 10, - rew_norm: bool = True, + num_train_envs: int = 16, + num_test_envs: int = 10, + return_scaling: bool = True, gae_lambda: float = 0.95, bound_action_method: Literal["clip", "tanh"] = "clip", lr_decay: bool = True, - norm_adv: bool = True, + advantage_normalization: bool = True, optim_critic_iters: int = 20, - actor_step_size: float = 0.1, + trust_region_size: float = 0.1, ) -> None: log_name = os.path.join(task, "npg", str(experiment_config.seed), datetime_tag()) training_config = OnPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, + num_train_envs=num_train_envs, + num_test_envs=num_test_envs, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_step_num_repetitions=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory( @@ -67,10 +67,10 @@ def main( gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - return_standardization=rew_norm, - advantage_normalization=norm_adv, + return_scaling=return_scaling, + advantage_normalization=advantage_normalization, optim_critic_iters=optim_critic_iters, - actor_step_size=actor_step_size, + trust_region_size=trust_region_size, lr=lr, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 5c584b8c6..5feedd1ff 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -31,14 +31,14 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument("--update_step_num_repetitions", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=10) # ppo special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) # In theory, `vf-coef` will not make any difference if using Adam optimizer. parser.add_argument("--vf-coef", type=float, default=0.25) parser.add_argument("--ent-coef", type=float, default=0.0) @@ -49,7 +49,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=0) parser.add_argument("--recompute-adv", type=int, default=1) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -80,7 +80,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=True, ) @@ -132,8 +132,8 @@ def main(args: argparse.Namespace = get_args()) -> None: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) @@ -157,11 +157,11 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - return_scaling=args.rew_norm, + return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) @@ -175,7 +175,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -214,11 +214,11 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index b146012bd..43162a577 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -26,13 +26,13 @@ def main( lr: float = 3e-4, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 2048, - repeat_per_collect: int = 10, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 2048, + update_step_num_repetitions: int = 10, batch_size: int = 64, - training_num: int = 10, + num_train_envs: int = 10, test_num: int = 10, - rew_norm: bool = True, + return_scaling: bool = True, vf_coef: float = 0.25, ent_coef: float = 0.0, gae_lambda: float = 0.95, @@ -42,20 +42,20 @@ def main( eps_clip: float = 0.2, dual_clip: float | None = None, value_clip: bool = False, - norm_adv: bool = False, + advantage_normalization: bool = False, recompute_adv: bool = True, ) -> None: log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) training_config = OnPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_step_num_repetitions=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory( @@ -72,12 +72,12 @@ def main( gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - return_scaling=rew_norm, + return_scaling=return_scaling, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, value_clip=value_clip, - advantage_normalization=norm_adv, + advantage_normalization=advantage_normalization, eps_clip=eps_clip, dual_clip=dual_clip, recompute_advantage=recompute_adv, diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index da0c050a0..ae1c42762 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -66,7 +66,7 @@ def main( num_test_envs=5, test_step_num_episodes=5, buffer_size=4096, - step_per_collect=2048, + collection_step_num_env_steps=2048, update_step_num_repetitions=1, ) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index cba5d9d4f..8cfffbe53 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -38,14 +38,14 @@ def get_args() -> argparse.Namespace: parser.add_argument("--alpha-lr", type=float, default=3e-4) parser.add_argument("--start-timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) parser.add_argument("--update-per-step", type=int, default=20) parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--target-mode", type=str, choices=("min", "mean"), default="min") - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -75,7 +75,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=False, ) @@ -147,7 +147,7 @@ def linear(x: int, y: int) -> EnsembleLinear: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -187,8 +187,8 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index f0c6e2c04..144028567 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -33,25 +33,25 @@ def main( alpha_lr: float = 3e-4, start_timesteps: int = 10000, epoch: int = 200, - step_per_epoch: int = 5000, - step_per_collect: int = 1, + epoch_num_steps: int = 5000, + collection_step_num_env_steps: int = 1, update_per_step: int = 20, n_step: int = 1, batch_size: int = 256, target_mode: Literal["mean", "min"] = "min", - training_num: int = 1, + num_train_envs: int = 1, test_num: int = 10, ) -> None: log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index f30dfbdcd..c3ce826ac 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -31,15 +31,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=None) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) # reinforce special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) # "clip" option also works well. parser.add_argument("--action-bound-method", type=str, default="tanh") parser.add_argument("--lr-decay", type=int, default=True) @@ -72,7 +72,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=True, ) @@ -115,8 +115,8 @@ def main(args: argparse.Namespace = get_args()) -> None: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) @@ -135,7 +135,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: policy=policy, optim=optim, gamma=args.gamma, - return_standardization=args.rew_norm, + return_standardization=args.return_scaling, ) # load a previous policy @@ -148,7 +148,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -187,11 +187,11 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 21508e2e1..9fb499262 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -26,13 +26,13 @@ def main( lr: float = 1e-3, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 2048, - repeat_per_collect: int = 1, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 2048, + update_step_num_repetitions: int = 1, batch_size: int | None = None, - training_num: int = 10, + num_train_envs: int = 10, test_num: int = 10, - rew_norm: bool = True, + return_scaling: bool = True, action_bound_method: Literal["clip", "tanh"] = "tanh", lr_decay: bool = True, ) -> None: @@ -40,13 +40,13 @@ def main( training_config = OnPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_step_num_repetitions=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory( @@ -62,7 +62,7 @@ def main( ReinforceParams( gamma=gamma, action_bound_method=action_bound_method, - return_standardization=rew_norm, + return_standardization=return_scaling, lr=lr, lr_scheduler=LRSchedulerFactoryFactoryLinear(training_config) if lr_decay else None, ), diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 04e9ba559..ac8ca9daa 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -35,13 +35,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--alpha-lr", type=float, default=3e-4) parser.add_argument("--start-timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) parser.add_argument("--update-per-step", type=int, default=1) parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -71,7 +71,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=False, ) @@ -140,7 +140,7 @@ def main(args: argparse.Namespace = get_args()) -> None: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -180,8 +180,8 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 3a7d0dd78..ada5bf542 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -30,24 +30,24 @@ def main( alpha_lr: float = 3e-4, start_timesteps: int = 10000, epoch: int = 200, - step_per_epoch: int = 5000, - step_per_collect: int = 1, + epoch_num_steps: int = 5000, + collection_step_num_env_steps: int = 1, update_per_step: int = 1, n_step: int = 1, batch_size: int = 256, - training_num: int = 1, + num_train_envs: int = 1, test_num: int = 10, ) -> None: log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) training_config = OffPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, - num_train_envs=training_num, + epoch_num_steps=epoch_num_steps, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, batch_size=batch_size, - step_per_collect=step_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 0a521a9c0..a1ee04987 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -37,13 +37,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--update-actor-freq", type=int, default=2) parser.add_argument("--start-timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) parser.add_argument("--update-per-step", type=int, default=1) parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -73,7 +73,7 @@ def main(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=False, ) @@ -141,7 +141,7 @@ def main(args: argparse.Namespace = get_args()) -> None: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -181,8 +181,8 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 8fcbe8168..3d304775a 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -35,24 +35,24 @@ def main( update_actor_freq: int = 2, start_timesteps: int = 25000, epoch: int = 200, - step_per_epoch: int = 5000, - step_per_collect: int = 1, + epoch_num_steps: int = 5000, + collection_step_num_env_steps: int = 1, update_per_step: int = 1, n_step: int = 1, batch_size: int = 256, - training_num: int = 1, + num_train_envs: int = 1, test_num: int = 10, ) -> None: log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag()) training_config = TrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, - num_train_envs=training_num, + epoch_num_steps=epoch_num_steps, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, batch_size=batch_size, - collection_step_num_env_steps=step_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, update_per_step=update_per_step, start_timesteps=start_timesteps, start_timesteps_random=True, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 238e2d53e..a8790bfcc 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -36,22 +36,22 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=30000) - parser.add_argument("--step-per-collect", type=int, default=1024) - parser.add_argument("--repeat-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=30000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1024) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=None) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) # trpo special - parser.add_argument("--rew-norm", type=int, default=True) + parser.add_argument("--return_scaling", type=int, default=True) parser.add_argument("--gae-lambda", type=float, default=0.95) # TODO tanh support parser.add_argument("--bound-action-method", type=str, default="clip") parser.add_argument("--lr-decay", type=int, default=True) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--optim-critic-iters", type=int, default=20) parser.add_argument("--max-kl", type=float, default=0.01) parser.add_argument("--backtrack-coeff", type=float, default=0.8) @@ -83,7 +83,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, - args.training_num, + args.num_train_envs, args.test_num, obs_norm=True, ) @@ -132,8 +132,8 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) @@ -154,8 +154,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: optim=optim, gamma=args.gamma, gae_lambda=args.gae_lambda, - return_scaling=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, backtrack_coeff=args.backtrack_coeff, @@ -172,7 +172,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: # collector buffer: VectorReplayBuffer | ReplayBuffer - if args.training_num > 1: + if args.num_train_envs > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) @@ -211,11 +211,11 @@ def save_best_fn(policy: Algorithm) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, logger=logger, test_in_train=False, diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index e1a9bb4cd..f167f058e 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -26,17 +26,17 @@ def main( lr: float = 1e-3, gamma: float = 0.99, epoch: int = 100, - step_per_epoch: int = 30000, - step_per_collect: int = 1024, - repeat_per_collect: int = 1, + epoch_num_steps: int = 30000, + collection_step_num_env_steps: int = 1024, + update_step_num_repetitions: int = 1, batch_size: int = 16, - training_num: int = 16, + num_train_envs: int = 16, test_num: int = 10, - rew_norm: bool = True, + return_scaling: bool = True, gae_lambda: float = 0.95, bound_action_method: Literal["clip", "tanh"] = "clip", lr_decay: bool = True, - norm_adv: bool = True, + advantage_normalization: bool = True, optim_critic_iters: int = 20, max_kl: float = 0.01, backtrack_coeff: float = 0.8, @@ -46,13 +46,13 @@ def main( training_config = OnPolicyTrainingConfig( max_epochs=epoch, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, batch_size=batch_size, - num_train_envs=training_num, + num_train_envs=num_train_envs, num_test_envs=test_num, buffer_size=buffer_size, - step_per_collect=step_per_collect, - update_step_num_repetitions=repeat_per_collect, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_repetitions=update_step_num_repetitions, ) env_factory = MujocoEnvFactory( @@ -69,8 +69,8 @@ def main( gamma=gamma, gae_lambda=gae_lambda, action_bound_method=bound_action_method, - return_standardization=rew_norm, - advantage_normalization=norm_adv, + return_standardization=return_scaling, + advantage_normalization=advantage_normalization, optim_critic_iters=optim_critic_iters, max_kl=max_kl, backtrack_coeff=backtrack_coeff, diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index a9f478dc3..eca748847 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -37,11 +37,11 @@ def get_args() -> argparse.Namespace: parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update-per-epoch", type=int, default=10000) - parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--resume-path", type=str, default=None) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 73fd01d58..9c0dcdd9b 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -37,11 +37,11 @@ def get_args() -> argparse.Namespace: parser.add_argument("--min-q-weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update-per-epoch", type=int, default=10000) - parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--resume-path", type=str, default=None) diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 4c95e5c86..87672399d 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -38,11 +38,11 @@ def get_args() -> argparse.Namespace: parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update-per-epoch", type=int, default=10000) - parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--resume-path", type=str, default=None) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 5481b61f3..1c97e25bc 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -32,10 +32,10 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--update-per-epoch", type=int, default=10000) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--scale-obs", type=int, default=0) + parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--resume-path", type=str, default=None) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index c17365e62..b7e1b2673 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -35,10 +35,10 @@ def get_args() -> argparse.Namespace: parser.add_argument("--critic-lr", type=float, default=1e-3) parser.add_argument("--start-timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) + parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) @@ -215,7 +215,7 @@ def watch() -> None: buffer=replay_buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 0b86f11b7..7a13c77ee 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -106,7 +106,7 @@ def get_args() -> argparse.Namespace: help="The number of epochs to train for.", ) parser.add_argument( - "--step-per-epoch", + "--epoch_num_steps", type=int, default=5000, help="The number of steps per epoch.", @@ -118,7 +118,7 @@ def get_args() -> argparse.Namespace: help="The number of steps to use for N-step TD learning.", ) parser.add_argument( - "--batch-size", + "--batch_size", type=int, default=256, help="The batch size for training.", @@ -167,7 +167,7 @@ def get_args() -> argparse.Namespace: help="The frequency of evaluation.", ) parser.add_argument( - "--test-num", + "--num_test_envs", type=int, default=10, help="The number of episodes to evaluate for.", @@ -355,7 +355,7 @@ def watch() -> None: buffer=replay_buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index e2076d85a..7b3ba59aa 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -34,9 +34,9 @@ def get_args() -> argparse.Namespace: parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument("--gamma", default=0.99) @@ -158,7 +158,7 @@ def watch() -> None: buffer=replay_buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 37e15758f..c2da56413 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -35,9 +35,9 @@ def get_args() -> argparse.Namespace: parser.add_argument("--actor-lr", type=float, default=3e-4) parser.add_argument("--critic-lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=200) - parser.add_argument("--step-per-epoch", type=int, default=5000) + parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--alpha", type=float, default=2.5) parser.add_argument("--exploration-noise", type=float, default=0.1) @@ -49,7 +49,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--norm-obs", type=int, default=1) parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( @@ -206,7 +206,7 @@ def watch() -> None: buffer=replay_buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/examples/vizdoom/README.md b/examples/vizdoom/README.md index ca151f19b..3a46aaf86 100644 --- a/examples/vizdoom/README.md +++ b/examples/vizdoom/README.md @@ -39,13 +39,13 @@ D4 can reach 700+ reward. Here is the result: To evaluate an agent's performance: ```bash -python3 vizdoom_c51.py --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} +python3 vizdoom_c51.py --num_test_envs 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} ``` To save `.lmp` files for recording: ```bash -python3 vizdoom_c51.py --save-lmp --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} +python3 vizdoom_c51.py --save-lmp --num_test_envs 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2} ``` it will store `lmp` file in `lmps/` directory. To watch these `lmp` files (for example, d3 lmp): diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 2869acd1a..295193408 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -133,7 +133,7 @@ def make_vizdoom_env( res: tuple[int], save_lmp: bool = False, seed: int | None = None, - training_num: int = 10, + num_train_envs: int = 10, test_num: int = 10, ) -> tuple[Env, ShmemVectorEnv, ShmemVectorEnv]: cpu_count = os.cpu_count() @@ -154,7 +154,7 @@ def make_vizdoom_env( frame_skip=frame_skip, stack_num=res[0], seed=seed, - num_envs=training_num, + num_envs=num_train_envs, reward_config=reward_config, use_combined_action=True, max_episode_steps=2625, @@ -176,7 +176,7 @@ def make_vizdoom_env( cfg_path = f"maps/{task}.cfg" env = Env(cfg_path, frame_skip, res) train_envs = ShmemVectorEnv( - [lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)], + [lambda: Env(cfg_path, frame_skip, res) for _ in range(num_train_envs)], ) test_envs = ShmemVectorEnv( [lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(test_num)], diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 676b467e0..6f225d3c3 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -34,12 +34,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=300) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -82,7 +82,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: (args.frames_stack, 84, 84), args.save_lmp, args.seed, - args.training_num, + args.num_train_envs, args.test_num, ) args.state_shape = env.observation_space.shape @@ -202,15 +202,15 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( OffPolicyTrainerParams( train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 26ac04a9f..2cb67d08b 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -33,14 +33,14 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=0.00002) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=300) - parser.add_argument("--step-per-epoch", type=int, default=100000) - parser.add_argument("--step-per-collect", type=int, default=1000) - parser.add_argument("--repeat-per-collect", type=int, default=4) - parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--epoch_num_steps", type=int, default=100000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1000) + parser.add_argument("--update_step_num_repetitions", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--hidden-size", type=int, default=512) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--rew-norm", type=int, default=False) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--return_scaling", type=int, default=False) parser.add_argument("--vf-coef", type=float, default=0.5) parser.add_argument("--ent-coef", type=float, default=0.01) parser.add_argument("--gae-lambda", type=float, default=0.95) @@ -49,7 +49,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -111,7 +111,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: (args.frames_stack, 84, 84), args.save_lmp, args.seed, - args.training_num, + args.num_train_envs, args.test_num, ) args.state_shape = env.observation_space.shape @@ -140,8 +140,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: optim.with_lr_scheduler_factory( LRSchedulerFactoryLinear( max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, ) ) @@ -165,11 +165,11 @@ def dist(logits: torch.Tensor) -> Categorical: max_grad_norm=args.max_grad_norm, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - return_scaling=args.rew_norm, + return_scaling=args.return_scaling, eps_clip=args.eps_clip, value_clip=args.value_clip, dual_clip=args.dual_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ).to(args.device) if args.icm_lr_scale > 0: @@ -276,7 +276,7 @@ def watch() -> None: # test train_collector and start filling replay buffer train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # train result = algorithm.run_training( @@ -284,11 +284,11 @@ def watch() -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 790bfea08..69675115d 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -33,16 +33,16 @@ def get_args() -> argparse.Namespace: parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--exploration-noise", type=float, default=0.1) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=20000) - parser.add_argument("--step-per-collect", type=int, default=8) + parser.add_argument("--epoch_num_steps", type=int, default=20000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=8) parser.add_argument("--update-per-step", type=float, default=0.125) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--rew-norm", action="store_true", default=False) + parser.add_argument("--return_scaling", action="store_true", default=False) parser.add_argument("--n-step", type=int, default=3) parser.add_argument( "--device", @@ -64,7 +64,7 @@ def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = T args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -129,8 +129,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 8490b199a..e45acea50 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -31,13 +31,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=2) # theoretically it should be 1 - parser.add_argument("--batch-size", type=int, default=99999) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument( + "--update_step_num_repetitions", type=int, default=2 + ) # theoretically it should be 1 + parser.add_argument("--batch_size", type=int, default=99999) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -47,10 +49,10 @@ def get_args() -> argparse.Namespace: ) # npg special parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--return_scaling", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--optim-critic-iters", type=int, default=5) - parser.add_argument("--actor-step-size", type=float, default=0.5) + parser.add_argument("--trust_region_size", type=float, default=0.5) return parser.parse_known_args()[0] @@ -68,7 +70,7 @@ def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -117,11 +119,11 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: critic=critic, optim=AdamOptimizerFactory(lr=args.lr), gamma=args.gamma, - return_scaling=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, gae_lambda=args.gae_lambda, optim_critic_iters=args.optim_critic_iters, - actor_step_size=args.actor_step_size, + trust_region_size=args.trust_region_size, ) # collector train_collector = Collector[CollectStats]( @@ -147,11 +149,11 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 73778d8db..b7d44ab54 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -30,13 +30,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=150000) - parser.add_argument("--episode-per-collect", type=int, default=16) - parser.add_argument("--repeat-per-collect", type=int, default=2) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--epoch_num_steps", type=int, default=150000) + parser.add_argument("--collection_step_num_episodes", type=int, default=16) + parser.add_argument("--update_step_num_repetitions", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -50,10 +50,10 @@ def get_args() -> argparse.Namespace: parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) + parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--resume", action="store_true") parser.add_argument("--save-interval", type=int, default=4) @@ -74,8 +74,8 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -119,8 +119,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - return_scaling=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, value_clip=args.value_clip, @@ -172,11 +172,11 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, - test_step_num_episodes=args.test_num, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, - collection_step_num_episodes=args.episode_per_collect, + collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index c2d80c587..ad52dbea0 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -39,15 +39,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--alpha-lr", type=float, default=3e-4) parser.add_argument("--start-timesteps", type=int, default=1000) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=5000) - parser.add_argument("--step-per-collect", type=int, default=1) + parser.add_argument("--epoch_num_steps", type=int, default=5000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=1) parser.add_argument("--update-per-step", type=int, default=3) parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--target-mode", type=str, choices=("min", "mean"), default="min") parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -73,7 +73,7 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -158,8 +158,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index fa7fff9c2..b06a47107 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -45,15 +45,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--auto-alpha", type=int, default=1) parser.add_argument("--alpha-lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=24000) + parser.add_argument("--epoch_num_steps", type=int, default=24000) parser.add_argument("--il-step-per-epoch", type=int, default=500) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--imitation-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--n-step", type=int, default=3) @@ -71,10 +71,10 @@ def test_sac_with_il( skip_il: bool = False, ) -> None: # if you want to use python vector env, please refer to other test scripts - # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) + # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.num_train_envs, seed=args.seed) # test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) env = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -91,7 +91,7 @@ def test_sac_with_il( np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) - test_envs.seed(args.seed + args.training_num) + test_envs.seed(args.seed + args.num_train_envs) # model net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) @@ -163,8 +163,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, @@ -205,7 +205,7 @@ def stop_fn(mean_rewards: float) -> bool: optim=optim, ) il_test_env = gym.make(args.task) - il_test_env.reset(seed=args.seed + args.training_num + args.test_num) + il_test_env.reset(seed=args.seed + args.num_train_envs + args.test_num) il_test_collector = Collector[CollectStats]( il_algorithm, # envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), @@ -217,8 +217,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=il_test_collector, max_epochs=args.epoch, - epoch_num_steps=args.il_step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.il_epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index b5497dcbd..59f1f8aaa 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -36,13 +36,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--noise-clip", type=float, default=0.5) parser.add_argument("--update-actor-freq", type=int, default=2) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=20000) - parser.add_argument("--step-per-collect", type=int, default=8) + parser.add_argument("--epoch_num_steps", type=int, default=20000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=8) parser.add_argument("--update-per-step", type=float, default=0.125) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--n-step", type=int, default=3) @@ -68,7 +68,7 @@ def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -145,8 +145,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index f6b90bc62..73a29c68c 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -31,13 +31,15 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--step-per-collect", type=int, default=2048) - parser.add_argument("--repeat-per-collect", type=int, default=2) # theoretically it should be 1 - parser.add_argument("--batch-size", type=int, default=99999) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) + parser.add_argument( + "--update_step_num_repetitions", type=int, default=2 + ) # theoretically it should be 1 + parser.add_argument("--batch_size", type=int, default=99999) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -47,8 +49,8 @@ def get_args() -> argparse.Namespace: ) # trpo special parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--return_scaling", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--optim-critic-iters", type=int, default=5) parser.add_argument("--max-kl", type=float, default=0.005) parser.add_argument("--backtrack-coeff", type=float, default=0.8) @@ -70,7 +72,7 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -117,8 +119,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: critic=critic, optim=optim, gamma=args.gamma, - return_scaling=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, gae_lambda=args.gae_lambda, optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, @@ -149,11 +151,11 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/determinism_test.py b/test/determinism_test.py index 7d6a9f1b1..f71bcb6e5 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -71,7 +71,7 @@ def __init__( :param args: the arguments to be passed to the main function (some of which are overridden for the test) :param is_offline: whether the algorithm being tested is an offline algorithm and therefore - does not configure the number of training environments (`training_num`) + does not configure the number of training environments (`num_train_envs`) :param ignored_messages: message fragments to ignore in the trace log (if any) """ self.determinism_test = TraceDeterminismTest( @@ -89,10 +89,10 @@ def set(attr: str, value: Any) -> None: setattr(args, attr, value) set("epoch", 3) - set("step_per_epoch", 100) + set("epoch_num_steps", 100) set("device", "cpu") if not is_offline: - set("training_num", 1) + set("num_train_envs", 1) set("test_num", 1) self.args = args diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 86c9da16d..95f25e0ac 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -35,17 +35,17 @@ def get_args() -> argparse.Namespace: parser.add_argument("--il-lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=50000) + parser.add_argument("--epoch_num_steps", type=int, default=50000) parser.add_argument("--il-step-per-epoch", type=int, default=1000) - parser.add_argument("--episode-per-collect", type=int, default=16) - parser.add_argument("--step-per-collect", type=int, default=16) + parser.add_argument("--collection_step_num_episodes", type=int, default=16) + parser.add_argument("--collection_step_num_env_steps", type=int, default=16) parser.add_argument("--update-per-step", type=float, default=1 / 16) - parser.add_argument("--repeat-per-collect", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--update_step_num_repetitions", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--imitation-hidden-sizes", type=int, nargs="*", default=[128]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -58,7 +58,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--ent-coef", type=float, default=0.0) parser.add_argument("--max-grad-norm", type=float, default=None) parser.add_argument("--gae-lambda", type=float, default=1.0) - parser.add_argument("--rew-norm", action="store_true", default=False) + parser.add_argument("--return_scaling", action="store_true", default=False) return parser.parse_known_args()[0] @@ -75,7 +75,7 @@ def test_a2c_with_il( train_envs = env = envpool.make( args.task, env_type="gymnasium", - num_envs=args.training_num, + num_envs=args.num_train_envs, seed=args.seed, ) test_envs = envpool.make( @@ -86,7 +86,9 @@ def test_a2c_with_il( ) else: env = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.num_train_envs)] + ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) train_envs.seed(args.seed) test_envs.seed(args.seed) @@ -116,7 +118,7 @@ def test_a2c_with_il( vf_coef=args.vf_coef, ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm, - return_scaling=args.rew_norm, + return_scaling=args.return_scaling, ) # collector train_collector = Collector[CollectStats]( @@ -144,11 +146,11 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_episodes=args.episode_per_collect, + collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, @@ -200,8 +202,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=il_test_collector, max_epochs=args.epoch, - epoch_num_steps=args.il_step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.il_epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 4e568551d..6716265a0 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -35,12 +35,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--target-update-freq", type=int, default=200) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=80000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=80000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=128) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -76,7 +76,7 @@ def test_bdq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr train_envs = DummyVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) - for _ in range(args.training_num) + for _ in range(args.num_train_envs) ], ) test_envs = DummyVectorEnv( @@ -117,7 +117,7 @@ def test_bdq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr train_collector = Collector[CollectStats]( algorithm, train_envs, - VectorReplayBuffer(args.buffer_size, args.training_num), + VectorReplayBuffer(args.buffer_size, args.num_train_envs), exploration_noise=True, ) test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=False) @@ -125,7 +125,7 @@ def test_bdq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # initial data collection with policy_within_training_step(policy): train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) @@ -140,8 +140,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 291736db8..24bc27a74 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -43,13 +43,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=8000) - parser.add_argument("--step-per-collect", type=int, default=8) + parser.add_argument("--epoch_num_steps", type=int, default=8000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=8) parser.add_argument("--update-per-step", type=float, default=0.125) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized-replay", action="store_true", default=False) @@ -77,7 +77,7 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -132,7 +132,7 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # initial data collection with policy_within_training_step(policy): train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # logger log_path = os.path.join(args.logdir, args.task, "c51") @@ -193,8 +193,8 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index e7690ca7b..f94405d05 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -37,13 +37,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--alpha", type=float, default=0.05) parser.add_argument("--auto-alpha", action="store_true", default=False) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--n-step", type=int, default=3) @@ -73,7 +73,7 @@ def test_discrete_sac( env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -143,8 +143,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 9b53f9668..de1fac1a1 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -39,13 +39,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=20) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized-replay", action="store_true", default=False) @@ -71,7 +71,7 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -123,7 +123,7 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # initial data collection with policy_within_training_step(policy): train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # logger log_path = os.path.join(args.logdir, args.task, "dqn") @@ -152,8 +152,8 @@ def train_fn(epoch: int, env_step: int) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 77712e6c9..1c2bbb021 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -34,13 +34,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=20000) + parser.add_argument("--epoch_num_steps", type=int, default=20000) parser.add_argument("--update-per-step", type=float, default=1 / 16) - parser.add_argument("--step-per-collect", type=int, default=16) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--collection_step_num_env_steps", type=int, default=16) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--layer-num", type=int, default=2) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -65,7 +65,7 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -108,7 +108,7 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T # initial data collection with policy_within_training_step(policy): train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, args.task, "drqn") @@ -127,8 +127,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 1d352e0aa..cb4f9b580 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -44,13 +44,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized-replay", action="store_true", default=False) @@ -76,7 +76,7 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -138,7 +138,7 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # initial data collection with policy_within_training_step(policy): train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # logger log_path = os.path.join(args.logdir, args.task, "fqf") @@ -167,8 +167,8 @@ def train_fn(epoch: int, env_step: int) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index de8689a8c..035d01fea 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -44,13 +44,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized-replay", action="store_true", default=False) @@ -76,7 +76,7 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -134,7 +134,7 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # initial data collection with policy_within_training_step(policy): train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # logger log_path = os.path.join(args.logdir, args.task, "iqn") @@ -163,8 +163,8 @@ def train_fn(epoch: int, env_step: int) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index b2d09bb96..d18cb3d89 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -35,13 +35,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--step-per-collect", type=int, default=2000) - parser.add_argument("--repeat-per-collect", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2000) + parser.add_argument("--update_step_num_repetitions", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=20) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=20) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -55,8 +55,8 @@ def get_args() -> argparse.Namespace: parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) + parser.add_argument("--return_scaling", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=0) parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=0) @@ -76,7 +76,7 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -120,10 +120,10 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, - return_scaling=args.rew_norm, + return_scaling=args.return_scaling, dual_clip=args.dual_clip, value_clip=args.value_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) # collector @@ -150,11 +150,11 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index a8e9c9db4..840402daf 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -39,13 +39,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized-replay", action="store_true", default=False) @@ -77,7 +77,7 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -129,7 +129,7 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = # initial data collection with policy_within_training_step(policy): train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # logger log_path = os.path.join(args.logdir, args.task, "qrdqn") @@ -158,8 +158,8 @@ def train_fn(epoch: int, env_step: int) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 9ac35345c..9e64d40d6 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -44,13 +44,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=8000) - parser.add_argument("--step-per-collect", type=int, default=8) + parser.add_argument("--epoch_num_steps", type=int, default=8000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=8) parser.add_argument("--update-per-step", type=float, default=0.125) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized-replay", action="store_true", default=False) @@ -83,7 +83,7 @@ def test_rainbow(args: argparse.Namespace = get_args(), enable_assertions: bool ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -142,7 +142,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: # initial data collection with policy_within_training_step(policy): train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # logger log_path = os.path.join(args.logdir, args.task, "rainbow") @@ -212,8 +212,8 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/discrete/test_reinforce.py b/test/discrete/test_reinforce.py index ae9b60c01..68b66e4be 100644 --- a/test/discrete/test_reinforce.py +++ b/test/discrete/test_reinforce.py @@ -29,16 +29,16 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=40000) - parser.add_argument("--episode-per-collect", type=int, default=8) - parser.add_argument("--repeat-per-collect", type=int, default=2) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epoch_num_steps", type=int, default=40000) + parser.add_argument("--collection_step_num_episodes", type=int, default=8) + parser.add_argument("--update_step_num_repetitions", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=8) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=8) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--rew-norm", type=int, default=1) + parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument( "--device", type=str, @@ -58,7 +58,7 @@ def test_reinforce(args: argparse.Namespace = get_args(), enable_assertions: boo args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -86,7 +86,7 @@ def test_reinforce(args: argparse.Namespace = get_args(), enable_assertions: boo policy=policy, optim=optim, gamma=args.gamma, - return_standardization=args.rew_norm, + return_standardization=args.return_scaling, ) for m in net.modules(): if isinstance(m, torch.nn.Linear): @@ -118,11 +118,11 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_episodes=args.episode_per_collect, + collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index a0ec22f2a..b02438787 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -28,21 +28,21 @@ def create_training_config( builder_cls: type[ExperimentBuilder], num_epochs: int = 1, - step_per_epoch: int = 100, + epoch_num_steps: int = 100, num_train_envs: int = 2, num_test_envs: int = 2, ) -> OffPolicyTrainingConfig | OnPolicyTrainingConfig: if issubclass(builder_cls, OffPolicyExperimentBuilder): return OffPolicyTrainingConfig( max_epochs=num_epochs, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, num_train_envs=num_train_envs, num_test_envs=num_test_envs, ) elif issubclass(builder_cls, OnPolicyExperimentBuilder): return OnPolicyTrainingConfig( max_epochs=num_epochs, - epoch_num_steps=step_per_epoch, + epoch_num_steps=epoch_num_steps, num_train_envs=num_train_envs, num_test_envs=num_test_envs, ) @@ -69,7 +69,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime training_config = create_training_config( builder_cls, num_epochs=1, - step_per_epoch=100, + epoch_num_steps=100, num_train_envs=2, num_test_envs=2, ) @@ -100,7 +100,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment training_config = create_training_config( builder_cls, num_epochs=1, - step_per_epoch=100, + epoch_num_steps=100, num_train_envs=2, num_test_envs=2, ) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 5b87311a1..70edd7e06 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -36,13 +36,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=20) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized-replay", action="store_true", default=False) @@ -88,7 +88,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: args.task, env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -163,7 +163,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector[CollectStats](icm_algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = str(os.path.join(args.logdir, args.task, "dqn_icm")) @@ -192,8 +192,8 @@ def train_fn(epoch: int, env_step: int) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index c6dbb5d51..58d91b4c0 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -34,13 +34,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=50000) - parser.add_argument("--step-per-collect", type=int, default=2000) - parser.add_argument("--repeat-per-collect", type=int, default=10) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epoch_num_steps", type=int, default=50000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=2000) + parser.add_argument("--update_step_num_repetitions", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=20) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=20) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -54,8 +54,8 @@ def get_args() -> argparse.Namespace: parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=0) - parser.add_argument("--norm-adv", type=int, default=0) + parser.add_argument("--return_scaling", type=int, default=0) + parser.add_argument("--advantage_normalization", type=int, default=0) parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=0) @@ -94,7 +94,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: env.spec.reward_threshold if env.spec else None, ) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -135,10 +135,10 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: vf_coef=args.vf_coef, ent_coef=args.ent_coef, gae_lambda=args.gae_lambda, - return_scaling=args.rew_norm, + return_scaling=args.return_scaling, dual_clip=args.dual_clip, value_clip=args.value_clip, - advantage_normalization=args.norm_adv, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, ) @@ -191,11 +191,11 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_env_steps=args.step_per_collect, + collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 1dc55d81a..56e81e14c 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -25,10 +25,10 @@ def get_args() -> argparse.Namespace: parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=50000) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) - parser.add_argument("--episode-per-collect", type=int, default=1) - parser.add_argument("--training-num", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=1000) + parser.add_argument("--collection_step_num_episodes", type=int, default=1) + parser.add_argument("--num_train_envs", type=int, default=1) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--rew-mean-prior", type=float, default=0.0) @@ -50,7 +50,9 @@ def get_args() -> argparse.Namespace: reason="EnvPool is not installed. If on linux, please install it (e.g. as poetry extra)", ) def test_psrl(args: argparse.Namespace = get_args()) -> None: - train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) + train_envs = env = envpool.make_gymnasium( + args.task, num_envs=args.num_train_envs, seed=args.seed + ) test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) if args.reward_threshold is None: default_reward_threshold = {"NChain-v0": 3400} @@ -117,11 +119,11 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=1, test_step_num_episodes=args.test_num, batch_size=0, - collection_step_num_episodes=args.episode_per_collect, + collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, logger=logger, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 8b8b31559..934affc62 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -42,13 +42,13 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--step-per-epoch", type=int, default=10000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=10000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--prioritized-replay", action="store_true", default=False) @@ -80,7 +80,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -127,7 +127,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: train_collector.reset() test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) test_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) @@ -155,8 +155,8 @@ def train_fn(epoch: int, env_step: int) -> None: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, train_fn=train_fn, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 2875fae7a..f433794e3 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -34,11 +34,11 @@ def get_args() -> argparse.Namespace: parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=7) - parser.add_argument("--step-per-epoch", type=int, default=8000) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=8000) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.125) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -83,7 +83,7 @@ def gather_data() -> VectorReplayBuffer: ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -151,8 +151,8 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 08cb1ef57..4f72a3035 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -31,9 +31,9 @@ def get_args() -> argparse.Namespace: parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=500) - parser.add_argument("--batch-size", type=int, default=32) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=500) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) @@ -193,7 +193,7 @@ def watch() -> None: buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 0777ec81d..38cb65286 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -36,9 +36,9 @@ def get_args() -> argparse.Namespace: parser.add_argument("--cql-alpha-lr", type=float, default=1e-3) parser.add_argument("--start-timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=500) + parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--temperature", type=float, default=1.0) @@ -48,7 +48,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( @@ -187,7 +187,7 @@ def stop_fn(mean_rewards: float) -> bool: buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 2b263a9c3..565af0981 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -39,10 +39,10 @@ def get_args() -> argparse.Namespace: parser.add_argument("--unlikely-action-threshold", type=float, default=0.6) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=2000) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epoch_num_steps", type=int, default=2000) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) @@ -161,7 +161,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 390c56d6c..622949d46 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -38,10 +38,10 @@ def get_args() -> argparse.Namespace: parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--min-q-weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) - parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--epoch_num_steps", type=int, default=1000) + parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64]) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) @@ -130,7 +130,7 @@ def stop_fn(mean_rewards: float) -> bool: buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 20de78e86..60e33d003 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -36,10 +36,10 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epoch_num_steps", type=int, default=1000) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) @@ -131,7 +131,7 @@ def stop_fn(mean_rewards: float) -> bool: buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index dfade5ceb..153ae7ff4 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -32,14 +32,14 @@ def get_args() -> argparse.Namespace: parser.add_argument("--disc-lr", type=float, default=5e-4) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=150000) - parser.add_argument("--episode-per-collect", type=int, default=16) - parser.add_argument("--repeat-per-collect", type=int, default=2) + parser.add_argument("--epoch_num_steps", type=int, default=150000) + parser.add_argument("--collection_step_num_episodes", type=int, default=16) + parser.add_argument("--update_step_num_repetitions", type=int, default=2) parser.add_argument("--disc-update-num", type=int, default=2) - parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=16) - parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--num_train_envs", type=int, default=16) + parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -53,10 +53,10 @@ def get_args() -> argparse.Namespace: parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) + parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--resume", action="store_true") parser.add_argument("--save-interval", type=int, default=4) @@ -84,7 +84,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action - train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -151,8 +151,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - return_scaling=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, dual_clip=args.dual_clip, value_clip=args.value_clip, @@ -204,11 +204,11 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_episodes=args.episode_per_collect, + collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index b8b412ed3..4f0c74934 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -33,9 +33,9 @@ def get_args() -> argparse.Namespace: parser.add_argument("--actor-lr", type=float, default=1e-3) parser.add_argument("--critic-lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=500) + parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--alpha", type=float, default=2.5) parser.add_argument("--exploration-noise", type=float, default=0.1) parser.add_argument("--policy-noise", type=float, default=0.2) @@ -45,7 +45,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) parser.add_argument( @@ -177,7 +177,7 @@ def stop_fn(mean_rewards: float) -> bool: buffer=buffer, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, + epoch_num_steps=args.epoch_num_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, save_best_fn=save_best_fn, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index ae695f75a..71f17407e 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -42,13 +42,13 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("--n-step", type=int, default=100) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=3) - parser.add_argument("--step-per-epoch", type=int, default=500) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=500) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=100) + parser.add_argument("--batch_size", type=int, default=100) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) @@ -127,7 +127,7 @@ def train_agent( agents: list[OffPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[InfoStats, Algorithm]: - train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -146,7 +146,7 @@ def train_agent( ) test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) @@ -168,8 +168,8 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index def27cf14..87f73f104 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -93,15 +93,15 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("--n-step", type=int, default=100) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=500) - parser.add_argument("--step-per-collect", type=int, default=10) - parser.add_argument("--episode-per-collect", type=int, default=16) - parser.add_argument("--repeat-per-collect", type=int, default=2) + parser.add_argument("--epoch_num_steps", type=int, default=500) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) + parser.add_argument("--collection_step_num_episodes", type=int, default=16) + parser.add_argument("--update_step_num_repetitions", type=int, default=2) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument( @@ -121,10 +121,10 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("--eps-clip", type=float, default=0.2) parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--rew-norm", type=int, default=1) + parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--dual-clip", type=float, default=None) parser.add_argument("--value-clip", type=int, default=1) - parser.add_argument("--norm-adv", type=int, default=1) + parser.add_argument("--advantage_normalization", type=int, default=1) parser.add_argument("--recompute-adv", type=int, default=0) parser.add_argument("--resume", action="store_true") parser.add_argument("--save-interval", type=int, default=4) @@ -209,8 +209,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: eps_clip=args.eps_clip, vf_coef=args.vf_coef, ent_coef=args.ent_coef, - return_scaling=args.rew_norm, - advantage_normalization=args.norm_adv, + return_scaling=args.return_scaling, + advantage_normalization=args.advantage_normalization, recompute_advantage=args.recompute_adv, # dual_clip=args.dual_clip, # dual clip cause monotonically increasing log_std :) @@ -233,7 +233,7 @@ def train_agent( agents: list[OnPolicyAlgorithm] | None = None, optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[InfoStats, Algorithm]: - train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -251,7 +251,7 @@ def train_agent( exploration_noise=False, # True ) test_collector = Collector[CollectStats](marl_algorithm, test_envs) - # train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) + # train_collector.collect(n_step=args.batch_size * args.num_train_envs, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) @@ -273,11 +273,11 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - update_step_num_repetitions=args.repeat_per_collect, + epoch_num_steps=args.epoch_num_steps, + update_step_num_repetitions=args.update_step_num_repetitions, test_step_num_episodes=args.test_num, batch_size=args.batch_size, - collection_step_num_episodes=args.episode_per_collect, + collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index f7f2a79f8..d99778be5 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -47,13 +47,13 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=50) - parser.add_argument("--step-per-epoch", type=int, default=1000) - parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--epoch_num_steps", type=int, default=1000) + parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--update-per-step", type=float, default=0.1) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) - parser.add_argument("--training-num", type=int, default=10) - parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--num_train_envs", type=int, default=10) + parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.1) parser.add_argument( @@ -161,7 +161,7 @@ def train_agent( agent_opponent: OffPolicyAlgorithm | None = None, optim: OptimizerFactory | None = None, ) -> tuple[InfoStats, OffPolicyAlgorithm]: - train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -185,7 +185,7 @@ def train_agent( ) test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True) train_collector.reset() - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.num_train_envs) # log log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn") writer = SummaryWriter(log_path) @@ -213,8 +213,8 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: train_collector=train_collector, test_collector=test_collector, max_epochs=args.epoch, - epoch_num_steps=args.step_per_epoch, - collection_step_num_env_steps=args.step_per_collect, + epoch_num_steps=args.epoch_num_steps, + collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, diff --git a/tianshou/algorithm/modelfree/npg.py b/tianshou/algorithm/modelfree/npg.py index 93f327e0f..71b170520 100644 --- a/tianshou/algorithm/modelfree/npg.py +++ b/tianshou/algorithm/modelfree/npg.py @@ -37,7 +37,7 @@ def __init__( critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_critic_iters: int = 5, - actor_step_size: float = 0.5, + trust_region_size: float = 0.5, advantage_normalization: bool = True, gae_lambda: float = 0.95, max_batchsize: int = 256, @@ -56,7 +56,9 @@ def __init__( training. Lower values maintain a more even learning pace between policy and value function but may lead to less reliable advantage estimates. Typically set between 1 and 10, depending on the complexity of the value function. - :param actor_step_size: the scalar multiplier for policy updates in the natural gradient direction. + :param trust_region_size: the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. + The mathematical meaning is the trust region size, which is the maximum KL divergence + allowed between the old and new policy distributions. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability or policy deterioration; lower values provide more stable but slower learning. Unlike @@ -112,9 +114,9 @@ def __init__( gamma=gamma, return_scaling=return_scaling, ) - self.norm_adv = advantage_normalization + self.advantage_normalization = advantage_normalization self.optim_critic_iters = optim_critic_iters - self.actor_step_size = actor_step_size + self.trust_region_size = trust_region_size # adjusts Hessian-vector product calculation for numerical stability self._damping = 0.1 @@ -131,7 +133,7 @@ def _preprocess_batch( for minibatch in batch.split(self.max_batchsize, shuffle=False, merge_last=True): old_log_prob.append(self.policy(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) - if self.norm_adv: + if self.advantage_normalization: batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() return batch @@ -169,7 +171,7 @@ def _update_with_batch( # type: ignore[override] flat_params = torch.cat( [param.data.view(-1) for param in self.policy.actor.parameters()], ) - new_flat_params = flat_params + self.actor_step_size * search_direction + new_flat_params = flat_params + self.trust_region_size * search_direction self._set_from_flat_params(self.policy.actor, new_flat_params) new_dist = self.policy(minibatch).dist kl = kl_divergence(old_dist, new_dist).mean() diff --git a/tianshou/algorithm/modelfree/ppo.py b/tianshou/algorithm/modelfree/ppo.py index a1395eddc..b6a3e8aa6 100644 --- a/tianshou/algorithm/modelfree/ppo.py +++ b/tianshou/algorithm/modelfree/ppo.py @@ -140,7 +140,7 @@ def __init__( self.eps_clip = eps_clip self.dual_clip = dual_clip self.value_clip = value_clip - self.norm_adv = advantage_normalization + self.advantage_normalization = advantage_normalization self.recompute_adv = recompute_advantage def _preprocess_batch( @@ -181,7 +181,7 @@ def _update_with_batch( # type: ignore[override] # calculate loss for actor advantages = minibatch.adv dist = self.policy(minibatch).dist - if self.norm_adv: + if self.advantage_normalization: mean, std = advantages.mean(), advantages.std() advantages = (advantages - mean) / (std + self._eps) # per-batch norm ratios = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() diff --git a/tianshou/algorithm/modelfree/trpo.py b/tianshou/algorithm/modelfree/trpo.py index 5b7372143..952ef679f 100644 --- a/tianshou/algorithm/modelfree/trpo.py +++ b/tianshou/algorithm/modelfree/trpo.py @@ -33,7 +33,7 @@ def __init__( backtrack_coeff: float = 0.8, max_backtracks: int = 10, optim_critic_iters: int = 5, - actor_step_size: float = 0.5, + trust_region_size: float = 0.5, advantage_normalization: bool = True, gae_lambda: float = 0.95, max_batchsize: int = 256, @@ -56,7 +56,9 @@ def __init__( training. Lower values maintain a more even learning pace between policy and value function but may lead to less reliable advantage estimates. Typically set between 1 and 10, depending on the complexity of the value function. - :param actor_step_size: the scalar multiplier for policy updates in the natural gradient direction. + :param trust_region_size: the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. + The mathematical meaning is the trust region size, which is the maximum KL divergence + allowed between the old and new policy distributions. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability or policy deterioration; lower values provide more stable but slower learning. Unlike @@ -107,7 +109,7 @@ def __init__( critic=critic, optim=optim, optim_critic_iters=optim_critic_iters, - actor_step_size=actor_step_size, + trust_region_size=trust_region_size, advantage_normalization=advantage_normalization, gae_lambda=gae_lambda, max_batchsize=max_batchsize, diff --git a/tianshou/env/atari/atari_wrapper.py b/tianshou/env/atari/atari_wrapper.py index de375eaa3..5d2761a99 100644 --- a/tianshou/env/atari/atari_wrapper.py +++ b/tianshou/env/atari/atari_wrapper.py @@ -352,9 +352,15 @@ def wrap_deepmind( env = MaxAndSkipEnv(env, skip=4) assert hasattr(env.unwrapped, "get_action_meanings") # for mypy - wrapped_env: MaxAndSkipEnv | EpisodicLifeEnv | FireResetEnv | WarpFrame | ScaledFloatFrame | ClipRewardEnv | FrameStack = ( - env - ) + wrapped_env: ( + MaxAndSkipEnv + | EpisodicLifeEnv + | FireResetEnv + | WarpFrame + | ScaledFloatFrame + | ClipRewardEnv + | FrameStack + ) = env if episode_life: wrapped_env = EpisodicLifeEnv(wrapped_env) if "FIRE" in env.unwrapped.get_action_meanings(): @@ -373,7 +379,7 @@ def wrap_deepmind( def make_atari_env( task: str, seed: int, - training_num: int, + num_train_envs: int, test_num: int, scale: int | bool = False, frame_stack: int = 4, @@ -384,8 +390,8 @@ def make_atari_env( :return: a tuple of (single env, training envs, test envs). """ - env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale)) - envs = env_factory.create_envs(training_num, test_num) + env_factory = AtariEnvFactory(task, seed, seed + num_train_envs, frame_stack, scale=bool(scale)) + envs = env_factory.create_envs(num_train_envs, test_num) return envs.env, envs.train_envs, envs.test_envs diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 223dfe725..22ea5c7d2 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -17,7 +17,7 @@ class TrainingConfig(ToStringMixin): epoch consists of a number of training steps and one test step, where each training step * [for the online case] collects environment steps/transitions (**collection step**), - adding them to the (replay) buffer (see :attr:`collection_step_num_env_steps` and :attr:`episode_per_collect`) + adding them to the (replay) buffer (see :attr:`collection_step_num_env_steps` and :attr:`collection_step_num_episodes`) * performs an **update step** via the RL algorithm being used, which can involve one or more actual gradient updates, depending on the algorithm @@ -67,7 +67,7 @@ class TrainingConfig(ToStringMixin): the number of environment steps/transitions to collect in each collection step before the network update within each training step. - This is mutually exclusive with :attr:`episode_per_collect`, and one of the two must be set. + This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same @@ -167,7 +167,9 @@ def __post_init__(self) -> None: ] ) == 1 - ), ("Only one of `collection_step_num_env_steps` and `episode_per_collect` can be set.",) + ), ( + "Only one of `collection_step_num_env_steps` and `collection_step_num_episodes` can be set.", + ) @dataclass(kw_only=True) @@ -177,7 +179,7 @@ class OnlineTrainingConfig(TrainingConfig): the number of environment steps/transitions to collect in each collection step before the network update within each training step. - This is mutually exclusive with :attr:`episode_per_collect`, and one of the two must be set. + This is mutually exclusive with :attr:`collection_step_num_episodes`, and one of the two must be set. Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same diff --git a/tianshou/highlevel/params/algorithm_params.py b/tianshou/highlevel/params/algorithm_params.py index b408e0205..88fc7a060 100644 --- a/tianshou/highlevel/params/algorithm_params.py +++ b/tianshou/highlevel/params/algorithm_params.py @@ -514,9 +514,11 @@ class NPGParams(ActorCriticOnPolicyParams, ParamsMixinGeneralAdvantageEstimation function but may lead to less reliable advantage estimates. Typically set between 1 and 10, depending on the complexity of the value function. """ - actor_step_size: float = 0.5 + trust_region_size: float = 0.5 """ - the scalar multiplier for policy updates in the natural gradient direction. + the parameter delta - a scalar multiplier for policy updates in the natural gradient direction. + The mathematical meaning is the trust region size, which is the maximum KL divergence + allowed between the old and new policy distributions. Controls how far the policy parameters move in the calculated direction during each update. Higher values allow for faster learning but may cause instability or policy deterioration; lower values provide more stable but slower learning. Unlike diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 0eabc8543..e31150acf 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -21,6 +21,7 @@ of the policy. Optionally, the performance result can be used to determine whether training shall stop early (see :attr:`TrainerParams.stop_fn`). """ + import logging import time from abc import ABC, abstractmethod @@ -84,15 +85,15 @@ class TrainerParams(ToStringMixin): Training may be stopped early if the stop criterion is met (see :attr:`stop_fn`). For online training, the number of training steps in each epoch is indirectly determined by - :attr:`step_per_epoch`: As many training steps will be performed as are required in - order to reach :attr:`step_per_epoch` total steps in the training environments. + :attr:`epoch_num_steps`: As many training steps will be performed as are required in + order to reach :attr:`epoch_num_steps` total steps in the training environments. Specifically, if the number of transitions collected per step is `c` (see - :attr:`collection_step_num_env_steps`) and :attr:`step_per_epoch` is set to `s`, then the number + :attr:`collection_step_num_env_steps`) and :attr:`epoch_num_steps` is set to `s`, then the number of training steps per epoch is `ceil(s / c)`. Therefore, if `max_epochs = e`, the total number of environment steps taken during training can be computed as `e * ceil(s / c) * c`. - For offline training, the number of training steps per epoch is equal to :attr:`step_per_epoch`. + For offline training, the number of training steps per epoch is equal to :attr:`epoch_num_steps`. """ epoch_num_steps: int = 30000 @@ -475,7 +476,7 @@ class _TrainingStepResult(ABC): def get_steps_in_epoch_advancement(self) -> int: """ :return: the number of steps that were done within the epoch, where the concrete semantics - of what a step is depend on the type of algorithm. See docstring of `TrainerParams.step_per_epoch`. + of what a step is depend on the type of algorithm. See docstring of `TrainerParams.epoch_num_steps`. """ @abstractmethod @@ -548,7 +549,7 @@ def execute_epoch(self) -> EpochStats: self._epoch += 1 TraceLogger.log(log, lambda: f"Epoch #{self._epoch} start") - # perform the required number of steps for the epoch (`step_per_epoch`) + # perform the required number of steps for the epoch (`epoch_num_steps`) steps_done_in_this_epoch = 0 train_collect_stats, training_stats = None, None with self._pbar( From 90ec2bd1c26de01934bdfaf95576480e30f372c6 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 16 May 2025 23:34:29 +0200 Subject: [PATCH 193/230] v1: Removed unused bash script --- examples/mujoco/run_experiments.sh | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100755 examples/mujoco/run_experiments.sh diff --git a/examples/mujoco/run_experiments.sh b/examples/mujoco/run_experiments.sh deleted file mode 100755 index b175fe7c9..000000000 --- a/examples/mujoco/run_experiments.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -LOGDIR="results" -TASK=$1 -ALGO=$2 - -echo "Experiments started." -for seed in $(seq 0 9) -do - python mujoco_${ALGO}.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 & -done -echo "Experiments ended." From ba26f8b0b89611f42641ad6ac4a9aacdbc74ada4 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 23:38:37 +0200 Subject: [PATCH 194/230] v2: Replace hyphenated args in argparsers with snake case args (such that they correspond to the Python identifiers) --- examples/atari/atari_c51.py | 30 ++++++------- examples/atari/atari_dqn.py | 30 ++++++------- examples/atari/atari_fqf.py | 34 +++++++------- examples/atari/atari_iqn.py | 34 +++++++------- examples/atari/atari_ppo.py | 38 ++++++++-------- examples/atari/atari_qrdqn.py | 26 +++++------ examples/atari/atari_rainbow.py | 44 +++++++++---------- examples/atari/atari_sac.py | 32 +++++++------- examples/box2d/acrobot_dualdqn.py | 18 ++++---- examples/box2d/bipedal_bdq.py | 20 ++++----- examples/box2d/bipedal_hardcore_sac.py | 18 ++++---- examples/box2d/lunarlander_dqn.py | 18 ++++---- examples/box2d/mcc_sac.py | 12 ++--- examples/inverse/irl_gail.py | 32 +++++++------- examples/mujoco/analysis.py | 2 +- examples/mujoco/fetch_her_ddpg.py | 28 ++++++------ examples/mujoco/mujoco_a2c.py | 22 +++++----- examples/mujoco/mujoco_ddpg.py | 22 +++++----- examples/mujoco/mujoco_npg.py | 18 ++++---- examples/mujoco/mujoco_ppo.py | 30 ++++++------- examples/mujoco/mujoco_redq.py | 30 ++++++------- examples/mujoco/mujoco_reinforce.py | 14 +++--- examples/mujoco/mujoco_sac.py | 24 +++++----- examples/mujoco/mujoco_td3.py | 28 ++++++------ examples/mujoco/mujoco_trpo.py | 24 +++++----- examples/mujoco/plotter.py | 18 ++++---- examples/mujoco/tools.py | 4 +- examples/offline/atari_bcq.py | 28 ++++++------ examples/offline/atari_cql.py | 26 +++++------ examples/offline/atari_crr.py | 26 +++++------ examples/offline/atari_il.py | 16 +++---- .../offline/convert_rl_unplugged_atari.py | 10 ++--- examples/offline/d4rl_bcq.py | 24 +++++----- examples/offline/d4rl_cql.py | 34 +++++++------- examples/offline/d4rl_il.py | 10 ++--- examples/offline/d4rl_td3_bc.py | 30 ++++++------- examples/vizdoom/vizdoom_c51.py | 34 +++++++------- examples/vizdoom/vizdoom_ppo.py | 42 +++++++++--------- test/continuous/test_ddpg.py | 16 +++---- test/continuous/test_npg.py | 10 ++--- test/continuous/test_ppo.py | 24 +++++----- test/continuous/test_redq.py | 26 +++++------ test/continuous/test_sac_with_il.py | 24 +++++----- test/continuous/test_td3.py | 22 +++++----- test/continuous/test_trpo.py | 16 +++---- test/discrete/test_a2c_with_il.py | 22 +++++----- test/discrete/test_bdqn.py | 22 +++++----- test/discrete/test_c51.py | 26 +++++------ test/discrete/test_discrete_sac.py | 18 ++++---- test/discrete/test_dqn.py | 18 ++++---- test/discrete/test_drqn.py | 18 ++++---- test/discrete/test_fqf.py | 26 +++++------ test/discrete/test_iqn.py | 26 +++++------ test/discrete/test_ppo_discrete.py | 22 +++++----- test/discrete/test_qrdqn.py | 20 ++++----- test/discrete/test_rainbow.py | 30 ++++++------- test/discrete/test_reinforce.py | 6 +-- test/modelbased/test_dqn_icm.py | 24 +++++----- test/modelbased/test_ppo_icm.py | 28 ++++++------ test/modelbased/test_psrl.py | 10 ++--- test/offline/gather_cartpole_data.py | 22 +++++----- test/offline/gather_pendulum_data.py | 22 +++++----- test/offline/test_bcq.py | 16 +++---- test/offline/test_cql.py | 30 ++++++------- test/offline/test_discrete_bcq.py | 18 ++++---- test/offline/test_discrete_cql.py | 16 +++---- test/offline/test_discrete_crr.py | 10 ++--- test/offline/test_gail.py | 30 ++++++------- test/offline/test_td3_bc.py | 24 +++++----- test/pettingzoo/pistonball.py | 16 +++---- test/pettingzoo/pistonball_continuous.py | 34 +++++++------- test/pettingzoo/tic_tac_toe.py | 22 +++++----- 72 files changed, 822 insertions(+), 822 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 537d5bdb8..2f2482a3a 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -23,21 +23,21 @@ def get_args() -> argparse.Namespace: parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--scale_obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) @@ -48,23 +48,23 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index e23d238f8..a40f60162 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -25,18 +25,18 @@ def get_args() -> argparse.Namespace: parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--scale_obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) @@ -47,37 +47,37 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( - "--icm-lr-scale", + "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--icm-reward-scale", + "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--icm-forward-loss-weight", + "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 63f098b34..d0c768997 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -24,23 +24,23 @@ def get_args() -> argparse.Namespace: parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=3128) parser.add_argument("--scale_obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=5e-5) - parser.add_argument("--fraction-lr", type=float, default=2.5e-9) + parser.add_argument("--fraction_lr", type=float, default=2.5e-9) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-fractions", type=int, default=32) - parser.add_argument("--num-cosines", type=int, default=64) - parser.add_argument("--ent-coef", type=float, default=10.0) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--num_fractions", type=int, default=32) + parser.add_argument("--num_cosines", type=int, default=64) + parser.add_argument("--ent_coef", type=float, default=10.0) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) @@ -51,23 +51,23 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 12ab60420..f489057bb 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -24,23 +24,23 @@ def get_args() -> argparse.Namespace: parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--scale_obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--sample-size", type=int, default=32) - parser.add_argument("--online-sample-size", type=int, default=8) - parser.add_argument("--target-sample-size", type=int, default=8) - parser.add_argument("--num-cosines", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--sample_size", type=int, default=32) + parser.add_argument("--online_sample_size", type=int, default=8) + parser.add_argument("--target_sample_size", type=int, default=8) + parser.add_argument("--num_cosines", type=int, default=64) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) @@ -51,23 +51,23 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index b34e61e2d..d49e85258 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -35,7 +35,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) parser.add_argument("--scale_obs", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=2.5e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) @@ -43,20 +43,20 @@ def get_args() -> argparse.Namespace: parser.add_argument("--collection_step_num_env_steps", type=int, default=1000) parser.add_argument("--update_step_num_repetitions", type=int, default=4) parser.add_argument("--batch_size", type=int, default=256) - parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--hidden_size", type=int, default=512) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--return_scaling", type=int, default=False) - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.01) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.1) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=1) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--eps_clip", type=float, default=0.1) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -64,37 +64,37 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( - "--icm-lr-scale", + "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--icm-reward-scale", + "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--icm-forward-loss-weight", + "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index c56c81d26..2b2c84890 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -23,19 +23,19 @@ def get_args() -> argparse.Namespace: parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--scale_obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) @@ -46,23 +46,23 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index af69814ee..2e6622dce 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -28,30 +28,30 @@ def get_args() -> argparse.Namespace: parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--scale_obs", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.0000625) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--noisy-std", type=float, default=0.1) - parser.add_argument("--no-dueling", action="store_true", default=False) - parser.add_argument("--no-noisy", action="store_true", default=False) - parser.add_argument("--no-priority", action="store_true", default=False) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--noisy_std", type=float, default=0.1) + parser.add_argument("--no_dueling", action="store_true", default=False) + parser.add_argument("--no_noisy", action="store_true", default=False) + parser.add_argument("--no_priority", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.5) parser.add_argument("--beta", type=float, default=0.4) - parser.add_argument("--beta-final", type=float, default=1.0) - parser.add_argument("--beta-anneal-step", type=int, default=5000000) - parser.add_argument("--no-weight-norm", action="store_true", default=False) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--beta_final", type=float, default=1.0) + parser.add_argument("--beta_anneal_step", type=int, default=5000000) + parser.add_argument("--no_weight_norm", action="store_true", default=False) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) @@ -62,23 +62,23 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 752f2888c..faac23bf2 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -29,21 +29,21 @@ def get_args() -> argparse.Namespace: parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) parser.add_argument("--scale_obs", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=100000) - parser.add_argument("--actor-lr", type=float, default=1e-5) - parser.add_argument("--critic-lr", type=float, default=1e-5) + parser.add_argument("--buffer_size", type=int, default=100000) + parser.add_argument("--actor_lr", type=float, default=1e-5) + parser.add_argument("--critic_lr", type=float, default=1e-5) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.05) - parser.add_argument("--auto-alpha", action="store_true", default=False) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--auto_alpha", action="store_true", default=False) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--hidden_size", type=int, default=512) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--return_scaling", type=int, default=False) @@ -54,37 +54,37 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( - "--icm-lr-scale", + "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--icm-reward-scale", + "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--icm-forward-loss-weight", + "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 17993c820..60fce546d 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -23,21 +23,21 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Acrobot-v1") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.5) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.5) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=100) - parser.add_argument("--update-per-step", type=float, default=0.01) + parser.add_argument("--update_per_step", type=float, default=0.01) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128]) - parser.add_argument("--dueling-q-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--dueling-v-hidden-sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--dueling_q_hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--dueling_v_hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index e75f142a7..522bac6c7 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -24,23 +24,23 @@ def get_args() -> argparse.Namespace: # task parser.add_argument("--task", type=str, default="BipedalWalker-v3") # network architecture - parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[512, 256]) - parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[128]) - parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[128]) - parser.add_argument("--action-per-branch", type=int, default=25) + parser.add_argument("--common_hidden_sizes", type=int, nargs="*", default=[512, 256]) + parser.add_argument("--action_hidden_sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--value_hidden_sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--action_per_branch", type=int, default=25) # training hyperparameters parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.0) - parser.add_argument("--eps-train", type=float, default=0.73) - parser.add_argument("--eps-decay", type=float, default=5e-6) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.0) + parser.add_argument("--eps_train", type=float, default=0.73) + parser.add_argument("--eps_decay", type=float, default=5e-6) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--target-update-freq", type=int, default=1000) + parser.add_argument("--target_update_freq", type=int, default=1000) parser.add_argument("--epoch", type=int, default=25) parser.add_argument("--epoch_num_steps", type=int, default=80000) parser.add_argument("--collection_step_num_env_steps", type=int, default=16) - parser.add_argument("--update-per-step", type=float, default=0.0625) + parser.add_argument("--update_per_step", type=float, default=0.0625) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--num_train_envs", type=int, default=20) parser.add_argument("--num_test_envs", type=int, default=10) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index d98e8613b..a1a1e90cb 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -26,31 +26,31 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="BipedalWalkerHardcore-v3") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--actor-lr", type=float, default=3e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--actor_lr", type=float, default=3e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.1) - parser.add_argument("--auto-alpha", type=int, default=1) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--auto_alpha", type=int, default=1) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--n-step", type=int, default=4) + parser.add_argument("--n_step", type=int, default=4) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) return parser.parse_args() diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index d32105fc4..99b9bcbc2 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -24,21 +24,21 @@ def get_args() -> argparse.Namespace: # the parameters are found by Optuna parser.add_argument("--task", type=str, default="LunarLander-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.01) - parser.add_argument("--eps-train", type=float, default=0.73) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--eps_test", type=float, default=0.01) + parser.add_argument("--eps_train", type=float, default=0.73) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.013) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=4) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--n_step", type=int, default=4) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=80000) parser.add_argument("--collection_step_num_env_steps", type=int, default=16) - parser.add_argument("--update-per-step", type=float, default=0.0625) + parser.add_argument("--update_per_step", type=float, default=0.0625) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--dueling-q-hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--dueling-v-hidden-sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--dueling_q_hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--dueling_v_hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_train_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 3fd57599c..b2c8265a8 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -25,10 +25,10 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="MountainCarContinuous-v0") parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--buffer-size", type=int, default=50000) - parser.add_argument("--actor-lr", type=float, default=3e-4) - parser.add_argument("--critic-lr", type=float, default=3e-4) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--buffer_size", type=int, default=50000) + parser.add_argument("--actor_lr", type=float, default=3e-4) + parser.add_argument("--critic_lr", type=float, default=3e-4) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--noise_std", type=float, default=1.2) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) @@ -37,9 +37,9 @@ def get_args() -> argparse.Namespace: parser.add_argument("--epoch", type=int, default=20) parser.add_argument("--epoch_num_steps", type=int, default=12000) parser.add_argument("--collection_step_num_env_steps", type=int, default=5) - parser.add_argument("--update-per-step", type=float, default=0.2) + parser.add_argument("--update_per_step", type=float, default=0.2) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_train_envs", type=int, default=5) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index fccb6453e..e16f32483 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -52,34 +52,34 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--buffer_size", type=int, default=4096) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=3e-4) - parser.add_argument("--disc-lr", type=float, default=2.5e-5) + parser.add_argument("--disc_lr", type=float, default=2.5e-5) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) parser.add_argument("--epoch_num_steps", type=int, default=30000) parser.add_argument("--collection_step_num_env_steps", type=int, default=2048) parser.add_argument("--update_step_num_repetitions", type=int, default=10) - parser.add_argument("--disc-update-num", type=int, default=2) + parser.add_argument("--disc_update_num", type=int, default=2) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--num_train_envs", type=int, default=64) parser.add_argument("--num_test_envs", type=int, default=10) # ppo special parser.add_argument("--return_scaling", type=int, default=True) # In theory, `vf-coef` will not make any difference if using Adam optimizer. - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.001) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.001) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=1) + parser.add_argument("--recompute_adv", type=int, default=1) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -87,7 +87,7 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/analysis.py b/examples/mujoco/analysis.py index b881cdd34..3bd40f4ad 100755 --- a/examples/mujoco/analysis.py +++ b/examples/mujoco/analysis.py @@ -89,7 +89,7 @@ def numerical_analysis(root_dir: str | PathLike, xlim: float, norm: bool = False default=1000000, help="x-axis limitation (default: 1000000)", ) - parser.add_argument("--root-dir", type=str) + parser.add_argument("--root_dir", type=str) parser.add_argument( "--norm", action="store_true", diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 84953c705..bd5848d51 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -36,23 +36,23 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="FetchReach-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=100000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=3e-3) + parser.add_argument("--buffer_size", type=int, default=100000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--start-timesteps", type=int, default=25000) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--start_timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=1) - parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=1) + parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--batch_size", type=int, default=512) - parser.add_argument("--replay-buffer", type=str, default="her", choices=["normal", "her"]) - parser.add_argument("--her-horizon", type=int, default=50) - parser.add_argument("--her-future-k", type=int, default=8) + parser.add_argument("--replay_buffer", type=str, default="her", choices=["normal", "her"]) + parser.add_argument("--her_horizon", type=int, default=50) + parser.add_argument("--her_future_k", type=int, default=8) parser.add_argument("--num_train_envs", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") @@ -62,15 +62,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="HER-benchmark") + parser.add_argument("--wandb_project", type=str, default="HER-benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 7a9ae036b..ef50ef0e6 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -26,8 +26,8 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--buffer_size", type=int, default=4096) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=7e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) @@ -40,12 +40,12 @@ def get_args() -> argparse.Namespace: parser.add_argument("--num_test_envs", type=int, default=10) # a2c special parser.add_argument("--return_scaling", type=int, default=True) - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.01) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -53,15 +53,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 97d48e7eb..c64bd4329 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -25,19 +25,19 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--start-timesteps", type=int, default=25000) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--start_timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=1) - parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=1) + parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_train_envs", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) @@ -48,15 +48,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 87be75b89..a4215f948 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -26,9 +26,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) + parser.add_argument("--buffer_size", type=int, default=4096) parser.add_argument( - "--hidden-sizes", + "--hidden_sizes", type=int, nargs="*", default=[64, 64], @@ -45,28 +45,28 @@ def get_args() -> argparse.Namespace: parser.add_argument("--num_test_envs", type=int, default=10) # npg special parser.add_argument("--return_scaling", type=int, default=True) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--optim-critic-iters", type=int, default=20) + parser.add_argument("--optim_critic_iters", type=int, default=20) parser.add_argument("--trust_region_size", type=float, default=0.1) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 5feedd1ff..c691aa9e3 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -26,8 +26,8 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--buffer_size", type=int, default=4096) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) @@ -40,17 +40,17 @@ def get_args() -> argparse.Namespace: # ppo special parser.add_argument("--return_scaling", type=int, default=True) # In theory, `vf-coef` will not make any difference if using Adam optimizer. - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=1) + parser.add_argument("--recompute_adv", type=int, default=1) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -58,15 +58,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 8cfffbe53..2b1762659 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -25,25 +25,25 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--ensemble-size", type=int, default=10) - parser.add_argument("--subset-size", type=int, default=2) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--ensemble_size", type=int, default=10) + parser.add_argument("--subset_size", type=int, default=2) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", default=False, action="store_true") - parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument("--auto_alpha", default=False, action="store_true") + parser.add_argument("--alpha_lr", type=float, default=3e-4) + parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=20) - parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=20) + parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--batch_size", type=int, default=256) - parser.add_argument("--target-mode", type=str, choices=("min", "mean"), default="min") + parser.add_argument("--target_mode", type=str, choices=("min", "mean"), default="min") parser.add_argument("--num_train_envs", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") @@ -53,15 +53,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index c3ce826ac..b05ecbcc3 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -26,8 +26,8 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--buffer_size", type=int, default=4096) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=100) @@ -41,8 +41,8 @@ def get_args() -> argparse.Namespace: # reinforce special parser.add_argument("--return_scaling", type=int, default=True) # "clip" option also works well. - parser.add_argument("--action-bound-method", type=str, default="tanh") - parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--action_bound_method", type=str, default="tanh") + parser.add_argument("--lr_decay", type=int, default=True) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -50,15 +50,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index ac8ca9daa..1b9de80c0 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -24,21 +24,21 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", default=False, action="store_true") - parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument("--auto_alpha", default=False, action="store_true") + parser.add_argument("--alpha_lr", type=float, default=3e-4) + parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=1) - parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=1) + parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_train_envs", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) @@ -49,15 +49,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index a1ee04987..7c2df9aab 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -25,22 +25,22 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=3e-4) - parser.add_argument("--critic-lr", type=float, default=3e-4) + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=3e-4) + parser.add_argument("--critic_lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--policy-noise", type=float, default=0.2) - parser.add_argument("--noise-clip", type=float, default=0.5) - parser.add_argument("--update-actor-freq", type=int, default=2) - parser.add_argument("--start-timesteps", type=int, default=25000) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--policy_noise", type=float, default=0.2) + parser.add_argument("--noise_clip", type=float, default=0.5) + parser.add_argument("--update_actor_freq", type=int, default=2) + parser.add_argument("--start_timesteps", type=int, default=25000) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=1) - parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=1) + parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_train_envs", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) @@ -51,15 +51,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index a8790bfcc..1f8cbab2a 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -26,9 +26,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=4096) + parser.add_argument("--buffer_size", type=int, default=4096) parser.add_argument( - "--hidden-sizes", + "--hidden_sizes", type=int, nargs="*", default=[64, 64], @@ -45,31 +45,31 @@ def get_args() -> argparse.Namespace: parser.add_argument("--num_test_envs", type=int, default=10) # trpo special parser.add_argument("--return_scaling", type=int, default=True) - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--gae_lambda", type=float, default=0.95) # TODO tanh support - parser.add_argument("--bound-action-method", type=str, default="clip") - parser.add_argument("--lr-decay", type=int, default=True) + parser.add_argument("--bound_action_method", type=str, default="clip") + parser.add_argument("--lr_decay", type=int, default=True) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--optim-critic-iters", type=int, default=20) - parser.add_argument("--max-kl", type=float, default=0.01) - parser.add_argument("--backtrack-coeff", type=float, default=0.8) - parser.add_argument("--max-backtracks", type=int, default=10) + parser.add_argument("--optim_critic_iters", type=int, default=20) + parser.add_argument("--max_kl", type=float, default=0.01) + parser.add_argument("--backtrack_coeff", type=float, default=0.8) + parser.add_argument("--max_backtracks", type=int, default=10) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark") + parser.add_argument("--wandb_project", type=str, default="mujoco.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py index 5e2f9e016..cb64efc27 100755 --- a/examples/mujoco/plotter.py +++ b/examples/mujoco/plotter.py @@ -180,13 +180,13 @@ def plot_figure( if __name__ == "__main__": parser = argparse.ArgumentParser(description="plotter") parser.add_argument( - "--fig-length", + "--fig_length", type=int, default=6, help="matplotlib figure length (default: 6)", ) parser.add_argument( - "--fig-width", + "--fig_width", type=int, default=6, help="matplotlib figure width (default: 6)", @@ -212,7 +212,7 @@ def plot_figure( parser.add_argument("--xlabel", default="Timesteps", help="matplotlib figure xlabel") parser.add_argument("--ylabel", default="Episode Reward", help="matplotlib figure ylabel") parser.add_argument( - "--shaded-std", + "--shaded_std", action="store_true", help="shaded region corresponding to standard deviation of the group", ) @@ -227,35 +227,35 @@ def plot_figure( help="whether to share y axis within multiple sub-figures", ) parser.add_argument( - "--legend-outside", + "--legend_outside", action="store_true", help="place the legend outside of the figure", ) parser.add_argument("--xlim", type=int, default=None, help="x-axis limitation (default: None)") - parser.add_argument("--root-dir", default="./", help="root dir (default: ./)") + parser.add_argument("--root_dir", default="./", help="root dir (default: ./)") parser.add_argument( - "--file-pattern", + "--file_pattern", type=str, default=r".*/test_rew_\d+seeds.csv$", help="regular expression to determine whether or not to include target csv " "file, default to including all test_rew_{num}seeds.csv file under rootdir", ) parser.add_argument( - "--group-pattern", + "--group_pattern", type=str, default=r"(/|^)\w*?\-v(\d|$)", help="regular expression to group files in sub-figure, default to grouping " 'according to env_name dir, "" means no grouping', ) parser.add_argument( - "--legend-pattern", + "--legend_pattern", type=str, default=r".*", help="regular expression to extract legend from csv file path, default to " "using file path as legend name.", ) parser.add_argument("--show", action="store_true", help="show figure") - parser.add_argument("--output-path", type=str, help="figure save path", default="./figure.png") + parser.add_argument("--output_path", type=str, help="figure save path", default="./figure.png") parser.add_argument("--dpi", type=int, default=200, help="figure dpi (default: 200)") args = parser.parse_args() file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern)) diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index e0db8162b..3777ce98c 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -128,11 +128,11 @@ def merge_csv( help="Re-generate all csv files instead of using existing one.", ) parser.add_argument( - "--remove-zero", + "--remove_zero", action="store_true", help="Remove the data point of env_step == 0.", ) - parser.add_argument("--root-dir", type=str) + parser.add_argument("--root_dir", type=str) args = parser.parse_args() csv_files = convert_tfevents_to_csv(args.root_dir, args.refresh) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index eca748847..d4f6015ca 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -28,44 +28,44 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=6.25e-5) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--target-update-freq", type=int, default=8000) - parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) - parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--target_update_freq", type=int, default=8000) + parser.add_argument("--unlikely_action_threshold", type=float, default=0.3) + parser.add_argument("--imitation_logits_penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--update-per-epoch", type=int, default=10000) + parser.add_argument("--update_per_epoch", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) parser.add_argument("--num_test_envs", type=int, default=10) - parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", + "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) - parser.add_argument("--buffer-from-rl-unplugged", action="store_true", default=False) + parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 9c0dcdd9b..d9a62f056 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -31,41 +31,41 @@ def get_args() -> argparse.Namespace: parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--target-update-freq", type=int, default=500) - parser.add_argument("--min-q-weight", type=float, default=10.0) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=1) + parser.add_argument("--target_update_freq", type=int, default=500) + parser.add_argument("--min_q_weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--update-per-epoch", type=int, default=10000) + parser.add_argument("--update_per_epoch", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) parser.add_argument("--num_test_envs", type=int, default=10) - parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", + "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) - parser.add_argument("--buffer-from-rl-unplugged", action="store_true", default=False) + parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 87672399d..57da6b6ba 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -31,42 +31,42 @@ def get_args() -> argparse.Namespace: parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--policy-improvement-mode", type=str, default="exp") - parser.add_argument("--ratio-upper-bound", type=float, default=20.0) + parser.add_argument("--policy_improvement_mode", type=str, default="exp") + parser.add_argument("--ratio_upper_bound", type=float, default=20.0) parser.add_argument("--beta", type=float, default=1.0) - parser.add_argument("--min-q-weight", type=float, default=10.0) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--min_q_weight", type=float, default=10.0) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--update-per-epoch", type=int, default=10000) + parser.add_argument("--update_per_epoch", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[512]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[512]) parser.add_argument("--num_test_envs", type=int, default=10) - parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", + "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) - parser.add_argument("--buffer-from-rl-unplugged", action="store_true", default=False) + parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 1c97e25bc..342d1c3eb 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -31,35 +31,35 @@ def get_args() -> argparse.Namespace: parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--update-per-epoch", type=int, default=10000) + parser.add_argument("--update_per_epoch", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_test_envs", type=int, default=10) - parser.add_argument("--frames-stack", type=int, default=4) + parser.add_argument("--frames_stack", type=int, default=4) parser.add_argument("--scale_obs", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_atari.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_atari.benchmark") parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--log_interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", + "--load_buffer_name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", ) - parser.add_argument("--buffer-from-rl-unplugged", action="store_true", default=False) + parser.add_argument("--buffer_from_rl_unplugged", action="store_true", default=False) parser.add_argument( "--device", type=str, diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py index 1afd721a5..d999ad330 100755 --- a/examples/offline/convert_rl_unplugged_atari.py +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -266,25 +266,25 @@ def main(args: Namespace) -> None: parser = ArgumentParser(usage=__doc__) parser.add_argument("--task", required=True, help="Name of the Atari game.") parser.add_argument( - "--run-id", + "--run_id", type=int, default=1, help="Run id to download and convert. Value in [1..5].", ) parser.add_argument( - "--shard-id", + "--shard_id", type=int, default=0, help="Shard id to download and convert. Value in [0..99].", ) - parser.add_argument("--total-num-shards", type=int, default=100, help="Total number of shards.") + parser.add_argument("--total_num_shards", type=int, default=100, help="Total number of shards.") parser.add_argument( - "--dataset-dir", + "--dataset_dir", default=os.path.expanduser("~/.rl_unplugged/datasets"), help="Directory for converted hdf5 files.", ) parser.add_argument( - "--cache-dir", + "--cache_dir", default=os.path.expanduser("~/.rl_unplugged/cache"), help="Directory for downloaded original datasets.", ) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index b7e1b2673..11399683d 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -28,23 +28,23 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) - parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) + parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) - parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[512, 512]) + parser.add_argument("--vae_hidden_sizes", type=int, nargs="*", default=[512, 512]) # default to 2 * action_dim - parser.add_argument("--latent-dim", type=int) + parser.add_argument("--latent_dim", type=int) parser.add_argument("--gamma", default=0.99) parser.add_argument("--tau", default=0.005) # Weighting for Clipped Double Q-learning in BCQ @@ -56,15 +56,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 7a13c77ee..8260b837c 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -39,32 +39,32 @@ def get_args() -> argparse.Namespace: help="The random seed to use.", ) parser.add_argument( - "--expert-data-task", + "--expert_data_task", type=str, default="hopper-expert-v2", help="The name of the OpenAI Gym environment to use for expert data collection.", ) parser.add_argument( - "--buffer-size", + "--buffer_size", type=int, default=1000000, help="The size of the replay buffer.", ) parser.add_argument( - "--hidden-sizes", + "--hidden_sizes", type=int, nargs="*", default=[256, 256], help="The list of hidden sizes for the neural networks.", ) parser.add_argument( - "--actor-lr", + "--actor_lr", type=float, default=1e-4, help="The learning rate for the actor network.", ) parser.add_argument( - "--critic-lr", + "--critic_lr", type=float, default=3e-4, help="The learning rate for the critic network.", @@ -76,25 +76,25 @@ def get_args() -> argparse.Namespace: help="The weight of the entropy term in the loss function.", ) parser.add_argument( - "--auto-alpha", + "--auto_alpha", default=True, action="store_true", help="Whether to use automatic entropy tuning.", ) parser.add_argument( - "--alpha-lr", + "--alpha_lr", type=float, default=1e-4, help="The learning rate for the entropy tuning.", ) parser.add_argument( - "--cql-alpha-lr", + "--cql_alpha_lr", type=float, default=3e-4, help="The learning rate for the CQL entropy tuning.", ) parser.add_argument( - "--start-timesteps", + "--start_timesteps", type=int, default=10000, help="The number of timesteps before starting to train.", @@ -112,7 +112,7 @@ def get_args() -> argparse.Namespace: help="The number of steps per epoch.", ) parser.add_argument( - "--n-step", + "--n_step", type=int, default=3, help="The number of steps to use for N-step TD learning.", @@ -136,13 +136,13 @@ def get_args() -> argparse.Namespace: help="The temperature for the Boltzmann policy.", ) parser.add_argument( - "--cql-weight", + "--cql_weight", type=float, default=1.0, help="The weight of the CQL loss term.", ) parser.add_argument( - "--with-lagrange", + "--with_lagrange", type=bool, default=True, help="Whether to use the Lagrange multiplier for CQL.", @@ -154,14 +154,14 @@ def get_args() -> argparse.Namespace: help="Whether to use calibration for CQL.", ) parser.add_argument( - "--lagrange-threshold", + "--lagrange_threshold", type=float, default=10.0, help="The Lagrange multiplier threshold for CQL.", ) parser.add_argument("--gamma", type=float, default=0.99, help="The discount factor") parser.add_argument( - "--eval-freq", + "--eval_freq", type=int, default=1, help="The frequency of evaluation.", @@ -191,13 +191,13 @@ def get_args() -> argparse.Namespace: help="The device to train on (cpu or cuda).", ) parser.add_argument( - "--resume-path", + "--resume_path", type=str, default=None, help="The path to the checkpoint to resume from.", ) parser.add_argument( - "--resume-id", + "--resume_id", type=str, default=None, help="The ID of the checkpoint to resume from.", @@ -208,7 +208,7 @@ def get_args() -> argparse.Namespace: default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 7b3ba59aa..c6b620016 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -30,8 +30,8 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) @@ -45,15 +45,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index c2da56413..1d5a69798 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -29,26 +29,26 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2") - parser.add_argument("--buffer-size", type=int, default=1000000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) - parser.add_argument("--actor-lr", type=float, default=3e-4) - parser.add_argument("--critic-lr", type=float, default=3e-4) + parser.add_argument("--expert_data_task", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--buffer_size", type=int, default=1000000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor_lr", type=float, default=3e-4) + parser.add_argument("--critic_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=200) parser.add_argument("--epoch_num_steps", type=int, default=5000) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--alpha", type=float, default=2.5) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--policy-noise", type=float, default=0.2) - parser.add_argument("--noise-clip", type=float, default=0.5) - parser.add_argument("--update-actor-freq", type=int, default=2) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--policy_noise", type=float, default=0.2) + parser.add_argument("--noise_clip", type=float, default=0.5) + parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--norm-obs", type=int, default=1) + parser.add_argument("--norm_obs", type=int, default=1) - parser.add_argument("--eval-freq", type=int, default=1) + parser.add_argument("--eval_freq", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) @@ -57,15 +57,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark") + parser.add_argument("--wandb_project", type=str, default="offline_d4rl.benchmark") parser.add_argument( "--watch", default=False, diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 6f225d3c3..3e29cf2ed 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -22,21 +22,21 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="D1_basic") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.005) - parser.add_argument("--eps-train", type=float, default=1.0) - parser.add_argument("--eps-train-final", type=float, default=0.05) - parser.add_argument("--buffer-size", type=int, default=2000000) + parser.add_argument("--eps_test", type=float, default=0.005) + parser.add_argument("--eps_train", type=float, default=1.0) + parser.add_argument("--eps_train_final", type=float, default=0.05) + parser.add_argument("--buffer_size", type=int, default=2000000) parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) parser.add_argument("--epoch", type=int, default=300) parser.add_argument("--epoch_num_steps", type=int, default=100000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) @@ -47,17 +47,17 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--skip-num", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--skip_num", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark") + parser.add_argument("--wandb_project", type=str, default="vizdoom.benchmark") parser.add_argument( "--watch", default=False, @@ -65,12 +65,12 @@ def get_args() -> argparse.Namespace: help="watch the play of pre-trained policy only", ) parser.add_argument( - "--save-lmp", + "--save_lmp", default=False, action="store_true", help="save lmp file for replay whole episode", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) return parser.parse_args() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 2cb67d08b..e3353be30 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -29,7 +29,7 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="D1_basic") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--buffer_size", type=int, default=100000) parser.add_argument("--lr", type=float, default=0.00002) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=300) @@ -37,20 +37,20 @@ def get_args() -> argparse.Namespace: parser.add_argument("--collection_step_num_env_steps", type=int, default=1000) parser.add_argument("--update_step_num_repetitions", type=int, default=4) parser.add_argument("--batch_size", type=int, default=256) - parser.add_argument("--hidden-size", type=int, default=512) + parser.add_argument("--hidden_size", type=int, default=512) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--return_scaling", type=int, default=False) - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.01) - parser.add_argument("--gae-lambda", type=float, default=0.95) - parser.add_argument("--lr-decay", type=int, default=True) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--gae_lambda", type=float, default=0.95) + parser.add_argument("--lr_decay", type=int, default=True) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument( @@ -58,17 +58,17 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--frames-stack", type=int, default=4) - parser.add_argument("--skip-num", type=int, default=4) - parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument("--frames_stack", type=int, default=4) + parser.add_argument("--skip_num", type=int, default=4) + parser.add_argument("--resume_path", type=str, default=None) + parser.add_argument("--resume_id", type=str, default=None) parser.add_argument( "--logger", type=str, default="tensorboard", choices=["tensorboard", "wandb"], ) - parser.add_argument("--wandb-project", type=str, default="vizdoom.benchmark") + parser.add_argument("--wandb_project", type=str, default="vizdoom.benchmark") parser.add_argument( "--watch", default=False, @@ -76,26 +76,26 @@ def get_args() -> argparse.Namespace: help="watch the play of pre-trained policy only", ) parser.add_argument( - "--save-lmp", + "--save_lmp", default=False, action="store_true", help="save lmp file for replay whole episode", ) - parser.add_argument("--save-buffer-name", type=str, default=None) + parser.add_argument("--save_buffer_name", type=str, default=None) parser.add_argument( - "--icm-lr-scale", + "--icm_lr_scale", type=float, default=0.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--icm-reward-scale", + "--icm_reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--icm-forward-loss-weight", + "--icm_forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 69675115d..f48cf90cc 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -24,26 +24,26 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--actor-lr", type=float, default=1e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--actor_lr", type=float, default=1e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) + parser.add_argument("--exploration_noise", type=float, default=0.1) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=20000) parser.add_argument("--collection_step_num_env_steps", type=int, default=8) - parser.add_argument("--update-per-step", type=float, default=0.125) + parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_train_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--return_scaling", action="store_true", default=False) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index e45acea50..7d43a153d 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -25,9 +25,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=50000) + parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) @@ -37,7 +37,7 @@ def get_args() -> argparse.Namespace: "--update_step_num_repetitions", type=int, default=2 ) # theoretically it should be 1 parser.add_argument("--batch_size", type=int, default=99999) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") @@ -48,10 +48,10 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # npg special - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--optim-critic-iters", type=int, default=5) + parser.add_argument("--optim_critic_iters", type=int, default=5) parser.add_argument("--trust_region_size", type=float, default=0.5) return parser.parse_known_args()[0] diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index b7d44ab54..e177967e0 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -24,9 +24,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) @@ -34,7 +34,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--collection_step_num_episodes", type=int, default=16) parser.add_argument("--update_step_num_repetitions", type=int, default=2) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") @@ -45,18 +45,18 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=1) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index ad52dbea0..c813ed12f 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -25,27 +25,27 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--ensemble-size", type=int, default=4) - parser.add_argument("--subset-size", type=int, default=2) - parser.add_argument("--actor-lr", type=float, default=1e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--ensemble_size", type=int, default=4) + parser.add_argument("--subset_size", type=int, default=2) + parser.add_argument("--actor_lr", type=float, default=1e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", action="store_true", default=False) - parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--start-timesteps", type=int, default=1000) + parser.add_argument("--auto_alpha", action="store_true", default=False) + parser.add_argument("--alpha_lr", type=float, default=3e-4) + parser.add_argument("--start_timesteps", type=int, default=1000) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=5000) parser.add_argument("--collection_step_num_env_steps", type=int, default=1) - parser.add_argument("--update-per-step", type=int, default=3) - parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--update_per_step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--target-mode", type=str, choices=("min", "mean"), default="min") - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--target_mode", type=str, choices=("min", "mean"), default="min") + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index b06a47107..409b1644f 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -33,30 +33,30 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) - parser.add_argument("--il-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) + parser.add_argument("--il_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", type=int, default=1) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--auto_alpha", type=int, default=1) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=24000) - parser.add_argument("--il-step-per-epoch", type=int, default=500) + parser.add_argument("--il_step_per_epoch", type=int, default=500) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--imitation-hidden-sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--imitation_hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 59f1f8aaa..f6476f282 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -24,28 +24,28 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--actor-lr", type=float, default=1e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--actor_lr", type=float, default=1e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--policy-noise", type=float, default=0.2) - parser.add_argument("--noise-clip", type=float, default=0.5) - parser.add_argument("--update-actor-freq", type=int, default=2) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--policy_noise", type=float, default=0.2) + parser.add_argument("--noise_clip", type=float, default=0.5) + parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=20000) parser.add_argument("--collection_step_num_env_steps", type=int, default=8) - parser.add_argument("--update-per-step", type=float, default=0.125) + parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) parser.add_argument("--num_train_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 73a29c68c..671e258f0 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -25,9 +25,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=50000) + parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) @@ -37,7 +37,7 @@ def get_args() -> argparse.Namespace: "--update_step_num_repetitions", type=int, default=2 ) # theoretically it should be 1 parser.add_argument("--batch_size", type=int, default=99999) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") @@ -48,13 +48,13 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # trpo special - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--optim-critic-iters", type=int, default=5) - parser.add_argument("--max-kl", type=float, default=0.005) - parser.add_argument("--backtrack-coeff", type=float, default=0.8) - parser.add_argument("--max-backtracks", type=int, default=10) + parser.add_argument("--optim_critic_iters", type=int, default=5) + parser.add_argument("--max_kl", type=float, default=0.005) + parser.add_argument("--backtrack_coeff", type=float, default=0.8) + parser.add_argument("--max_backtracks", type=int, default=10) return parser.parse_known_args()[0] diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 95f25e0ac..e30d7ede1 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -28,22 +28,22 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--il-lr", type=float, default=1e-3) + parser.add_argument("--il_lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=50000) - parser.add_argument("--il-step-per-epoch", type=int, default=1000) + parser.add_argument("--il_step_per_epoch", type=int, default=1000) parser.add_argument("--collection_step_num_episodes", type=int, default=16) parser.add_argument("--collection_step_num_env_steps", type=int, default=16) - parser.add_argument("--update-per-step", type=float, default=1 / 16) + parser.add_argument("--update_per_step", type=float, default=1 / 16) parser.add_argument("--update_step_num_repetitions", type=int, default=1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--imitation-hidden-sizes", type=int, nargs="*", default=[128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--imitation_hidden_sizes", type=int, nargs="*", default=[128]) parser.add_argument("--num_train_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") @@ -54,10 +54,10 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # a2c special - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--max-grad-norm", type=float, default=None) - parser.add_argument("--gae-lambda", type=float, default=1.0) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--max_grad_norm", type=float, default=None) + parser.add_argument("--gae_lambda", type=float, default=1.0) parser.add_argument("--return_scaling", action="store_true", default=False) return parser.parse_known_args()[0] diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 6716265a0..224d9a4e3 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -19,25 +19,25 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # task parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) # network architecture - parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[64]) - parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[64]) - parser.add_argument("--action-per-branch", type=int, default=40) + parser.add_argument("--common_hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--action_hidden_sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--value_hidden_sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--action_per_branch", type=int, default=40) # training hyperparameters parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.01) - parser.add_argument("--eps-train", type=float, default=0.76) - parser.add_argument("--eps-decay", type=float, default=1e-4) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.01) + parser.add_argument("--eps_train", type=float, default=0.76) + parser.add_argument("--eps_decay", type=float, default=1e-4) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--target-update-freq", type=int, default=200) + parser.add_argument("--target_update_freq", type=int, default=200) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=80000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 24bc27a74..7795a54f9 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -30,29 +30,29 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=8000) parser.add_argument("--collection_step_num_env_steps", type=int, default=8) - parser.add_argument("--update-per-step", type=float, default=0.125) + parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_train_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument("--resume", action="store_true") @@ -61,7 +61,7 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index f94405d05..5cd8ae305 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -26,27 +26,27 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--actor-lr", type=float, default=1e-4) - parser.add_argument("--critic-lr", type=float, default=1e-3) - parser.add_argument("--alpha-lr", type=float, default=3e-4) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--actor_lr", type=float, default=1e-4) + parser.add_argument("--critic_lr", type=float, default=1e-3) + parser.add_argument("--alpha_lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--alpha", type=float, default=0.05) - parser.add_argument("--auto-alpha", action="store_true", default=False) + parser.add_argument("--auto_alpha", action="store_true", default=False) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument( "--device", type=str, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index de1fac1a1..654ff07c8 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -29,26 +29,26 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=20) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 1c2bbb021..95f250bbd 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -23,22 +23,22 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--stack-num", type=int, default=4) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--stack_num", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=20000) - parser.add_argument("--update-per-step", type=float, default=1 / 16) + parser.add_argument("--update_per_step", type=float, default=1 / 16) parser.add_argument("--collection_step_num_env_steps", type=int, default=16) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--layer-num", type=int, default=2) + parser.add_argument("--layer_num", type=int, default=2) parser.add_argument("--num_train_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index cb4f9b580..b0b92d126 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -30,30 +30,30 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-3) - parser.add_argument("--fraction-lr", type=float, default=2.5e-9) + parser.add_argument("--fraction_lr", type=float, default=2.5e-9) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-fractions", type=int, default=32) - parser.add_argument("--num-cosines", type=int, default=64) - parser.add_argument("--ent-coef", type=float, default=10.0) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_fractions", type=int, default=32) + parser.add_argument("--num_cosines", type=int, default=64) + parser.add_argument("--ent_coef", type=float, default=10.0) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64, 64]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 035d01fea..c0ad64700 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -30,30 +30,30 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--sample-size", type=int, default=32) - parser.add_argument("--online-sample-size", type=int, default=8) - parser.add_argument("--target-sample-size", type=int, default=8) - parser.add_argument("--num-cosines", type=int, default=64) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--sample_size", type=int, default=32) + parser.add_argument("--online_sample_size", type=int, default=8) + parser.add_argument("--target_sample_size", type=int, default=8) + parser.add_argument("--num_cosines", type=int, default=64) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64, 64]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index d18cb3d89..d8a6e61f9 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -29,9 +29,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=10) @@ -39,7 +39,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--collection_step_num_env_steps", type=int, default=2000) parser.add_argument("--update_step_num_repetitions", type=int, default=10) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=20) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") @@ -50,16 +50,16 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=0) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) return parser.parse_known_args()[0] diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 840402daf..de67f7797 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -28,27 +28,27 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 9e64d40d6..d326a4f17 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -30,40 +30,40 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-atoms", type=int, default=51) - parser.add_argument("--v-min", type=float, default=-10.0) - parser.add_argument("--v-max", type=float, default=10.0) - parser.add_argument("--noisy-std", type=float, default=0.1) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_atoms", type=int, default=51) + parser.add_argument("--v_min", type=float, default=-10.0) + parser.add_argument("--v_max", type=float, default=10.0) + parser.add_argument("--noisy_std", type=float, default=0.1) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=8000) parser.add_argument("--collection_step_num_env_steps", type=int, default=8) - parser.add_argument("--update-per-step", type=float, default=0.125) + parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_train_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) - parser.add_argument("--beta-final", type=float, default=1.0) + parser.add_argument("--beta_final", type=float, default=1.0) parser.add_argument("--resume", action="store_true") parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] diff --git a/test/discrete/test_reinforce.py b/test/discrete/test_reinforce.py index 68b66e4be..dd3376558 100644 --- a/test/discrete/test_reinforce.py +++ b/test/discrete/test_reinforce.py @@ -23,9 +23,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=10) @@ -33,7 +33,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--collection_step_num_episodes", type=int, default=8) parser.add_argument("--update_step_num_repetitions", type=int, default=2) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=8) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 70edd7e06..d44c003cc 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -26,26 +26,26 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=20) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) parser.add_argument( @@ -54,19 +54,19 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument( - "--lr-scale", + "--lr_scale", type=float, default=1.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--reward-scale", + "--reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--forward-loss-weight", + "--forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 58d91b4c0..c0ebb723a 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -28,9 +28,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--epoch", type=int, default=10) @@ -38,7 +38,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--collection_step_num_env_steps", type=int, default=2000) parser.add_argument("--update_step_num_repetitions", type=int, default=10) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=20) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") @@ -49,30 +49,30 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.5) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--vf_coef", type=float, default=0.5) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=0) parser.add_argument("--advantage_normalization", type=int, default=0) - parser.add_argument("--recompute-adv", type=int, default=0) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=0) parser.add_argument( - "--lr-scale", + "--lr_scale", type=float, default=1.0, help="use intrinsic curiosity module with this lr scale", ) parser.add_argument( - "--reward-scale", + "--reward_scale", type=float, default=0.01, help="scaling factor for intrinsic curiosity reward", ) parser.add_argument( - "--forward-loss-weight", + "--forward_loss_weight", type=float, default=0.2, help="weight for the forward model loss in ICM", diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 56e81e14c..a84718075 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -21,9 +21,9 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="NChain-v0") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=50000) + parser.add_argument("--buffer_size", type=int, default=50000) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=1000) parser.add_argument("--collection_step_num_episodes", type=int, default=1) @@ -31,11 +31,11 @@ def get_args() -> argparse.Namespace: parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--rew-mean-prior", type=float, default=0.0) - parser.add_argument("--rew-std-prior", type=float, default=1.0) + parser.add_argument("--rew_mean_prior", type=float, default=0.0) + parser.add_argument("--rew_std_prior", type=float, default=1.0) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--eps", type=float, default=0.01) - parser.add_argument("--add-done-loop", action="store_true", default=False) + parser.add_argument("--add_done_loop", action="store_true", default=False) parser.add_argument( "--logger", type=str, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 934affc62..583e2cf68 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -31,30 +31,30 @@ def expert_file_name() -> str: def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--gamma", type=float, default=0.9) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--epoch_num_steps", type=int, default=10000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--prioritized-replay", action="store_true", default=False) + parser.add_argument("--prioritized_replay", action="store_true", default=False) parser.add_argument("--alpha", type=float, default=0.6) parser.add_argument("--beta", type=float, default=0.4) - parser.add_argument("--save-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--save_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index f433794e3..967be29d7 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -27,19 +27,19 @@ def expert_file_name() -> str: def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=20000) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--buffer_size", type=int, default=20000) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=7) parser.add_argument("--epoch_num_steps", type=int, default=8000) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.125) + parser.add_argument("--update_per_step", type=float, default=0.125) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) parser.add_argument("--gamma", default=0.99) @@ -49,7 +49,7 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, @@ -58,10 +58,10 @@ def get_args() -> argparse.Namespace: ) # sac: parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", type=int, default=1) - parser.add_argument("--alpha-lr", type=float, default=3e-4) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--save-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--auto_alpha", type=int, default=1) + parser.add_argument("--alpha_lr", type=float, default=3e-4) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--save_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 4f72a3035..ad5a3f411 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -25,11 +25,11 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--batch_size", type=int, default=32) @@ -37,7 +37,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) - parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[32, 32]) + parser.add_argument("--vae_hidden_sizes", type=int, nargs="*", default=[32, 32]) # default to 2 * action_dim parser.add_argument("--latent_dim", type=int, default=None) parser.add_argument("--gamma", default=0.99) @@ -51,15 +51,15 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) - parser.add_argument("--show-progress", action="store_true") + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) + parser.add_argument("--show_progress", action="store_true") return parser.parse_known_args()[0] diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 38cb65286..dcc7d53ad 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -25,29 +25,29 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--auto-alpha", default=True, action="store_true") - parser.add_argument("--alpha-lr", type=float, default=1e-3) - parser.add_argument("--cql-alpha-lr", type=float, default=1e-3) - parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument("--auto_alpha", default=True, action="store_true") + parser.add_argument("--alpha_lr", type=float, default=1e-3) + parser.add_argument("--cql_alpha_lr", type=float, default=1e-3) + parser.add_argument("--start_timesteps", type=int, default=10000) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=500) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--cql-weight", type=float, default=1.0) - parser.add_argument("--with-lagrange", type=bool, default=True) - parser.add_argument("--lagrange-threshold", type=float, default=10.0) + parser.add_argument("--cql_weight", type=float, default=1.0) + parser.add_argument("--with_lagrange", type=bool, default=True) + parser.add_argument("--lagrange_threshold", type=float, default=10.0) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--eval-freq", type=int, default=1) + parser.add_argument("--eval_freq", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) @@ -56,14 +56,14 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 565af0981..c04fbacab 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -29,30 +29,30 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) - parser.add_argument("--unlikely-action-threshold", type=float, default=0.6) - parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) + parser.add_argument("--unlikely_action_threshold", type=float, default=0.6) + parser.add_argument("--imitation_logits_penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=2000) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume", action="store_true") - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) return parser.parse_known_args()[0] diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 622949d46..aa867a140 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -28,23 +28,23 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--eps_test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--num-quantiles", type=int, default=200) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=500) - parser.add_argument("--min-q-weight", type=float, default=10.0) + parser.add_argument("--num_quantiles", type=int, default=200) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=500) + parser.add_argument("--min_q_weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64]) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 60e33d003..71a2e23d1 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -29,20 +29,20 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=7e-4) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.0) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) parser.add_argument( "--device", type=str, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 153ae7ff4..f4df597dc 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -25,19 +25,19 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-3) - parser.add_argument("--disc-lr", type=float, default=5e-4) + parser.add_argument("--disc_lr", type=float, default=5e-4) parser.add_argument("--gamma", type=float, default=0.95) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=150000) parser.add_argument("--collection_step_num_episodes", type=int, default=16) parser.add_argument("--update_step_num_repetitions", type=int, default=2) - parser.add_argument("--disc-update-num", type=int, default=2) + parser.add_argument("--disc_update_num", type=int, default=2) parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=16) parser.add_argument("--num_test_envs", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") @@ -48,19 +48,19 @@ def get_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=1) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") - parser.add_argument("--save-interval", type=int, default=4) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--save_interval", type=int, default=4) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 4f0c74934..99f7e15bc 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -27,24 +27,24 @@ def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") - parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument("--reward_threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) - parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--actor_lr", type=float, default=1e-3) + parser.add_argument("--critic_lr", type=float, default=1e-3) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=500) - parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--n_step", type=int, default=3) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--alpha", type=float, default=2.5) - parser.add_argument("--exploration-noise", type=float, default=0.1) - parser.add_argument("--policy-noise", type=float, default=0.2) - parser.add_argument("--noise-clip", type=float, default=0.5) - parser.add_argument("--update-actor-freq", type=int, default=2) + parser.add_argument("--exploration_noise", type=float, default=0.1) + parser.add_argument("--policy_noise", type=float, default=0.2) + parser.add_argument("--noise_clip", type=float, default=0.5) + parser.add_argument("--update_actor_freq", type=int, default=2) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--gamma", type=float, default=0.99) - parser.add_argument("--eval-freq", type=int, default=1) + parser.add_argument("--eval_freq", type=int, default=1) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=1 / 35) @@ -53,14 +53,14 @@ def get_args() -> argparse.Namespace: type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) - parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume_path", type=str, default=None) parser.add_argument( "--watch", default=False, action="store_true", help="watch the play of pre-trained policy only", ) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + parser.add_argument("--load_buffer_name", type=str, default=expert_file_name()) return parser.parse_known_args()[0] diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 71f17407e..b94668b5d 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -23,9 +23,9 @@ def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=2000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=2000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument( "--gamma", @@ -34,19 +34,19 @@ def get_parser() -> argparse.ArgumentParser: help="a smaller gamma favors earlier win", ) parser.add_argument( - "--n-pistons", + "--n_pistons", type=int, default=3, help="Number of pistons(agents) in the env", ) - parser.add_argument("--n-step", type=int, default=100) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=100) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=3) parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=100) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 87f73f104..30efcdd01 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -74,9 +74,9 @@ def forward( def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=2000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=2000) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument( "--gamma", @@ -85,21 +85,21 @@ def get_parser() -> argparse.ArgumentParser: help="a smaller gamma favors earlier win", ) parser.add_argument( - "--n-pistons", + "--n_pistons", type=int, default=3, help="Number of pistons(agents) in the env", ) - parser.add_argument("--n-step", type=int, default=100) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=100) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--epoch_num_steps", type=int, default=500) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) parser.add_argument("--collection_step_num_episodes", type=int, default=16) parser.add_argument("--update_step_num_repetitions", type=int, default=2) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") @@ -116,18 +116,18 @@ def get_parser() -> argparse.ArgumentParser: default="cuda" if torch.cuda.is_available() else "cpu", ) # ppo special - parser.add_argument("--vf-coef", type=float, default=0.25) - parser.add_argument("--ent-coef", type=float, default=0.0) - parser.add_argument("--eps-clip", type=float, default=0.2) - parser.add_argument("--max-grad-norm", type=float, default=0.5) - parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--vf_coef", type=float, default=0.25) + parser.add_argument("--ent_coef", type=float, default=0.0) + parser.add_argument("--eps_clip", type=float, default=0.2) + parser.add_argument("--max_grad_norm", type=float, default=0.5) + parser.add_argument("--gae_lambda", type=float, default=0.95) parser.add_argument("--return_scaling", type=int, default=1) - parser.add_argument("--dual-clip", type=float, default=None) - parser.add_argument("--value-clip", type=int, default=1) + parser.add_argument("--dual_clip", type=float, default=None) + parser.add_argument("--value_clip", type=int, default=1) parser.add_argument("--advantage_normalization", type=int, default=1) - parser.add_argument("--recompute-adv", type=int, default=0) + parser.add_argument("--recompute_adv", type=int, default=0) parser.add_argument("--resume", action="store_true") - parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--save_interval", type=int, default=4) parser.add_argument("--render", type=float, default=0.0) return parser diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index d99778be5..4e7146199 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -34,9 +34,9 @@ def get_env(render_mode: str | None = None) -> PettingZooEnv: def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=1626) - parser.add_argument("--eps-test", type=float, default=0.05) - parser.add_argument("--eps-train", type=float, default=0.1) - parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--eps_test", type=float, default=0.05) + parser.add_argument("--eps_train", type=float, default=0.1) + parser.add_argument("--buffer_size", type=int, default=20000) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument( "--gamma", @@ -44,20 +44,20 @@ def get_parser() -> argparse.ArgumentParser: default=0.9, help="a smaller gamma favors earlier win", ) - parser.add_argument("--n-step", type=int, default=3) - parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--n_step", type=int, default=3) + parser.add_argument("--target_update_freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=50) parser.add_argument("--epoch_num_steps", type=int, default=1000) parser.add_argument("--collection_step_num_env_steps", type=int, default=10) - parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--update_per_step", type=float, default=0.1) parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128]) + parser.add_argument("--hidden_sizes", type=int, nargs="*", default=[128, 128, 128, 128]) parser.add_argument("--num_train_envs", type=int, default=10) parser.add_argument("--num_test_envs", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.1) parser.add_argument( - "--win-rate", + "--win_rate", type=float, default=0.6, help="the expected winning rate: Optimal policy can get 0.7", @@ -69,19 +69,19 @@ def get_parser() -> argparse.ArgumentParser: help="no training, watch the play of pre-trained models", ) parser.add_argument( - "--agent-id", + "--agent_id", type=int, default=2, help="the learned agent plays as the agent_id-th player. Choices are 1 and 2.", ) parser.add_argument( - "--resume-path", + "--resume_path", type=str, default="", help="the path of agent pth file for resuming from a pre-trained agent", ) parser.add_argument( - "--opponent-path", + "--opponent_path", type=str, default="", help="the path of opponent agent pth file for resuming from a pre-trained agent", From 0fa36cd19f6027a47ea7dab98633f652a69c2364 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 00:00:53 +0200 Subject: [PATCH 195/230] v2: Better handling of max_action in actor --- test/continuous/test_redq.py | 1 - test/offline/gather_pendulum_data.py | 1 - test/offline/test_bcq.py | 4 ++-- tianshou/algorithm/modelfree/reinforce.py | 29 +++++++++++++---------- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index c813ed12f..a43c43098 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -64,7 +64,6 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} args.reward_threshold = default_reward_threshold.get( diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 967be29d7..2e4513e3e 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -73,7 +73,6 @@ def gather_data() -> VectorReplayBuffer: space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index ad5a3f411..fc24f7cca 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -104,7 +104,7 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, ) - actor = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to( + actor_perturbation = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) @@ -141,7 +141,7 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr vae_optim = AdamOptimizerFactory() policy = BCQPolicy( - actor_perturbation=actor, + actor_perturbation=actor_perturbation, critic=critic, vae=vae, action_space=env.action_space, diff --git a/tianshou/algorithm/modelfree/reinforce.py b/tianshou/algorithm/modelfree/reinforce.py index 505eb2d68..9cc6f2323 100644 --- a/tianshou/algorithm/modelfree/reinforce.py +++ b/tianshou/algorithm/modelfree/reinforce.py @@ -148,14 +148,20 @@ def __init__( action_scaling=action_scaling, action_bound_method=action_bound_method, ) - if action_scaling and not np.isclose(actor.max_action, 1.0): - warnings.warn( - "action_scaling and action_bound_method are only intended " - "to deal with unbounded model action space, but find actor model " - f"bound action space with max_action={actor.max_action}. " - "Consider using unbounded=True option of the actor model, " - "or set action_scaling to False and action_bound_method to None.", - ) + if action_scaling: + try: + max_action = float(actor.max_action) # type: ignore + if np.isclose(max_action, 1.0): + warnings.warn( + "action_scaling and action_bound_method are only intended " + "to deal with unbounded model action space, but find actor model " + f"bound action space with max_action={actor.max_action}. " + "Consider using unbounded=True option of the actor model, " + "or set action_scaling to False and action_bound_method to None.", + ) + except: + pass + self.actor = actor self.dist_fn = dist_fn self._eps = 1e-8 @@ -286,7 +292,7 @@ def add_discounted_returns( should be marked by done flag, unfinished (or collecting) episodes will be recognized by buffer.unfinished_index(). :param buffer: the corresponding replay buffer. - :param numpy.ndarray indices: tell batch's location in buffer, batch is equal + :param indices: tell batch's location in buffer, batch is equal to buffer[indices]. """ v_s_ = np.full(indices.shape, self.ret_rms.mean) @@ -306,8 +312,7 @@ def add_discounted_returns( self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns - batch: BatchWithReturnsProtocol - return batch + return cast(BatchWithReturnsProtocol, batch) class Reinforce(OnPolicyAlgorithm[ActorPolicyProbabilistic]): @@ -316,7 +321,7 @@ class Reinforce(OnPolicyAlgorithm[ActorPolicyProbabilistic]): def __init__( self, *, - policy: TActorPolicy, + policy: ActorPolicyProbabilistic, gamma: float = 0.99, return_standardization: bool = False, optim: OptimizerFactory, From 2ab33f48664b3e7966802e5fc9ea3d9581547b2a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 17 May 2025 00:24:17 +0200 Subject: [PATCH 196/230] v2: Rename trainer parameter reward_metric -> multi_agent_return_reduction --- test/pettingzoo/pistonball.py | 2 +- test/pettingzoo/tic_tac_toe.py | 2 +- tianshou/trainer/trainer.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index b94668b5d..7f7857824 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -177,7 +177,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: update_step_num_gradient_steps_per_sample=args.update_per_step, logger=logger, test_in_train=False, - reward_metric=reward_metric, + multi_agent_return_reduction=reward_metric, ) ) return result, marl_algorithm diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 4e7146199..099be67a5 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -222,7 +222,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: update_step_num_gradient_steps_per_sample=args.update_per_step, logger=logger, test_in_train=False, - reward_metric=reward_metric, + multi_agent_return_reduction=reward_metric, ) ) diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index e31150acf..8c5301df2 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -164,13 +164,13 @@ class TrainerParams(ToStringMixin): which is given in :attr:`logger`. """ - reward_metric: Callable[[np.ndarray], np.ndarray] | None = None + multi_agent_return_reduction: Callable[[np.ndarray], np.ndarray] | None = None """ a function with signature - ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - which is used in multi-agent RL. We need to return a single scalar for each episode's result + ``f(returns: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + which is used in multi-agent RL. We need to return a single scalar for each episode's return to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, - e.g., the reward of agent 1 or the average reward over all agents. + e.g., the return achieved by agent 1 or the average return over all agents. """ logger: BaseLogger | None = None @@ -637,8 +637,8 @@ def _collect_test_episodes( if self.params.test_fn: self.params.test_fn(self._epoch, self._env_step) result = collector.collect(n_episode=self.params.test_step_num_episodes) - if self.params.reward_metric: # TODO: move into collector - rew = self.params.reward_metric(result.returns) + if self.params.multi_agent_return_reduction: + rew = self.params.multi_agent_return_reduction(result.returns) result.returns = rew result.returns_stat = SequenceSummaryStats.from_sequence(rew) if self._logger and self._env_step is not None: @@ -933,8 +933,8 @@ def _collect_training_data(self) -> CollectStats: if collect_stats.n_collected_episodes > 0: assert collect_stats.returns_stat is not None # for mypy assert collect_stats.lens_stat is not None # for mypy - if self.params.reward_metric: # TODO: move inside collector - rew = self.params.reward_metric(collect_stats.returns) + if self.params.multi_agent_return_reduction: + rew = self.params.multi_agent_return_reduction(collect_stats.returns) collect_stats.returns = rew collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) From cf1a34d1b1f249787abda188ff27946c4329965d Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 00:30:29 +0200 Subject: [PATCH 197/230] v2: block comment --- tianshou/trainer/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index e31150acf..1d6f0b5a2 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -1089,15 +1089,13 @@ def _update_step( # just for logging, no functional role self._policy_update_time += training_stat.train_time - # Note 1: this is the main difference to the off-policy trainer! - # The second difference is that batches of data are sampled without replacement - # during training, whereas in off-policy or offline training, the batches are - # sampled with replacement (and potentially custom prioritization). # Note 2: in the policy-update we modify the buffer, which is not very clean. # currently the modification will erase previous samples but keep things like - # _ep_rew and _ep_len. This means that such quantities can no longer be computed + # _ep_rew and _ep_len (b/c keep_statistics=True). This is needed since the collection might have stopped + # in the middle of an episode and in the next collect iteration we need these numbers to compute correct + # return and episode length values. With the current code structure, this means that after an update and buffer reset + # such quantities can no longer be computed # from samples still contained in the buffer, which is also not clean - # TODO: improve this situation self.params.train_collector.reset_buffer(keep_statistics=True) # The step is the number of mini-batches used for the update, so essentially From 91b9120c8e5d85d39f0152c3578e95d4b67affc7 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 00:33:16 +0200 Subject: [PATCH 198/230] v2: minor, renamed kwarg --- tianshou/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 1d6f0b5a2..3dd4c28a1 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -1072,7 +1072,7 @@ class OnPolicyTrainer(OnlineTrainer[OnPolicyAlgorithm, OnPolicyTrainerParams]): def _update_step( self, - result: CollectStatsBase | None = None, + collect_stats: CollectStatsBase | None = None, ) -> TrainingStats: """Perform one on-policy update by passing the entire buffer to the algorithm's update method.""" assert self.params.train_collector is not None From 54c6ebae772ca78e06fb4c6d758976c43075686f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 17 May 2025 00:37:55 +0200 Subject: [PATCH 199/230] v2: Remove some TODOs --- tianshou/trainer/trainer.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 8c5301df2..310ff3286 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -699,8 +699,6 @@ def _test_step( def _training_step(self) -> _TrainingStepResult: """Performs one training step.""" - # TODO: move moving average computation and logging into its own logger - # TODO: maybe think about a command line logger instead of always printing data dict def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" cur_losses_dict = update_stat.get_loss_stats_dict() @@ -1018,7 +1016,6 @@ class OffPolicyTrainer(OnlineTrainer[OffPolicyAlgorithm, OffPolicyTrainerParams] def _update_step( self, - # TODO: this is the only implementation where collect_stats is actually needed. Maybe change interface? collect_stats: CollectStatsBase, ) -> TrainingStats: """Perform `update_step_num_gradient_steps_per_sample * n_collected_steps` gradient steps by sampling @@ -1076,7 +1073,6 @@ def _update_step( ) -> TrainingStats: """Perform one on-policy update by passing the entire buffer to the algorithm's update method.""" assert self.params.train_collector is not None - # TODO: add logging like in off-policy. Iteration over minibatches currently happens in the algorithms themselves. log.info( f"Performing on-policy update on buffer of length {len(self.params.train_collector.buffer)}", ) @@ -1089,10 +1085,6 @@ def _update_step( # just for logging, no functional role self._policy_update_time += training_stat.train_time - # Note 1: this is the main difference to the off-policy trainer! - # The second difference is that batches of data are sampled without replacement - # during training, whereas in off-policy or offline training, the batches are - # sampled with replacement (and potentially custom prioritization). # Note 2: in the policy-update we modify the buffer, which is not very clean. # currently the modification will erase previous samples but keep things like # _ep_rew and _ep_len. This means that such quantities can no longer be computed From efa64a3d0ee886cd0eebef0c53218bf18d943ee6 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 00:45:05 +0200 Subject: [PATCH 200/230] v2: Minor typefix --- tianshou/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 3ceec87ee..16683f747 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -812,7 +812,7 @@ class OnlineTrainer( def __init__( self, algorithm: TAlgorithm, - params: OnlineTrainerParams, + params: TOnlineTrainerParams, ): super().__init__(algorithm, params) self._env_episode = 0 From 50ef0c96b919beca2345fccb3a692d32c56533e0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 00:50:39 +0200 Subject: [PATCH 201/230] v2: removed unused kwarg --- tianshou/algorithm/algorithm_base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tianshou/algorithm/algorithm_base.py b/tianshou/algorithm/algorithm_base.py index 2619b344a..eed5e4b49 100644 --- a/tianshou/algorithm/algorithm_base.py +++ b/tianshou/algorithm/algorithm_base.py @@ -723,7 +723,6 @@ def compute_nstep_return( target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor], gamma: float = 0.99, n_step: int = 1, - return_scaling: bool = False, ) -> BatchWithReturnsProtocol: r""" Computes the n-step return for Q-learning targets, adds it to the batch and returns the resulting batch. @@ -749,8 +748,6 @@ def compute_nstep_return( Typically set between 0.9 and 0.99 for most reinforcement learning tasks :param n_step: the number of estimation step, should be an int greater than 0. - :param return_scaling: whether to standardise returns to Normal(0, 1); - supported is currently suspended! :return: a Batch. The result will be stored in `batch.returns` as a torch.Tensor with the same shape as target_q_fn's return tensor. """ From 2972b1319840e148abfb8628e9f13dc83cca9aaa Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 01:01:05 +0200 Subject: [PATCH 202/230] v2: Comments, typos, minor renaming --- test/offline/test_bcq.py | 4 +++- tianshou/algorithm/modelfree/reinforce.py | 6 +++--- tianshou/trainer/trainer.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index fc24f7cca..eeafcf922 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -104,7 +104,9 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr output_dim=args.action_dim, hidden_sizes=args.hidden_sizes, ) - actor_perturbation = Perturbation(preprocess_net=net_a, max_action=args.max_action, phi=args.phi).to( + actor_perturbation = Perturbation( + preprocess_net=net_a, max_action=args.max_action, phi=args.phi + ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) diff --git a/tianshou/algorithm/modelfree/reinforce.py b/tianshou/algorithm/modelfree/reinforce.py index 9cc6f2323..17949b977 100644 --- a/tianshou/algorithm/modelfree/reinforce.py +++ b/tianshou/algorithm/modelfree/reinforce.py @@ -150,16 +150,16 @@ def __init__( ) if action_scaling: try: - max_action = float(actor.max_action) # type: ignore + max_action = float(actor.max_action) # type: ignore if np.isclose(max_action, 1.0): warnings.warn( "action_scaling and action_bound_method are only intended " - "to deal with unbounded model action space, but find actor model " + "to deal with unbounded model action space, but found actor model " f"bound action space with max_action={actor.max_action}. " "Consider using unbounded=True option of the actor model, " "or set action_scaling to False and action_bound_method to None.", ) - except: + except BaseException: pass self.actor = actor diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 16683f747..642bf9a9c 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -1089,7 +1089,7 @@ def _update_step( # currently the modification will erase previous samples but keep things like # _ep_rew and _ep_len (b/c keep_statistics=True). This is needed since the collection might have stopped # in the middle of an episode and in the next collect iteration we need these numbers to compute correct - # return and episode length values. With the current code structure, this means that after an update and buffer reset + # return and episode length values. With the current code structure, this means that after an update and buffer reset # such quantities can no longer be computed # from samples still contained in the buffer, which is also not clean self.params.train_collector.reset_buffer(keep_statistics=True) From 03123510aae879626e7926f8a0cffce9bccccc97 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 01:02:38 +0200 Subject: [PATCH 203/230] v2: add_exploration_noise - raise error on wrong type instead of doing nothing --- tianshou/algorithm/modelfree/bdqn.py | 14 +++++++++----- tianshou/algorithm/modelfree/dqn.py | 11 ++++++++--- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index fddfe3f42..a173d852b 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import cast import gymnasium as gym import numpy as np @@ -65,7 +65,6 @@ def forward( batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, model: torch.nn.Module | None = None, - **kwargs: Any, ) -> ModelOutputBatchProtocol: if model is None: model = self.model @@ -84,8 +83,9 @@ def add_exploration_noise( batch: ObsBatchProtocol, ) -> TArrOrActBatch: eps = self.eps_training if self.is_within_training_step else self.eps_inference - # TODO: This looks problematic; the non-array case is silently ignored - if isinstance(act, np.ndarray) and not np.isclose(eps, 0.0): + if not np.isclose(eps, 0.0): + return act + if isinstance(act, np.ndarray): bsz = len(act) rand_mask = np.random.rand(bsz) < eps rand_act = np.random.randint( @@ -96,7 +96,11 @@ def add_exploration_noise( if hasattr(batch.obs, "mask"): rand_act += batch.obs.mask act[rand_mask] = rand_act[rand_mask] - return act + return act + else: + raise NotImplementedError( + f"Currently only numpy arrays are supported, got {type(act)=}." + ) class BDQN(QLearningOffPolicyAlgorithm[BDQNPolicy]): diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index 57142d2d6..182427f91 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -156,8 +156,10 @@ def add_exploration_noise( batch: ObsBatchProtocol, ) -> TArrOrActBatch: eps = self.eps_training if self.is_within_training_step else self.eps_inference - # TODO: This looks problematic; the non-array case is silently ignored - if isinstance(act, np.ndarray) and not np.isclose(eps, 0.0): + eps = self.eps_training if self.is_within_training_step else self.eps_inference + if not np.isclose(eps, 0.0): + return act + if isinstance(act, np.ndarray): batch_size = len(act) rand_mask = np.random.rand(batch_size) < eps self.action_space = cast(Discrete, self.action_space) # for mypy @@ -167,7 +169,10 @@ def add_exploration_noise( q += batch.obs.mask rand_act = q.argmax(axis=1) act[rand_mask] = rand_act[rand_mask] - return act + return act + raise NotImplementedError( + f"Currently only numpy array is supported for action, but got {type(act)}" + ) TDQNPolicy = TypeVar("TDQNPolicy", bound=DiscreteQLearningPolicy) From 2f95cec3b422371f255467676d26affadd575817 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 17 May 2025 01:03:32 +0200 Subject: [PATCH 204/230] v2: Remove TODOs --- tianshou/algorithm/modelfree/discrete_sac.py | 1 - tianshou/algorithm/modelfree/fqf.py | 2 -- tianshou/algorithm/modelfree/reinforce.py | 1 - tianshou/algorithm/modelfree/trpo.py | 3 +-- 4 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tianshou/algorithm/modelfree/discrete_sac.py b/tianshou/algorithm/modelfree/discrete_sac.py index ae1d1318f..99cf2c01c 100644 --- a/tianshou/algorithm/modelfree/discrete_sac.py +++ b/tianshou/algorithm/modelfree/discrete_sac.py @@ -28,7 +28,6 @@ class DiscreteSACTrainingStats(SACTrainingStats): TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) -# TODO: This is a vanilla discrete actor policy; we may not need this "specific" class. class DiscreteSACPolicy(Policy): def __init__( self, diff --git a/tianshou/algorithm/modelfree/fqf.py b/tianshou/algorithm/modelfree/fqf.py index 8aba2e048..8ebc6fa93 100644 --- a/tianshou/algorithm/modelfree/fqf.py +++ b/tianshou/algorithm/modelfree/fqf.py @@ -116,8 +116,6 @@ def __init__( optim: OptimizerFactory, fraction_optim: OptimizerFactory, gamma: float = 0.99, - # TODO: used as num_quantiles in QRDQNPolicy, but num_fractions in FQFPolicy. - # Rename? Or at least explain what happens here. num_fractions: int = 32, ent_coef: float = 0.0, n_step_return_horizon: int = 1, diff --git a/tianshou/algorithm/modelfree/reinforce.py b/tianshou/algorithm/modelfree/reinforce.py index 17949b977..567c5123e 100644 --- a/tianshou/algorithm/modelfree/reinforce.py +++ b/tianshou/algorithm/modelfree/reinforce.py @@ -81,7 +81,6 @@ def __init__( deterministic_eval: bool = False, action_space: gym.Space, observation_space: gym.Space | None = None, - # TODO: why change the default from the base? action_scaling: bool = True, action_bound_method: Literal["clip", "tanh"] | None = "clip", ) -> None: diff --git a/tianshou/algorithm/modelfree/trpo.py b/tianshou/algorithm/modelfree/trpo.py index 952ef679f..a52c1ac44 100644 --- a/tianshou/algorithm/modelfree/trpo.py +++ b/tianshou/algorithm/modelfree/trpo.py @@ -132,7 +132,7 @@ def _update_with_batch( # type: ignore[override] for minibatch in batch.split(split_batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient - dist = self.policy(minibatch).dist # TODO could come from batch + dist = self.policy(minibatch).dist ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * minibatch.adv).mean() @@ -191,7 +191,6 @@ def _update_with_batch( # type: ignore[override] ) # optimize critic - # TODO: remove type-ignore once the top-level type-ignore is removed for _ in range(self.optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) From 0d38340e0bb8154258c6919a0d243d2b74ec65a3 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 17 May 2025 02:08:53 +0200 Subject: [PATCH 205/230] v2: Fix mypy issues --- tianshou/algorithm/modelfree/bdqn.py | 2 +- tianshou/algorithm/modelfree/dqn.py | 3 +-- tianshou/algorithm/modelfree/reinforce.py | 2 +- tianshou/data/buffer/__init__.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index a173d852b..0869169c8 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -96,7 +96,7 @@ def add_exploration_noise( if hasattr(batch.obs, "mask"): rand_act += batch.obs.mask act[rand_mask] = rand_act[rand_mask] - return act + return act # type: ignore[return-value] else: raise NotImplementedError( f"Currently only numpy arrays are supported, got {type(act)=}." diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index 182427f91..8118325c9 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -155,7 +155,6 @@ def add_exploration_noise( act: TArrOrActBatch, batch: ObsBatchProtocol, ) -> TArrOrActBatch: - eps = self.eps_training if self.is_within_training_step else self.eps_inference eps = self.eps_training if self.is_within_training_step else self.eps_inference if not np.isclose(eps, 0.0): return act @@ -169,7 +168,7 @@ def add_exploration_noise( q += batch.obs.mask rand_act = q.argmax(axis=1) act[rand_mask] = rand_act[rand_mask] - return act + return act # type: ignore[return-value] raise NotImplementedError( f"Currently only numpy array is supported for action, but got {type(act)}" ) diff --git a/tianshou/algorithm/modelfree/reinforce.py b/tianshou/algorithm/modelfree/reinforce.py index 567c5123e..7fcc2f19a 100644 --- a/tianshou/algorithm/modelfree/reinforce.py +++ b/tianshou/algorithm/modelfree/reinforce.py @@ -149,7 +149,7 @@ def __init__( ) if action_scaling: try: - max_action = float(actor.max_action) # type: ignore + max_action = float(actor.max_action) if np.isclose(max_action, 1.0): warnings.warn( "action_scaling and action_bound_method are only intended " diff --git a/tianshou/data/buffer/__init__.py b/tianshou/data/buffer/__init__.py index 3a90d3bb8..0609fddf8 100644 --- a/tianshou/data/buffer/__init__.py +++ b/tianshou/data/buffer/__init__.py @@ -1,4 +1,4 @@ -def _backward_compatibility(): +def _backward_compatibility() -> None: import sys from . import buffer_base From d059481c9eda13c501c471313ead1eb7f1accbeb Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 14:51:03 +0200 Subject: [PATCH 206/230] v1: improvement in doc-build commands --- docs/autogen_rst.py | 4 +- docs/spelling_wordlist.txt | 294 ------------------------------------- pyproject.toml | 5 +- 3 files changed, 4 insertions(+), 299 deletions(-) delete mode 100644 docs/spelling_wordlist.txt diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py index b1a8b18d9..1477b2117 100644 --- a/docs/autogen_rst.py +++ b/docs/autogen_rst.py @@ -114,8 +114,8 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix="" for f in files_in_dir if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_") ] - if not module_names: - log.debug(f"Skipping {dirname} as it does not contain any .py files") + if not module_names and not "__init__.py" in files_in_dir: + log.debug(f"Skipping {dirname} as it does not contain any modules or __init__.py") continue package_qualname = f"{base_package_qualname}.{dirname}" package_index_rst_path = os.path.join( diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt deleted file mode 100644 index 40aa69970..000000000 --- a/docs/spelling_wordlist.txt +++ /dev/null @@ -1,294 +0,0 @@ -tianshou -arXiv -tanh -lr -logits -env -envs -optim -eps -timelimit -TimeLimit -envpool -EnvPool -maxsize -timestep -timesteps -numpy -ndarray -stackoverflow -tensorboard -state_dict -len -tac -fqf -iqn -qrdqn -rl -offpolicy -onpolicy -quantile -quantiles -dqn -param -async -subprocess -deque -nn -equ -cql -fn -boolean -pre -np -cuda -rnn -rew -pre -perceptron -bsz -dataset -mujoco -jit -nstep -preprocess -preprocessing -repo -ReLU -namespace -recv -th -utils -NaN -linesearch -hyperparameters -pseudocode -entropies -nn -config -cpu -rms -debias -indice -regularizer -miniblock -modularize -serializable -softmax -vectorized -optimizers -undiscounted -submodule -subclasses -submodules -tfevent -dirichlet -docstring -webpage -formatter -num -py -pythonic -中文文档位于 -conda -miniconda -Amir -Andreas -Antonoglou -Beattie -Bellemare -Charles -Daan -Demis -Dharshan -Fidjeland -Georg -Hassabis -Helen -Ioannis -Kavukcuoglu -King -Koray -Kumaran -Legg -Mnih -Ostrovski -Petersen -Riedmiller -Rusu -Sadik -Shane -Stig -Veness -Volodymyr -Wierstra -Lillicrap -Pritzel -Heess -Erez -Yuval -Tassa -Schulman -Filip -Wolski -Prafulla -Dhariwal -Radford -Oleg -Klimov -Kaichao -Jiayi -Weng -Duburcq -Huayu -Yi -Su -Strens -Ornstein -Uhlenbeck -mse -gail -airl -ppo -Jupyter -Colab -Colaboratory -IPendulum -Reacher -Runtime -Nvidia -Enduro -Qbert -Seaquest -subnets -subprocesses -isort -yapf -pydocstyle -Args -tuples -tuple -Multi -multi -parameterized -Proximal -metadata -GPU -Dopamine -builtin -params -inplace -deepcopy -Gaussian -stdout -parallelization -minibatch -minibatches -MLP -backpropagation -dataclass -superset -subtype -subdirectory -picklable -ShmemVectorEnv -Github -wandb -jupyter -img -src -parallelized -infty -venv -venvs -subproc -bcq -highlevel -icm -modelbased -td -psrl -ddpg -npg -tf -trpo -crr -pettingzoo -multidiscrete -vecbuf -prio -colab -segtree -multiagent -mapolicy -sensai -sensAI -docstrings -superclass -iterable -functools -str -sklearn -attr -bc -redq -modelfree -bdq -util -logp -autogenerated -subpackage -subpackages -recurse -rollout -rollouts -prepend -prepends -dict -dicts -pytorch -tensordict -onwards -Dominik -Tsinghua -Tianshou -appliedAI -macOS -joblib -master -Panchenko -BA -BH -BO -BD -configs -postfix -backend -rliable -hl -v_s -v_s_ -obs -obs_next -dtype -iqm -kwarg -entrypoint -interquantile -init -kwarg -kwargs -autocompletion -codebase -indexable -sliceable -gaussian -logprob -monte -carlo -subclass -subclassing -dist -dists -subbuffer -subbuffers diff --git a/pyproject.toml b/pyproject.toml index df37e9a06..fd1004878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,13 +228,12 @@ _poetry_sort = "poetry sort" clean-nbs = "python docs/nbstripout.py" format = ["_ruff_format", "_ruff_format_nb", "_black_format", "_poetry_install_sort_plugin", "_poetry_sort"] _autogen_rst = "python docs/autogen_rst.py" -_sphinx_build = "sphinx-build -W -b html docs docs/_build" +_sphinx_build = "sphinx-build -b html docs docs/_build -W --keep-going" _jb_generate_toc = "python docs/create_toc.py" _jb_generate_config = "jupyter-book config sphinx docs/" doc-clean = "rm -rf docs/_build" doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] -doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build" -doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"] +doc-build = ["doc-generate-files", "_sphinx_build"] _mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"] From 04290c81e91bdca00e98f97302673c27cc08aa61 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 14:51:51 +0200 Subject: [PATCH 207/230] v2: docs - removed outdated documents, fixed remaning --- docs/01_tutorials/00_dqn.rst | 2 +- docs/01_tutorials/01_concepts.rst | 10 +- docs/02_notebooks/L0_overview.ipynb | 250 ---- docs/02_notebooks/L1_Batch.ipynb | 4 +- docs/02_notebooks/L2_Buffer.ipynb | 4 +- .../L3_Vectorized__Environment.ipynb | 4 +- docs/02_notebooks/L4_GAE.ipynb | 265 +++++ docs/02_notebooks/L4_Policy.ipynb | 1009 ----------------- docs/02_notebooks/L5_Collector.ipynb | 27 +- docs/02_notebooks/L6_Trainer.ipynb | 283 ----- docs/02_notebooks/L7_Experiment.ipynb | 341 ------ 11 files changed, 284 insertions(+), 1915 deletions(-) delete mode 100644 docs/02_notebooks/L0_overview.ipynb create mode 100644 docs/02_notebooks/L4_GAE.ipynb delete mode 100644 docs/02_notebooks/L4_Policy.ipynb delete mode 100644 docs/02_notebooks/L6_Trainer.ipynb delete mode 100644 docs/02_notebooks/L7_Experiment.ipynb diff --git a/docs/01_tutorials/00_dqn.rst b/docs/01_tutorials/00_dqn.rst index bb73d4c52..77611fdf1 100644 --- a/docs/01_tutorials/00_dqn.rst +++ b/docs/01_tutorials/00_dqn.rst @@ -188,7 +188,7 @@ The main function of collector is the collect function, which can be summarized Train Policy with a Trainer --------------------------- -Tianshou provides :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, +Tianshou provides :class:`~tianshou.trainer.OnPolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, and :class:`~tianshou.trainer.OfflineTrainer`. The trainer will automatically stop training when the policy reaches the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.OffpolicyTrainer` as follows: diff --git a/docs/01_tutorials/01_concepts.rst b/docs/01_tutorials/01_concepts.rst index aa244b137..0f381262e 100644 --- a/docs/01_tutorials/01_concepts.rst +++ b/docs/01_tutorials/01_concepts.rst @@ -339,8 +339,6 @@ Thus, we need a time-related interface for calculating the 2-step return. :meth: This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.algorithm.BasePolicy.process_fn`. -For other method, you can check out :doc:`/03_api/policy/index`. We give the usage of policy class a high-level explanation in :ref:`pseudocode`. - Collector --------- @@ -384,10 +382,10 @@ Once you have a collector and a policy, you can start writing the training metho Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/03_api/trainer/index` for the usage. -We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic: +We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnPolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic: :: - trainer = OnpolicyTrainer(...) + trainer = OnPolicyTrainer(...) for epoch, epoch_stat, info in trainer: print(f"Epoch: {epoch}") print(epoch_stat) @@ -399,8 +397,8 @@ We also provide the corresponding iterator-based trainer classes :class:`~tiansh # or even iterate on several trainers at the same time - trainer1 = OnpolicyTrainer(...) - trainer2 = OnpolicyTrainer(...) + trainer1 = OnPolicyTrainer(...) + trainer2 = OnPolicyTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...) diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb deleted file mode 100644 index 4c983deb3..000000000 --- a/docs/02_notebooks/L0_overview.ipynb +++ /dev/null @@ -1,250 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "editable": true, - "id": "r7aE6Rq3cAEE", - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "# Overview\n", - "To begin, ensure you have Tianshou and the Gym environment installed by executing the following commands. This tutorials will always keep up with the latest version of Tianshou since they also serve as a test for the latest version. For users on older versions of Tianshou, please consult the [documentation](https://tianshou.readthedocs.io/en/latest/) corresponding to your version..\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1_mLTSEIcY2c" - }, - "source": [ - "## Run the code" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IcFNmCjYeIIU" - }, - "source": [ - "Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v1\n", - "problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n", - "exactly what this tutorial is for.\n", - "\n", - "If the script ends normally, you will see the evaluation result printed out before the first\n", - "epoch is finished." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.algorithm import PPOPolicy\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import ActorCritic, MLPActor\n", - "from tianshou.utils.net.discrete import Actor, Critic\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# environments\n", - "env = gym.make(\"CartPole-v1\")\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])\n", - "\n", - "# model & optimizer\n", - "assert env.observation_space.shape is not None # for mypy\n", - "net = MLPActor(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", - "\n", - "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", - "critic = Critic(preprocess_net=net, device=device).to(device)\n", - "actor_critic = ActorCritic(actor, critic)\n", - "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n", - "\n", - "# PPO policy\n", - "dist = torch.distributions.Categorical\n", - "policy: PPOPolicy = PPOPolicy(\n", - " actor=actor,\n", - " critic=critic,\n", - " optim=optim,\n", - " dist_fn=dist,\n", - " action_space=env.action_space,\n", - " action_scaling=False,\n", - ")\n", - "\n", - "# collector\n", - "train_collector = Collector[CollectStats](\n", - " policy,\n", - " train_envs,\n", - " VectorReplayBuffer(20000, len(train_envs)),\n", - ")\n", - "test_collector = Collector[CollectStats](policy, test_envs)\n", - "\n", - "# trainer\n", - "train_result = OnpolicyTrainer(\n", - " policy=policy,\n", - " batch_size=256,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=50000,\n", - " repeat_per_collect=10,\n", - " episode_per_test=10,\n", - " step_per_collect=2000,\n", - " stop_fn=lambda mean_reward: mean_reward >= 195,\n", - ").run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "train_result.pprint_asdict()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "G9YEQptYvCgx", - "outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a" - }, - "outputs": [], - "source": [ - "# Let's watch its performance!\n", - "policy.eval()\n", - "eval_result = test_collector.collect(n_episode=3, render=False)\n", - "print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xFYlcPo8fpPU" - }, - "source": [ - "## Tutorial Introduction\n", - "\n", - "A common DRL experiment as is shown above may require many components to work together. The agent, the\n", - "environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n", - "training task.\n", - "\n", - "
\n", - "\n", - "\n", - "
\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kV_uOyimj-bk" - }, - "source": [ - "In Tianshou, all of these main components are factored out as different building blocks, which you\n", - "can use to create your own algorithm and finish your own experiment.\n", - "\n", - "Building blocks may include:\n", - "- Batch\n", - "- Replay Buffer\n", - "- Vectorized Environment Wrapper\n", - "- Policy (the agent and the training algorithm)\n", - "- Data Collector\n", - "- Trainer\n", - "- Logger\n", - "\n", - "\n", - "These notebooks tutorials will guide you through all the modules one by one." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S0mNKwH9i6Ek" - }, - "source": [ - "## Further reading" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "M3NPSUnAov4L" - }, - "source": [ - "### What if I am not familiar with the PPO algorithm itself?\n", - "As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n", - "plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's tutorials, we will\n", - "focus on the usages of different modules, but not the algorithms themselves." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L1_Batch.ipynb b/docs/02_notebooks/L1_Batch.ipynb index 4e56c4a1c..d40869287 100644 --- a/docs/02_notebooks/L1_Batch.ipynb +++ b/docs/02_notebooks/L1_Batch.ipynb @@ -31,8 +31,6 @@ }, "outputs": [], "source": [ - "%%capture\n", - "\n", "import pickle\n", "\n", "import numpy as np\n", @@ -401,7 +399,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L2_Buffer.ipynb b/docs/02_notebooks/L2_Buffer.ipynb index 892aaff4f..4f51abca5 100644 --- a/docs/02_notebooks/L2_Buffer.ipynb +++ b/docs/02_notebooks/L2_Buffer.ipynb @@ -6,8 +6,6 @@ "metadata": {}, "outputs": [], "source": [ - "%%capture\n", - "\n", "import pickle\n", "\n", "import numpy as np\n", @@ -421,7 +419,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L3_Vectorized__Environment.ipynb b/docs/02_notebooks/L3_Vectorized__Environment.ipynb index 374ee2ba8..19e5489a2 100644 --- a/docs/02_notebooks/L3_Vectorized__Environment.ipynb +++ b/docs/02_notebooks/L3_Vectorized__Environment.ipynb @@ -47,8 +47,6 @@ }, "outputs": [], "source": [ - "%%capture\n", - "\n", "import time\n", "\n", "import gymnasium as gym\n", @@ -223,7 +221,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L4_GAE.ipynb b/docs/02_notebooks/L4_GAE.ipynb new file mode 100644 index 000000000..8393d6f92 --- /dev/null +++ b/docs/02_notebooks/L4_GAE.ipynb @@ -0,0 +1,265 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "QJ5krjrcbuiA" + }, + "source": [ + "# Notes on Generalized Advantage Estimation\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UPVl5LBEWJ0t" + }, + "source": [ + "## How to compute GAE on your own?\n", + "(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n", + "\n", + "In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n", + "\n", + "To compute GAE advantage, the usage of `self.compute_episodic_return()` may go like:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D34GlVvPNz08", + "outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215" + }, + "source": [ + "```python\n", + "batch, indices = dummy_buffer.sample(0) # 0 means sampling all the data from the buffer\n", + "returns, advantage = Algorithm.compute_episodic_return(\n", + " batch=batch,\n", + " buffer=dummy_buffer,\n", + " indices=indices,\n", + " v_s_=np.zeros(10),\n", + " v_s=np.zeros(10),\n", + " gamma=1.0,\n", + " gae_lambda=1.0,\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. However, the way the returns are computed here might be a bit misleading. In fact, the last episode is unfinished, but its last step saved in the batch is treated as a terminal state, since it assumes that there are no future rewards. The episode is not terminated yet, it is truncated, so the agent could still get rewards in the future. Terminated and truncated episodes should indeed be treated differently.\n", + "The return of a step is the (discounted) sum of the future rewards from that step until the end of the episode. \n", + "\\begin{equation}\n", + "R_{t}=\\sum_{t}^{T} \\gamma^{t} r_{t}\n", + "\\end{equation}\n", + "Thus, at the last step of a terminated episode the return is equal to the reward at that state, since there are no future states.\n", + "\\begin{equation}\n", + "R_{T,terminated}=r_{T}\n", + "\\end{equation}\n", + "\n", + "However, if the episode was truncated the return at the last step is usually better represented by the estimated value of that state, which is the expected return from that state onwards.\n", + "\\begin{align*}\n", + "R_{T,truncated}=V^{\\pi}\\left(s_{T}\\right) \\quad & \\text{or} \\quad R_{T,truncated}=Q^{\\pi}(s_{T},a_{T})\n", + "\\end{align*}\n", + "Moreover, if the next state was also observed (but not its reward), then an even better estimate would be the reward of the last step plus the discounted value of the next state.\n", + "\\begin{align*}\n", + "R_{T,truncated}=r_T+\\gamma V^{\\pi}\\left(s_{T+1}\\right)\n", + "\\end{align*}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h_5Dt6XwQLXV" + }, + "source": [ + "\n", + "As we know, we need to estimate the value function of every observation to compute GAE advantage. So in `v_s` is the value of `batch.obs`, and in `v_s_` is the value of `batch.obs_next`. This is usually computed by:\n", + "\n", + "`v_s = critic(batch.obs)`,\n", + "\n", + "`v_s_ = critic(batch.obs_next)`,\n", + "\n", + "where both `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n", + "\n", + "After we've got all those values, GAE can be computed following the equation below." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ooHNIICGUO19" + }, + "source": [ + "\\begin{aligned}\n", + "\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n", + "\\end{aligned}\n", + "\n", + "where\n", + "\n", + "\\begin{equation}\n", + "\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n", + "\\end{equation}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eV6XZaouU7EV" + }, + "source": [ + "Unfortunately, if you follow this equation, which is taken from the paper, you probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FCxD9gNNVYbd" + }, + "source": [ + "**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rNZNUNgQVvRJ", + "outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d" + }, + "source": [ + "```python\n", + "# Assume v_s_ is got by calling critic(batch.obs_next)\n", + "v_s_ = np.ones(10)\n", + "v_s_ *= ~batch.done\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2EtMi18QWXTN" + }, + "source": [ + "After the fix above, we will perhaps get a more accurate estimate.\n", + "\n", + "**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstrapping.\n", + "\n", + "Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "saluvX4JU6bC", + "outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5" + }, + "source": [ + "```python\n", + "unfinished_indexes = dummy_buffer.unfinished_index()\n", + "done_indexes = np.where(batch.done)[0]\n", + "stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qp6vVE4dYWv1" + }, + "source": [ + "**Thirdly**, there are some special indexes which are marked by done flag, however its value for obs_next should not be zero. It is again because done does not differentiate between terminated and truncated. These steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tWkqXRJfZTvV" + }, + "source": [ + "As a result, we need to rewrite the equation above\n", + "\n", + "`v_s_ *= ~batch.done`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kms-QtxKZe-M" + }, + "source": [ + "to\n", + "\n", + "```\n", + "mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n", + "v_s_ *= mask\n", + "\n", + "```\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u_aPPoKraBu6" + }, + "source": [ + "## Summary\n", + "If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `Algorithm.compute_episodic_return()`.\n", + "\n", + "If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `Algorithm.value_mask()` and `Algorithm.compute_episodic_return()` for details." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2cPnUXRBWKD9" + }, + "source": [ + "
\n", + "\n", + "
\n", + "
\n", + "\n", + "
" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb deleted file mode 100644 index fec0c5cfe..000000000 --- a/docs/02_notebooks/L4_Policy.ipynb +++ /dev/null @@ -1,1009 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "PNM9wqstBSY_" - }, - "source": [ - "# Policy\n", - "In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the agent part. In Tianshou, both the agent and the core DRL algorithm are implemented in the Policy module. Tianshou provides more than 20 Policy modules, each representing one DRL algorithm. See supported algorithms [here](https://tianshou.readthedocs.io/en/master/03_api/policy/index.html).\n", - "\n", - "
\n", - "\n", - "\n", - " The agents interacting with the environment \n", - "
\n", - "\n", - "All Policy modules inherit from a BasePolicy Class and share the same interface." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZqdHYdoJJS51" - }, - "source": [ - "## Creating your own Policy\n", - "We will use the simple PGPolicy, also called REINFORCE algorithm Policy, to show the implementation of a Policy Module. The Policy we implement here will be a scaled-down version of [PGPolicy](https://tianshou.readthedocs.io/en/master/03_api/policy/modelfree/pg.html) in Tianshou." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "cDlSjASbJmy-", - "slideshow": { - "slide_type": "" - }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "from typing import Any, cast\n", - "\n", - "import gymnasium as gym\n", - "import numpy as np\n", - "import torch\n", - "\n", - "from tianshou.algorithm import BasePolicy\n", - "from tianshou.algorithm.modelfree.pg import (\n", - " PGTrainingStats,\n", - " TDistFnDiscrOrCont,\n", - " TPGTrainingStats,\n", - ")\n", - "from tianshou.data import (\n", - " Batch,\n", - " ReplayBuffer,\n", - " SequenceSummaryStats,\n", - " to_torch,\n", - " to_torch_as,\n", - ")\n", - "from tianshou.data.batch import BatchProtocol\n", - "from tianshou.data.types import (\n", - " BatchWithReturnsProtocol,\n", - " DistBatchProtocol,\n", - " ObsBatchProtocol,\n", - " RolloutBatchProtocol,\n", - ")\n", - "from tianshou.utils import RunningMeanStd\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor\n", - "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Protocols\n", - "Note: as we learned in tutorial [L1_batch](https://tianshou.readthedocs.io/en/master/02_notebooks/L1_Batch.html#), Tianshou uses `Batch` to store data. `Batch` is a dataclass that can store any data you want. In order to have more control about what kind of batch data is expected and produced in each processing step we use protocols. \n", - "For example, `BatchWithReturnsProtocol` specifies that the batch should have fields `obs`, `act`, `rew`, `done`, `obs_next`, `info` and `returns`. This is not only for type checking, but also for IDE support.\n", - "To learn more about protocols, please refer to the official documentation ([PEP 544](https://www.python.org/dev/peps/pep-0544/)) or to mypy documentation ([Protocols](https://mypy.readthedocs.io/en/stable/protocols.html)).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Initialization\n", - "Firstly we create the `PGPolicy` by inheriting from `BasePolicy` in Tianshou." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "class PGPolicy(BasePolicy):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - "\n", - " def __init__(self) -> None:\n", - " super().__init__(\n", - " action_space=action_space,\n", - " observation_space=observation_space,\n", - " )\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qc1RnIBbLCDN" - }, - "source": [ - "The Policy Module mainly does two things:\n", - "\n", - "1. `policy.forward()` receives observation and other information (stored in a Batch) from the environment and returns a new Batch containing the next action and other information.\n", - "2. `policy.update()` receives training data sampled from the replay buffer and updates the policy network. It returns a dataclass containing logging details.\n", - "\n", - "
\n", - "\n", - "\n", - " policy.forward() and policy.update() \n", - "
\n", - "\n", - "We also need to take care of the following things:\n", - "\n", - "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network and a Torch optimizer in our Policy Module.\n", - "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to \n", - "preprocess training data and computes quantities like episodic returns (gradient free), \n", - "then it will call `Policy.learn()` to perform the back-propagation.\n", - "3. Each Policy is accompanied by a dedicated implementation of `TrainingStats` , which store details of each training step.\n", - "\n", - "This is how we get the implementation below." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "class PGPolicy(BasePolicy[TPGTrainingStats]):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - "\n", - " def __init__(\n", - " self, \n", - " actor: torch.nn.Module, \n", - " optim: torch.optim.Optimizer, \n", - " action_space: gym.Space\n", - " ):\n", - " super().__init__(\n", - " action_space=action_space,\n", - " observation_space=observation_space\n", - " )\n", - " self.actor = model\n", - " self.optim = optim\n", - "\n", - " def process_fn(\n", - " self, \n", - " batch: RolloutBatchProtocol, \n", - " buffer: ReplayBuffer, \n", - " indices: np.ndarray\n", - " ) -> BatchWithReturnsProtocol:\n", - " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", - " return batch\n", - "\n", - " def forward(\n", - " self, \n", - " batch: ObsBatchProtocol,\n", - " state: dict | BatchProtocol | np.ndarray | None = None,\n", - " **kwargs: Any\n", - " ) -> DistBatchProtocol:\n", - " \"\"\"Compute action over the given batch data.\"\"\"\n", - " act = None\n", - " return Batch(act=act)\n", - "\n", - " def learn(\n", - " self,\n", - " batch: BatchWithReturnsProtocol, \n", - " batch_size: int | None, \n", - " repeat: int,\n", - " *args: Any,\n", - " **kwargs: Any,\n", - " ) -> TPGTrainingStats:\n", - " \"\"\"Perform the back-propagation.\"\"\"\n", - " return PGTrainingStats(loss=loss_summary_stat)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tjtqjt8WRY5e" - }, - "source": [ - "### Policy.forward()\n", - "According to the equation of REINFORCE algorithm in Spinning Up's [documentation](https://spinningup.openai.com/en/latest/algorithms/vpg.html), we need to map the observation to an action distribution in action space using the neural network (`self.actor`).\n", - "\n", - "
\n", - "\n", - "
\n", - "\n", - "Let us suppose the action space is discrete, and the distribution is a simple categorical distribution." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "def forward(\n", - " self,\n", - " batch: ObsBatchProtocol,\n", - " state: dict | BatchProtocol | np.ndarray | None = None,\n", - " **kwargs: Any,\n", - ") -> DistBatchProtocol:\n", - " \"\"\"Compute action over the given batch data.\"\"\"\n", - " logits, hidden = self.actor(batch.obs, state=state)\n", - " dist = self.dist_fn(logits)\n", - " act = dist.sample()\n", - " result = Batch(logits=logits, act=act, state=hidden, dist=dist)\n", - " return result\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CultfOeuTx2V" - }, - "source": [ - "### Policy.process_fn()\n", - "Now that we have defined our actor, if given training data we can set up a loss function and optimize our neural network. However, before that, we must first calculate episodic returns for every step in our training data to construct the REINFORCE loss function.\n", - "\n", - "Calculating episodic return is not hard, given `ReplayBuffer.next()` allows us to access every reward to go in an episode. A more convenient way would be to simply use the built-in method `BasePolicy.compute_episodic_return()` inherited from BasePolicy.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "def process_fn(\n", - " self,\n", - " batch: RolloutBatchProtocol,\n", - " buffer: ReplayBuffer,\n", - " indices: np.ndarray,\n", - ") -> BatchWithReturnsProtocol:\n", - " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", - " v_s_ = np.full(indices.shape, self.ret_rms.mean)\n", - " returns, _ = self.compute_episodic_return(batch, buffer, indices, v_s_=v_s_, gamma=0.99, gae_lambda=1.0)\n", - " batch.returns = returns\n", - " return batch\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XA8OF4GnWWr5" - }, - "source": [ - "`BasePolicy.compute_episodic_return()` could also be used to compute [GAE](https://arxiv.org/abs/1506.02438). Another similar method is `BasePolicy.compute_nstep_return()`. Check the [source code](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L304) for more details." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7UsdzNaOXPpC" - }, - "source": [ - "### Policy.learn()\n", - "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finally,\n", - "we can construct our loss function and perform the back-propagation. The method \n", - "should look something like this:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```python\n", - "def learn(\n", - " self,\n", - " batch: BatchWithReturnsProtocol,\n", - " batch_size: int | None,\n", - " repeat: int,\n", - " *args: Any,\n", - " **kwargs: Any,\n", - ") -> TPGTrainingStats:\n", - " \"\"\"Perform the back-propagation.\"\"\"\n", - " losses = []\n", - " split_batch_size = batch_size or -1\n", - " for _ in range(repeat):\n", - " for minibatch in batch.split(split_batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", - " result = self(minibatch)\n", - " dist = result.dist\n", - " act = to_torch_as(minibatch.act, result.act)\n", - " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", - " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", - " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " losses.append(loss.item())\n", - " loss_summary_stat = SequenceSummaryStats.from_sequence(losses)\n", - "\n", - " return PGTrainingStats(loss=loss_summary_stat)\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1BtuV2W0YJTi" - }, - "source": [ - "## Implementation\n", - "Now we can assemble the methods and form a PGPolicy. The outputs of\n", - "`learn` will be collected to a dedicated dataclass." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class PGPolicy(BasePolicy[TPGTrainingStats]):\n", - " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " *,\n", - " actor: torch.nn.Module,\n", - " optim: torch.optim.Optimizer,\n", - " dist_fn: TDistFnDiscrOrCont,\n", - " action_space: gym.Space,\n", - " discount_factor: float = 0.99,\n", - " observation_space: gym.Space | None = None,\n", - " ) -> None:\n", - " super().__init__(\n", - " action_space=action_space,\n", - " observation_space=observation_space,\n", - " )\n", - " self.actor = actor\n", - " self.optim = optim\n", - " self.dist_fn = dist_fn\n", - " assert 0.0 <= discount_factor <= 1.0, \"discount factor should be in [0, 1]\"\n", - " self.gamma = discount_factor\n", - " self.ret_rms = RunningMeanStd()\n", - "\n", - " def process_fn(\n", - " self,\n", - " batch: RolloutBatchProtocol,\n", - " buffer: ReplayBuffer,\n", - " indices: np.ndarray,\n", - " ) -> BatchWithReturnsProtocol:\n", - " \"\"\"Compute the discounted returns (Monte Carlo estimates) for each transition.\n", - "\n", - " They are added to the batch under the field `returns`.\n", - " Note: this function will modify the input batch!\n", - " \"\"\"\n", - " v_s_ = np.full(indices.shape, self.ret_rms.mean)\n", - " # use a function inherited from BasePolicy to compute returns\n", - " # gae_lambda = 1.0 means we use Monte Carlo estimate\n", - " batch.returns, _ = self.compute_episodic_return(\n", - " batch,\n", - " buffer,\n", - " indices,\n", - " v_s_=v_s_,\n", - " gamma=self.gamma,\n", - " gae_lambda=1.0,\n", - " )\n", - " batch: BatchWithReturnsProtocol\n", - " return batch\n", - "\n", - " def forward(\n", - " self,\n", - " batch: ObsBatchProtocol,\n", - " state: dict | BatchProtocol | np.ndarray | None = None,\n", - " **kwargs: Any,\n", - " ) -> DistBatchProtocol:\n", - " \"\"\"Compute action over the given batch data by applying the actor.\n", - "\n", - " Will sample from the dist_fn, if appropriate.\n", - " Returns a new object representing the processed batch data\n", - " (contrary to other methods that modify the input batch inplace).\n", - " \"\"\"\n", - " logits, hidden = self.actor(batch.obs, state=state)\n", - "\n", - " if isinstance(logits, tuple):\n", - " dist = self.dist_fn(*logits)\n", - " else:\n", - " dist = self.dist_fn(logits)\n", - "\n", - " act = dist.sample()\n", - " return cast(DistBatchProtocol, Batch(logits=logits, act=act, state=hidden, dist=dist))\n", - "\n", - " def learn( # type: ignore\n", - " self,\n", - " batch: BatchWithReturnsProtocol,\n", - " batch_size: int | None,\n", - " repeat: int,\n", - " *args: Any,\n", - " **kwargs: Any,\n", - " ) -> TPGTrainingStats:\n", - " losses = []\n", - " split_batch_size = batch_size or -1\n", - " for _ in range(repeat):\n", - " for minibatch in batch.split(split_batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", - " result = self(minibatch)\n", - " dist = result.dist\n", - " act = to_torch_as(minibatch.act, result.act)\n", - " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", - " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", - " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " losses.append(loss.item())\n", - "\n", - " loss_summary_stat = SequenceSummaryStats.from_sequence(losses)\n", - "\n", - " return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xlPAbh0lKti8" - }, - "source": [ - "## Use the policy\n", - "Note that `BasePolicy` itself inherits from `torch.nn.Module`. As a result, you can consider all Policy modules as a Torch Module. They share similar APIs.\n", - "\n", - "Firstly we will initialize a new PGPolicy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JkLFA9Z1KjuX" - }, - "outputs": [], - "source": [ - "state_shape = 4\n", - "action_shape = 2\n", - "# Usually taken from an env by using env.action_space\n", - "action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))\n", - "net = Net(state_shape, hidden_sizes=[16, 16], device=\"cpu\")\n", - "actor = Actor(net, action_shape, device=\"cpu\").to(\"cpu\")\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", - "dist_fn = torch.distributions.Categorical\n", - "\n", - "policy: BasePolicy\n", - "policy = PGPolicy(actor=actor, optim=optim, dist_fn=dist_fn, action_space=action_space)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LAo_0t2fekUD" - }, - "source": [ - "PGPolicy shares same APIs with the Torch Module." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "UiuTc8RhJiEi", - "outputId": "9b5bc54c-6303-45f3-ba81-2216a44931e8" - }, - "outputs": [], - "source": [ - "print(policy)\n", - "print(\"========================================\")\n", - "for param in policy.parameters():\n", - " print(param.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-RCrsttYgAG-" - }, - "source": [ - "### Making decision\n", - "Given a batch of observations, the policy can return a batch of actions and other data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0jkBb6AAgUla", - "outputId": "37948844-cdd8-4567-9481-89453c80a157" - }, - "outputs": [], - "source": [ - "obs_batch = Batch(obs=np.ones(shape=(256, 4)))\n", - "dist_batch = policy(obs_batch) # forward() method is called\n", - "print(\"Next action for each observation: \\n\", dist_batch.act)\n", - "print(\"Dsitribution: \\n\", dist_batch.dist)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "swikhnuDfKep" - }, - "source": [ - "### Save and Load models\n", - "Naturally, Tianshou Policy can be saved and loaded like a normal Torch Network." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tYOoWM_OJRnA" - }, - "outputs": [], - "source": [ - "torch.save(policy.state_dict(), \"policy.pth\")\n", - "assert policy.load_state_dict(torch.load(\"policy.pth\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gp8PzOYsg5z-" - }, - "source": [ - "### Algorithm Updating\n", - "We have to collect some data and save them in the ReplayBuffer before updating our agent(policy). Typically we use collector to collect data, but we leave this part till later when we have learned the Collector in Tianshou. For now we generate some **fake** data." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrrPxOUAYShR" - }, - "source": [ - "#### Generating fake data\n", - "Firstly, we need to \"pretend\" that we are using the \"Policy\" to collect data. We plan to collect 10 data so that we can update our algorithm." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a14CmzSfYh5C", - "outputId": "aaf45a1f-5e21-4bc8-cbe3-8ce798258af0" - }, - "outputs": [], - "source": [ - "dummy_buffer = ReplayBuffer(size=10)\n", - "print(dummy_buffer)\n", - "print(f\"maxsize: {dummy_buffer.maxsize}, data length: {len(dummy_buffer)}\")\n", - "env = gym.make(\"CartPole-v1\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8S94cV7yZITR" - }, - "source": [ - "Now we are pretending to collect the first episode. The first episode ends at step 3 (perhaps because we are performing too badly)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "a_mtvbmBZbfs" - }, - "outputs": [], - "source": [ - "obs, info = env.reset()\n", - "for i in range(3):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", - " obs_next, rew, _, truncated, info = env.step(act)\n", - " # pretend ending at step 3\n", - " terminated = i == 2\n", - " info[\"id\"] = i\n", - " dummy_buffer.add(\n", - " Batch(\n", - " obs=obs,\n", - " act=act,\n", - " rew=rew,\n", - " terminated=terminated,\n", - " truncated=truncated,\n", - " obs_next=obs_next,\n", - " info=info,\n", - " ),\n", - " )\n", - " obs = obs_next" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(dummy_buffer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pkxq4gu9bGkt" - }, - "source": [ - "Now we are pretending to collect the second episode. At step 7 the second episode still doesn't end, but we are unwilling to wait, so we stop collecting to update the algorithm." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pAoKe02ybG68" - }, - "outputs": [], - "source": [ - "obs, info = env.reset()\n", - "for i in range(3, 10):\n", - " # For retrieving actions to be used for training, we set the policy to training mode,\n", - " # but the wrapped torch module should be in eval mode.\n", - " with policy_within_training_step(policy), torch_train_mode(policy, enabled=False):\n", - " act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n", - " obs_next, rew, _, truncated, info = env.step(act)\n", - " # pretend this episode never end\n", - " terminated = False\n", - " info[\"id\"] = i\n", - " dummy_buffer.add(\n", - " Batch(\n", - " obs=obs,\n", - " act=act,\n", - " rew=rew,\n", - " terminated=terminated,\n", - " truncated=truncated,\n", - " obs_next=obs_next,\n", - " info=info,\n", - " ),\n", - " )\n", - " obs = obs_next" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MKM6aWMucv-M" - }, - "source": [ - "Our replay buffer looks like this now." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "CSJEEWOqXdTU", - "outputId": "2b3bb75c-f219-4e82-ca78-0ea6173a91f9" - }, - "outputs": [], - "source": [ - "print(dummy_buffer)\n", - "print(f\"maxsize: {dummy_buffer.maxsize}, data length: {len(dummy_buffer)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "55VWhWpkdfEb" - }, - "source": [ - "#### Updates\n", - "Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.\n", - "\n", - "However, we need to manually set the torch module to training mode prior to that, \n", - "and also declare that we are within a training step. Tianshou Trainers will take care of that automatically,\n", - "but users need to consider it when calling `.update` outside of the trainer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "i_O1lJDWdeoc", - "outputId": "b154741a-d6dc-46cb-898f-6e84fa14e5a7" - }, - "outputs": [], - "source": [ - "# 0 means sample all data from the buffer\n", - "\n", - "# For updating the policy, the policy should be in training mode\n", - "# and the wrapped torch module should also be in training mode (unlike when collecting data).\n", - "with policy_within_training_step(policy), torch_train_mode(policy):\n", - " policy.update(sample_size=0, buffer=dummy_buffer, batch_size=10, repeat=6).pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QJ5krjrcbuiA" - }, - "source": [ - "## Further Reading\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pmWi3HuXWcV8" - }, - "source": [ - "### Pre-defined Networks\n", - "Tianshou provides numerous pre-defined networks usually used in DRL so that you don't have to bother yourself. Check this [documentation](https://tianshou.readthedocs.io/en/master/03_api/utils/net/index.html) for details." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UPVl5LBEWJ0t" - }, - "source": [ - "### How to compute GAE on your own?\n", - "(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n", - "\n", - "In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n", - "\n", - "To compute GAE advantage, the usage of `self.compute_episodic_return()` may go like:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "D34GlVvPNz08", - "outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215" - }, - "outputs": [], - "source": [ - "batch, indices = dummy_buffer.sample(0) # 0 means sampling all the data from the buffer\n", - "returns, advantage = BasePolicy.compute_episodic_return(\n", - " batch=batch,\n", - " buffer=dummy_buffer,\n", - " indices=indices,\n", - " v_s_=np.zeros(10),\n", - " v_s=np.zeros(10),\n", - " gamma=1.0,\n", - " gae_lambda=1.0,\n", - ")\n", - "print(f\"{batch.rew=}\")\n", - "print(f\"{batch.done=}\")\n", - "print(f\"{returns=}\")\n", - "print(f\"{advantage=}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. However, the way the returns are computed here might be a bit misleading. In fact, the last episode is unfinished, but its last step saved in the batch is treated as a terminal state, since it assumes that there are no future rewards. The episode is not terminated yet, it is truncated, so the agent could still get rewards in the future. Terminated and truncated episodes should indeed be treated differently.\n", - "The return of a step is the (discounted) sum of the future rewards from that step until the end of the episode. \n", - "\\begin{equation}\n", - "R_{t}=\\sum_{t}^{T} \\gamma^{t} r_{t}\n", - "\\end{equation}\n", - "Thus, at the last step of a terminated episode the return is equal to the reward at that state, since there are no future states.\n", - "\\begin{equation}\n", - "R_{T,terminated}=r_{T}\n", - "\\end{equation}\n", - "\n", - "However, if the episode was truncated the return at the last step is usually better represented by the estimated value of that state, which is the expected return from that state onwards.\n", - "\\begin{align*}\n", - "R_{T,truncated}=V^{\\pi}\\left(s_{T}\\right) \\quad & \\text{or} \\quad R_{T,truncated}=Q^{\\pi}(s_{T},a_{T})\n", - "\\end{align*}\n", - "Moreover, if the next state was also observed (but not its reward), then an even better estimate would be the reward of the last step plus the discounted value of the next state.\n", - "\\begin{align*}\n", - "R_{T,truncated}=r_T+\\gamma V^{\\pi}\\left(s_{T+1}\\right)\n", - "\\end{align*}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h_5Dt6XwQLXV" - }, - "source": [ - "\n", - "As we know, we need to estimate the value function of every observation to compute GAE advantage. So in `v_s` is the value of `batch.obs`, and in `v_s_` is the value of `batch.obs_next`. This is usually computed by:\n", - "\n", - "`v_s = critic(batch.obs)`,\n", - "\n", - "`v_s_ = critic(batch.obs_next)`,\n", - "\n", - "where both `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n", - "\n", - "After we've got all those values, GAE can be computed following the equation below." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ooHNIICGUO19" - }, - "source": [ - "\\begin{aligned}\n", - "\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n", - "\\end{aligned}\n", - "\n", - "where\n", - "\n", - "\\begin{equation}\n", - "\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n", - "\\end{equation}\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eV6XZaouU7EV" - }, - "source": [ - "Unfortunately, if you follow this equation, which is taken from the paper, you probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FCxD9gNNVYbd" - }, - "source": [ - "**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "rNZNUNgQVvRJ", - "outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d" - }, - "outputs": [], - "source": [ - "# Assume v_s_ is got by calling critic(batch.obs_next)\n", - "v_s_ = np.ones(10)\n", - "v_s_ *= ~batch.done\n", - "print(f\"{v_s_=}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2EtMi18QWXTN" - }, - "source": [ - "After the fix above, we will perhaps get a more accurate estimate.\n", - "\n", - "**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstrapping.\n", - "\n", - "Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "saluvX4JU6bC", - "outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5" - }, - "outputs": [], - "source": [ - "unfinished_indexes = dummy_buffer.unfinished_index()\n", - "print(\"unfinished_indexes: \", unfinished_indexes)\n", - "done_indexes = np.where(batch.done)[0]\n", - "print(\"done_indexes: \", done_indexes)\n", - "stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n", - "print(\"stop_bootstrap_ids: \", stop_bootstrap_ids)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qp6vVE4dYWv1" - }, - "source": [ - "**Thirdly**, there are some special indexes which are marked by done flag, however its value for obs_next should not be zero. It is again because done does not differentiate between terminated and truncated. These steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tWkqXRJfZTvV" - }, - "source": [ - "As a result, we need to rewrite the equation above\n", - "\n", - "`v_s_ *= ~batch.done`" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kms-QtxKZe-M" - }, - "source": [ - "to\n", - "\n", - "```\n", - "mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n", - "v_s_ *= mask\n", - "\n", - "```\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u_aPPoKraBu6" - }, - "source": [ - "### Summary\n", - "If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `BasePolicy.compute_episodic_return()`.\n", - "\n", - "If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `BasePolicy.value_mask()` and `BasePolicy.compute_episodic_return()` for details." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2cPnUXRBWKD9" - }, - "source": [ - "
\n", - "\n", - "
\n", - "
\n", - "\n", - "
" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index d7aaa9fb3..36fd4572d 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -8,10 +8,6 @@ "source": [ "# Collector\n", "From its literal meaning, we can easily know that the Collector in Tianshou is used to collect training data. More specifically, the Collector controls the interaction between Policy (agent) and the environment. It also helps save the interaction data into the ReplayBuffer and returns episode statistics.\n", - "\n", - "
\n", - "\n", - "
\n", "\n" ] }, @@ -53,16 +49,14 @@ }, "outputs": [], "source": [ - "%%capture\n", - "\n", "import gymnasium as gym\n", "import torch\n", "\n", - "from tianshou.algorithm import PGPolicy\n", + "from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic\n", "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" + "from tianshou.utils.net.common import MLPActor\n", + "from tianshou.utils.net.discrete import DiscreteActor" ] }, { @@ -71,25 +65,26 @@ "metadata": {}, "outputs": [], "source": [ + "from tianshou.algorithm.optim import AdamOptimizerFactory\n", + "\n", + "\n", "env = gym.make(\"CartPole-v1\")\n", "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", "\n", "# model\n", "assert env.observation_space.shape is not None # for mypy\n", - "net = Net(\n", - " env.observation_space.shape,\n", + "preprocess_net = MLPActor(\n", + " state_shape=env.observation_space.shape,\n", " hidden_sizes=[\n", " 16,\n", " ],\n", ")\n", "\n", "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "actor = Actor(net, env.action_space.n)\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", + "actor = DiscreteActor(preprocess_net=preprocess_net, action_shape=env.action_space.n)\n", "\n", - "policy: PGPolicy = PGPolicy(\n", + "policy = ActorPolicyProbabilistic(\n", " actor=actor,\n", - " optim=optim,\n", " dist_fn=torch.distributions.Categorical,\n", " action_space=env.action_space,\n", " action_scaling=False,\n", @@ -270,7 +265,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb deleted file mode 100644 index ffa18168b..000000000 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ /dev/null @@ -1,283 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "S3-tJZy35Ck_" - }, - "source": [ - "# Trainer\n", - "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n", - "\n", - "
\n", - "\n", - "
\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ifsEQMzZ6mmz" - }, - "source": [ - "## Usages\n", - "In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XfsuU2AAE52C" - }, - "source": [ - "### Pseudocode\n", - "
\n", - "\n", - "
\n", - "\n", - "For the on-policy trainer, the main difference is that we clear the buffer after Line 10." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Hcp_o0CCFz12" - }, - "source": [ - "### Training without trainer\n", - "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n", - "\n", - "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "do-xZ-8B7nVH", - "slideshow": { - "slide_type": "" - }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.algorithm import PGPolicy\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor\n", - "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_env_num = 4\n", - "# Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n", - "buffer_size = 2000\n", - "\n", - "\n", - "# Create the environments, used for training and evaluation\n", - "env = gym.make(\"CartPole-v1\")\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n", - "\n", - "# Create the Policy instance\n", - "assert env.observation_space.shape is not None\n", - "net = Net(\n", - " env.observation_space.shape,\n", - " hidden_sizes=[\n", - " 16,\n", - " ],\n", - ")\n", - "\n", - "assert isinstance(env.action_space, gym.spaces.Discrete)\n", - "actor = Actor(net, env.action_space.n)\n", - "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n", - "\n", - "# We choose to use REINFORCE algorithm, also known as Policy Gradient\n", - "policy: PGPolicy = PGPolicy(\n", - " actor=actor,\n", - " optim=optim,\n", - " dist_fn=torch.distributions.Categorical,\n", - " action_space=env.action_space,\n", - " action_scaling=False,\n", - ")\n", - "\n", - "# Create the replay buffer and the collector\n", - "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", - "test_collector = Collector[CollectStats](policy, test_envs)\n", - "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wiEGiBgQIiFM" - }, - "source": [ - "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMUNPN5SI_kd", - "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d" - }, - "outputs": [], - "source": [ - "train_collector.reset()\n", - "train_envs.reset()\n", - "test_collector.reset()\n", - "test_envs.reset()\n", - "replayBuffer.reset()\n", - "\n", - "n_episode = 10\n", - "for _i in range(n_episode):\n", - " # for test collector, we set the wrapped torch module to evaluation mode\n", - " # by default, the policy object itself is not within the training step\n", - " with torch_train_mode(policy, enabled=False):\n", - " evaluation_result = test_collector.collect(n_episode=n_episode)\n", - " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", - " # for collecting data for training, the policy object should be within the training step\n", - " # (affecting e.g. whether the policy is stochastic or deterministic)\n", - " with policy_within_training_step(policy):\n", - " train_collector.collect(n_step=2000)\n", - " # 0 means taking all data stored in train_collector.buffer\n", - " # for updating the policy, the wrapped torch module should be in training mode\n", - " with torch_train_mode(policy):\n", - " policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n", - " train_collector.reset_buffer(keep_statistics=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QXBHIBckMs_2" - }, - "source": [ - "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p-7U_cwgF5Ej" - }, - "source": [ - "### Training with trainer\n", - "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vcvw9J8RNtFE", - "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5", - "tags": [ - "remove-output" - ] - }, - "outputs": [], - "source": [ - "train_collector.reset()\n", - "train_envs.reset()\n", - "test_collector.reset()\n", - "test_envs.reset()\n", - "replayBuffer.reset()\n", - "\n", - "result = OnpolicyTrainer(\n", - " policy=policy,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=1,\n", - " repeat_per_collect=1,\n", - " episode_per_test=10,\n", - " step_per_collect=2000,\n", - " batch_size=512,\n", - ").run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_j3aUJZQ7nml" - }, - "source": [ - "## Further Reading\n", - "### Logger usages\n", - "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.org/en/master/03_api/utils/logger/base.html#tianshou.utils.logger.base.BaseLogger) for details.\n", - "\n", - "### Learn more about the APIs of Trainers\n", - "[documentation](https://tianshou.org/en/master/03_api/trainer/index.html)" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [ - "S3-tJZy35Ck_", - "XfsuU2AAE52C", - "p-7U_cwgF5Ej", - "_j3aUJZQ7nml" - ], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb deleted file mode 100644 index 9da33c98b..000000000 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ /dev/null @@ -1,341 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "_UaXOSRjDUF9" - }, - "source": [ - "# Experiment\n", - "Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2QRbCJvDHNAd" - }, - "source": [ - "## Experiment\n", - "To conduct this experiment, we need the following building blocks.\n", - "\n", - "\n", - "* Two vectorized environments, one for training and one for evaluation\n", - "* A PPO agent\n", - "* A replay buffer to store transition data\n", - "* Two collectors to manage the data collecting process, one for training and one for evaluation\n", - "* A trainer to manage the training loop\n", - "\n", - "
\n", - "\n", - "\n", - "
\n", - "\n", - "Let us do this step by step." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-Hh4E6i0Hj0I" - }, - "source": [ - "## Preparation\n", - "Firstly, install Tianshou if you haven't installed it before." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7E4EhiBeHxD5" - }, - "source": [ - "Import libraries we might need later." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "ao9gWJDiHgG-", - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.algorithm import PPOPolicy\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.trainer import OnpolicyTrainer\n", - "from tianshou.utils.net.common import ActorCritic, MLPActor\n", - "from tianshou.utils.net.discrete import Actor, Critic\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QnRg5y7THRYw" - }, - "source": [ - "## Environment" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YZERKCGtH8W1" - }, - "source": [ - "We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Mpuj5PFnDKVS" - }, - "outputs": [], - "source": [ - "env = gym.make(\"CartPole-v1\")\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BJtt_Ya8DTAh" - }, - "source": [ - "## Policy\n", - "Next we need to initialize our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n", - "\n", - "The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n", - "\n", - "Luckily, Tianshou already provides basic network modules that we can use in this experiment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_Vy8uPWXP4m_" - }, - "outputs": [], - "source": [ - "# net is the shared head of the actor and the critic\n", - "assert env.observation_space.shape is not None # for mypy\n", - "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "net = MLPActor(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n", - "actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n", - "critic = Critic(preprocess_net=net, device=device).to(device)\n", - "actor_critic = ActorCritic(actor=actor, critic=critic)\n", - "\n", - "# optimizer of the actor and the critic\n", - "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Lh2-hwE5Dn9I" - }, - "source": [ - "Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OiJ2GkT0Qnbr" - }, - "outputs": [], - "source": [ - "dist = torch.distributions.Categorical\n", - "policy: PPOPolicy = PPOPolicy(\n", - " actor=actor,\n", - " critic=critic,\n", - " optim=optim,\n", - " dist_fn=dist,\n", - " action_space=env.action_space,\n", - " deterministic_eval=True,\n", - " action_scaling=False,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "okxfj6IEQ-r8" - }, - "source": [ - "`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n5XAAbuBZarO" - }, - "source": [ - "## Collector\n", - "We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ezwz0qerZhQM" - }, - "outputs": [], - "source": [ - "train_collector = Collector[CollectStats](\n", - " policy=policy,\n", - " env=train_envs,\n", - " buffer=VectorReplayBuffer(20000, len(train_envs)),\n", - ")\n", - "test_collector = Collector[CollectStats](policy=policy, env=test_envs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZaoPxOd2hm0b" - }, - "source": [ - "We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qBoE9pLUiC-8" - }, - "source": [ - "## Trainer\n", - "Finally, we can use the trainer to help us set up the training loop." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "editable": true, - "id": "i45EDnpxQ8gu", - "outputId": "b1666b88-0bfa-4340-868e-58611872d988", - "tags": [ - "remove-output" - ] - }, - "outputs": [], - "source": [ - "result = OnpolicyTrainer(\n", - " policy=policy,\n", - " train_collector=train_collector,\n", - " test_collector=test_collector,\n", - " max_epoch=10,\n", - " step_per_epoch=50000,\n", - " repeat_per_collect=10,\n", - " episode_per_test=10,\n", - " batch_size=256,\n", - " step_per_collect=2000,\n", - " stop_fn=lambda mean_reward: mean_reward >= 195,\n", - ").run()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ckgINHE2iTFR" - }, - "source": [ - "## Results\n", - "Print the training result." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tJCPgmiyiaaX", - "outputId": "40123ae3-3365-4782-9563-46c43812f10f", - "tags": [] - }, - "outputs": [], - "source": [ - "result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "A-MJ9avMibxN" - }, - "source": [ - "We can also test our trained agent." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mnMANFcciiAQ", - "outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21" - }, - "outputs": [], - "source": [ - "# Let's watch its performance!\n", - "policy.eval()\n", - "result = test_collector.collect(n_episode=1, render=False)\n", - "print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")" - ] - } - ], - "metadata": { - "colab": { - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 0d622f3717f08c56bb7028eba79b36ef584c72e1 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 14:57:27 +0200 Subject: [PATCH 208/230] v1: minor improvement in doc-build command --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fd1004878..9abb7fac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -231,9 +231,9 @@ _autogen_rst = "python docs/autogen_rst.py" _sphinx_build = "sphinx-build -b html docs docs/_build -W --keep-going" _jb_generate_toc = "python docs/create_toc.py" _jb_generate_config = "jupyter-book config sphinx docs/" -doc-clean = "rm -rf docs/_build" +doc-clean = "rm -rf docs/_build docs/03_api" doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] -doc-build = ["doc-generate-files", "_sphinx_build"] +doc-build = ["doc-clean", "doc-generate-files", "_sphinx_build"] _mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"] From e9fe650f16c423b5cb8be3f42f9fd05eb28ad76b Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 14:59:27 +0200 Subject: [PATCH 209/230] v2: minor fixes in docstrings, doc build runs through --- tianshou/algorithm/modelfree/ddpg.py | 4 ++-- tianshou/algorithm/modelfree/sac.py | 8 ++++---- tianshou/algorithm/modelfree/td3.py | 4 ++-- tianshou/env/atari/atari_network.py | 2 +- tianshou/highlevel/params/alpha.py | 4 ++-- tianshou/policy/__init__.py | 0 tianshou/trainer/trainer.py | 23 ++++++++++++++++------- tianshou/utils/lagged_network.py | 8 ++++---- tianshou/utils/torch_utils.py | 4 ++-- 9 files changed, 33 insertions(+), 24 deletions(-) create mode 100644 tianshou/policy/__init__.py diff --git a/tianshou/algorithm/modelfree/ddpg.py b/tianshou/algorithm/modelfree/ddpg.py index 038d647b5..0e44bc4ad 100644 --- a/tianshou/algorithm/modelfree/ddpg.py +++ b/tianshou/algorithm/modelfree/ddpg.py @@ -231,8 +231,8 @@ def __init__( :param critic: the critic network. For continuous action spaces: (s, a -> Q(s, a)). For discrete action spaces: (s -> ). - NOTE: The default implementation of `_target_q_compute_value` assumes - a continuous action space; override this method if using discrete actions. + **NOTE**: The default implementation of `_target_q_compute_value` assumes + a continuous action space; override this method if using discrete actions. :param critic_optim: the optimizer factory for the critic network. :param tau: the soft update coefficient for target networks, controlling the rate at which target networks track the learned networks. diff --git a/tianshou/algorithm/modelfree/sac.py b/tianshou/algorithm/modelfree/sac.py index 4ad7fd195..a63583445 100644 --- a/tianshou/algorithm/modelfree/sac.py +++ b/tianshou/algorithm/modelfree/sac.py @@ -179,10 +179,10 @@ class AutoAlpha(torch.nn.Module, Alpha): def __init__(self, target_entropy: float, log_alpha: float, optim: OptimizerFactory): """ :param target_entropy: the target entropy value. - For discrete action spaces, it is usually -log(|A|) for a balance between stochasticity - and determinism or -log(1/|A|)=log(|A|) for maximum stochasticity or, more generally, - lambda*log(|A|), e.g. with lambda close to 1 (e.g. 0.98) for pronounced stochasticity. - For continuous action spaces, it is usually -dim(A) for a balance between stochasticity + For discrete action spaces, it is usually `-log(|A|)` for a balance between stochasticity + and determinism or `-log(1/|A|)=log(|A|)` for maximum stochasticity or, more generally, + `lambda*log(|A|)`, e.g. with `lambda` close to 1 (e.g. 0.98) for pronounced stochasticity. + For continuous action spaces, it is usually `-dim(A)` for a balance between stochasticity and determinism, with similar generalizations as for discrete action spaces. :param log_alpha: the (initial) value of the log of the entropy regularization coefficient alpha. :param optim: the factory with which to create the optimizer for `log_alpha`. diff --git a/tianshou/algorithm/modelfree/td3.py b/tianshou/algorithm/modelfree/td3.py index a6e857c0b..4c273cc63 100644 --- a/tianshou/algorithm/modelfree/td3.py +++ b/tianshou/algorithm/modelfree/td3.py @@ -55,8 +55,8 @@ def __init__( :param policy_optim: the optimizer factory for the policy's model. :param critic: the first critic network. For continuous action spaces: (s, a -> Q(s, a)). - NOTE: The default implementation of `_target_q_compute_value` assumes - a continuous action space; override this method if using discrete actions. + **NOTE**: The default implementation of `_target_q_compute_value` assumes + a continuous action space; override this method if using discrete actions. :param critic_optim: the optimizer factory for the first critic network. :param critic2: the second critic network (analogous functionality to the first). If None, copy the first critic (via deepcopy). diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index 32ec838e1..ab4693a74 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -5,6 +5,7 @@ import torch from torch import nn +from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.data import Batch from tianshou.data.types import TObs from tianshou.highlevel.env import Environments @@ -17,7 +18,6 @@ IntermediateModuleFactory, ) from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical -from tianshou.policy.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.utils.net.common import Actor, ModuleWithVectorOutput from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear from tianshou.utils.torch_utils import torch_device diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 28f13511e..406fda2d8 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -30,8 +30,8 @@ def __init__( """ :param lr: the learning rate for the optimizer of the alpha parameter :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; - The base value being scaled is dim(A) for continuous action spaces and log(|A|) for discrete action spaces, - i.e. with the default coefficient -1, we obtain -dim(A) and -log(dim(A)) for continuous and discrete action + The base value being scaled is `dim(A)` for continuous action spaces and `log(|A|)` for discrete action spaces, + i.e. with the default coefficient `-1`, we obtain `-dim(A)` and `-log(dim(A))` for continuous and discrete action spaces respectively, which gives a reasonable trade-off between exploration and exploitation. For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); 1.0 would give full entropy exploration. diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 0eabc8543..f265a9af6 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -3,21 +3,30 @@ specific network updating logic to perform the actual gradient updates. Training is structured as follows (hierarchical glossary): -- **epoch**: The outermost iteration level of the training loop. Each epoch consists of a number of training steps - and one test step (see :attr:`TrainerParams.max_epoch` for a detailed explanation): - - **training step**: A training step performs the steps necessary in order to apply a single update of the neural + +- **epoch**: the outermost iteration level of the training loop. Each epoch consists of a number of training steps + and one test step (see :attr:`TrainerParams.max_epoch` for a detailed explanation). + + - **training step**: a training step performs the steps necessary in order to apply a single update of the neural network components as defined by the underlying RL algorithm (:class:`Algorithm`). This involves the following sub-steps: + - for online learning algorithms: + - **collection step**: collecting environment steps/transitions to be used for training. - - (potentially) a test step (see below) if the early stopping criterion is satisfied based on + + - (Potentially) a test step (see below) if the early stopping criterion is satisfied based on the data collected (see :attr:`OnlineTrainerParams.test_in_train`). + - **update step**: applying the actual gradient updates using the RL algorithm. - The update is based on either ... + The update is based on either: + - data from only the preceding collection step (on-policy learning), - data from the collection step and previously collected data (off-policy learning), or - data from the user-provided replay buffer (offline learning). - For offline learning algorithms, a training step is thus equivalent to an update step. - - **test step**: Collects test episodes from dedicated test environments which are used to evaluate the performance + + For offline learning algorithms, a training step is thus equivalent to an update step. + + - **test step**: collects test episodes from dedicated test environments which are used to evaluate the performance of the policy. Optionally, the performance result can be used to determine whether training shall stop early (see :attr:`TrainerParams.stop_fn`). """ diff --git a/tianshou/utils/lagged_network.py b/tianshou/utils/lagged_network.py index 3f5146580..37a2aa71b 100644 --- a/tianshou/utils/lagged_network.py +++ b/tianshou/utils/lagged_network.py @@ -23,10 +23,10 @@ class EvalModeModuleWrapper(torch.nn.Module): A wrapper around a torch.nn.Module that forces the module to eval mode. The wrapped module supports only the forward method, attribute access is not supported. - NOTE: It is *not* recommended to support attribute/method access beyond this via `__getattr__`, - because torch.nn.Module already heavily relies on `__getattr__` to provides its own attribute access. - Overriding it naively will cause problems! - But it's also not necessary for our use cases; forward is enough. + **NOTE**: It is *not* recommended to support attribute/method access beyond this via `__getattr__`, + because torch.nn.Module already heavily relies on `__getattr__` to provides its own attribute access. + Overriding it naively will cause problems! + But it's also not necessary for our use cases; forward is enough. """ def __init__(self, m: torch.nn.Module): diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index f675502d3..a4e15ba6c 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -8,7 +8,7 @@ from torch import nn if TYPE_CHECKING: - from tianshou.algorithm.algorithm_base import Policy + from tianshou.algorithm import algorithm_base @contextmanager @@ -23,7 +23,7 @@ def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: @contextmanager -def policy_within_training_step(policy: "Policy", enabled: bool = True) -> Iterator[None]: +def policy_within_training_step(policy: "algorithm_base.Policy", enabled: bool = True) -> Iterator[None]: """Temporarily switch to `policy.is_within_training_step=enabled`. Enabling this ensures that the policy is able to adapt its behavior, From 62bd07fa41297bec8c56818cdc05ec81028b175f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 17 May 2025 15:01:40 +0200 Subject: [PATCH 210/230] v2: Fix import --- tianshou/env/atari/atari_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index 32ec838e1..ab4693a74 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -5,6 +5,7 @@ import torch from torch import nn +from tianshou.algorithm.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.data import Batch from tianshou.data.types import TObs from tianshou.highlevel.env import Environments @@ -17,7 +18,6 @@ IntermediateModuleFactory, ) from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical -from tianshou.policy.modelfree.reinforce import TDistFnDiscrOrCont from tianshou.utils.net.common import Actor, ModuleWithVectorOutput from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear from tianshou.utils.torch_utils import torch_device From 3ff942336a7f7328a8970df93473732f1d983b11 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 17 May 2025 15:02:27 +0200 Subject: [PATCH 211/230] v2: Rename BranchingActor back to BranchingNet --- examples/box2d/bipedal_bdq.py | 4 ++-- test/discrete/test_bdqn.py | 4 ++-- tianshou/algorithm/modelfree/bdqn.py | 6 +++--- tianshou/utils/net/common.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 522bac6c7..d1afcd7f5 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -16,7 +16,7 @@ from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import BranchingActor +from tianshou.utils.net.common import BranchingNet def get_args() -> argparse.Namespace: @@ -93,7 +93,7 @@ def run_bdq(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = BranchingActor( + net = BranchingNet( state_shape=args.state_shape, num_branches=args.num_branches, action_per_branch=args.action_per_branch, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 224d9a4e3..1bf93cd87 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import BranchingActor +from tianshou.utils.net.common import BranchingNet from tianshou.utils.torch_utils import policy_within_training_step @@ -92,7 +92,7 @@ def test_bdq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = BranchingActor( + net = BranchingNet( state_shape=args.state_shape, num_branches=args.num_branches, action_per_branch=args.action_per_branch, diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index 0869169c8..4b09698d0 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -21,16 +21,16 @@ ObsBatchProtocol, RolloutBatchProtocol, ) -from tianshou.utils.net.common import BranchingActor +from tianshou.utils.net.common import BranchingNet mark_used(ActBatchProtocol) -class BDQNPolicy(DiscreteQLearningPolicy[BranchingActor]): +class BDQNPolicy(DiscreteQLearningPolicy[BranchingNet]): def __init__( self, *, - model: BranchingActor, + model: BranchingNet, action_space: gym.spaces.Discrete, observation_space: gym.Space | None = None, eps_training: float = 0.0, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 44ca504a3..49d093a12 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -538,7 +538,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class BranchingActor(ActorForwardInterface): +class BranchingNet(ActorForwardInterface): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module From 105043e0a7842bf58c8a37746b48a557b6559283 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 17 May 2025 15:24:23 +0200 Subject: [PATCH 212/230] v2: Fix renamed class reference --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ca3729e9..96c81ceac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -207,7 +207,7 @@ Developers: * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. * Fix issues pertaining to the torch device assignment of network components (#810): * Remove 'device' member (and the corresponding constructor argument) from the following classes: - `BranchingActor`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProbabilistic`, `ContinuousCritic`, + `BranchingNet`, `C51Net`, `ContinuousActorDeterministic`, `ContinuousActorProbabilistic`, `ContinuousCritic`, `DiscreteActor`, `DiscreteCritic`, `DQNet`, `FullQuantileFunction`, `ImplicitQuantileNetwork`, `IntrinsicCuriosityModule`, `MLPActor`, `MLP`, `Perturbation`, `QRDQNet`, `Rainbow`, `Recurrent`, `RecurrentActorProb`, `RecurrentCritic`, `VAE` From 03132cc04cfe5c755baf441617bc7f52e2ae061c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Sat, 17 May 2025 16:49:59 +0200 Subject: [PATCH 213/230] v2: Improve neural network class hierarchy --- CHANGELOG.md | 8 ++- README.md | 4 +- docs/02_notebooks/L5_Collector.ipynb | 16 +++-- docs/autogen_rst.py | 2 +- examples/atari/atari_ppo.py | 6 +- examples/atari/atari_rainbow.py | 4 +- examples/box2d/acrobot_dualdqn.py | 4 +- examples/box2d/bipedal_hardcore_sac.py | 8 +-- examples/box2d/lunarlander_dqn.py | 4 +- examples/box2d/mcc_sac.py | 8 +-- examples/discrete/discrete_dqn.py | 4 +- examples/inverse/irl_gail.py | 12 ++-- examples/mujoco/fetch_her_ddpg.py | 6 +- examples/mujoco/mujoco_a2c.py | 10 +-- examples/mujoco/mujoco_ddpg.py | 6 +- examples/mujoco/mujoco_npg.py | 10 +-- examples/mujoco/mujoco_ppo.py | 10 +-- examples/mujoco/mujoco_redq.py | 6 +- examples/mujoco/mujoco_reinforce.py | 8 +-- examples/mujoco/mujoco_sac.py | 8 +-- examples/mujoco/mujoco_td3.py | 8 +-- examples/mujoco/mujoco_trpo.py | 10 +-- examples/offline/d4rl_bcq.py | 6 +- examples/offline/d4rl_cql.py | 8 +-- examples/offline/d4rl_il.py | 4 +- examples/offline/d4rl_td3_bc.py | 8 +-- examples/vizdoom/vizdoom_ppo.py | 4 +- test/base/test_policy.py | 12 ++-- test/base/test_utils.py | 12 ++-- test/continuous/test_ddpg.py | 6 +- test/continuous/test_npg.py | 10 +-- test/continuous/test_ppo.py | 10 +-- test/continuous/test_redq.py | 6 +- test/continuous/test_sac_with_il.py | 10 +-- test/continuous/test_td3.py | 8 +-- test/continuous/test_trpo.py | 10 +-- test/discrete/test_a2c_with_il.py | 10 +-- test/discrete/test_c51.py | 4 +- test/discrete/test_discrete_sac.py | 8 +-- test/discrete/test_dqn.py | 4 +- test/discrete/test_fqf.py | 4 +- test/discrete/test_iqn.py | 4 +- test/discrete/test_ppo_discrete.py | 12 ++-- test/discrete/test_qrdqn.py | 4 +- test/discrete/test_rainbow.py | 4 +- test/discrete/test_reinforce.py | 8 +-- test/modelbased/test_dqn_icm.py | 4 +- test/modelbased/test_ppo_icm.py | 8 +-- test/offline/gather_cartpole_data.py | 4 +- test/offline/gather_pendulum_data.py | 6 +- test/offline/test_bcq.py | 4 +- test/offline/test_cql.py | 6 +- test/offline/test_discrete_bcq.py | 4 +- test/offline/test_discrete_cql.py | 4 +- test/offline/test_discrete_crr.py | 4 +- test/offline/test_gail.py | 12 ++-- test/offline/test_td3_bc.py | 8 +-- test/pettingzoo/pistonball.py | 4 +- test/pettingzoo/pistonball_continuous.py | 4 +- test/pettingzoo/tic_tac_toe.py | 4 +- tianshou/algorithm/imitation/gail.py | 4 +- tianshou/algorithm/modelfree/a2c.py | 8 +-- tianshou/algorithm/modelfree/c51.py | 4 +- tianshou/algorithm/modelfree/ddpg.py | 4 +- tianshou/algorithm/modelfree/dqn.py | 4 +- tianshou/algorithm/modelfree/npg.py | 4 +- tianshou/algorithm/modelfree/ppo.py | 4 +- tianshou/algorithm/modelfree/reinforce.py | 22 +++---- tianshou/algorithm/modelfree/trpo.py | 4 +- tianshou/env/atari/atari_network.py | 22 +++---- tianshou/highlevel/algorithm.py | 6 +- tianshou/highlevel/module/actor.py | 8 +-- tianshou/highlevel/module/critic.py | 8 +-- tianshou/utils/net/common.py | 74 +++++++++++++---------- tianshou/utils/net/continuous.py | 8 +-- tianshou/utils/net/discrete.py | 4 +- tianshou/utils/torch_utils.py | 4 +- 77 files changed, 308 insertions(+), 298 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96c81ceac..9347e06cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -183,7 +183,7 @@ Developers: * Detailed optimizer configuration (analogous to the procedural API) is now possible: * All optimizers can be configured in the respective algorithm-specific `Params` object by using - `OptimizerFactoryFactory` instances as parameter values (e.g. for `optim`, `actor_optim`, `critic_optim`, etc.). + `OptimizerFactoryFactory` instances as parameter values (e.g. `optim`, `actor_optim`, `critic_optim`, etc.). * Learning rate schedulers remain separate parameters and now use `LRSchedulerFactoryFactory` instances. The respective parameter names now use the suffix `lr_scheduler` instead of `lr_scheduler_factory` (as the precise nature need not be reflected in the name; brevity is preferable). @@ -218,6 +218,12 @@ Developers: dimension as an argument were changed to use `ModuleWithVectorOutput`. * The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance (via adaptation if necessary). +* The class hierarchy of supporting `nn.Module` implementations was cleaned up: + * With the fundamental base classes `ActionReprNet` and `ActionReprNetWithVectorOutput`, we etablished a + well-defined interface for the most commonly used `forward` interface in Tianshou's algorithms & policies. + * Some network classes were renamed: + * `ScaledObsInputModule` -> `ScaledObsInputActionReprNet` + * `Rainbow` -> `RainbowNet` * All modules containing base classes were renamed from `base` to a more descriptive name, rendering file names unique. diff --git a/README.md b/README.md index cc14117a6..573f1caae 100644 --- a/README.md +++ b/README.md @@ -362,14 +362,14 @@ test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_nu Create the network as well as its optimizer: ```python -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network env = gym.make(task, render_mode="human") state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n -net = MLPActor( +net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128] ) diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 36fd4572d..a52dd25eb 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -52,11 +52,7 @@ "import gymnasium as gym\n", "import torch\n", "\n", - "from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic\n", - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.utils.net.common import MLPActor\n", - "from tianshou.utils.net.discrete import DiscreteActor" + "from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy" ] }, { @@ -65,15 +61,17 @@ "metadata": {}, "outputs": [], "source": [ - "from tianshou.algorithm.optim import AdamOptimizerFactory\n", - "\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.utils.net.common import Net\n", + "from tianshou.utils.net.discrete import DiscreteActor\n", "\n", "env = gym.make(\"CartPole-v1\")\n", "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", "\n", "# model\n", "assert env.observation_space.shape is not None # for mypy\n", - "preprocess_net = MLPActor(\n", + "preprocess_net = Net(\n", " state_shape=env.observation_space.shape,\n", " hidden_sizes=[\n", " 16,\n", @@ -83,7 +81,7 @@ "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", "actor = DiscreteActor(preprocess_net=preprocess_net, action_shape=env.action_space.n)\n", "\n", - "policy = ActorPolicyProbabilistic(\n", + "policy = ProbabilisticActorPolicy(\n", " actor=actor,\n", " dist_fn=torch.distributions.Categorical,\n", " action_space=env.action_space,\n", diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py index 1477b2117..93e2a0954 100644 --- a/docs/autogen_rst.py +++ b/docs/autogen_rst.py @@ -114,7 +114,7 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix="" for f in files_in_dir if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_") ] - if not module_names and not "__init__.py" in files_in_dir: + if not module_names and "__init__.py" not in files_in_dir: log.debug(f"Skipping {dirname} as it does not contain any modules or __init__.py") continue package_qualname = f"{base_package_qualname}.{dirname}" diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index d49e85258..0d1aa385a 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -17,7 +17,7 @@ from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import ( DQNet, - ScaledObsInputModule, + ScaledObsInputActionReprNet, layer_init, ) from tianshou.env.atari.atari_wrapper import make_atari_env @@ -121,7 +121,7 @@ def main(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) # define model c, h, w = args.state_shape - net: ScaledObsInputModule | DQNet + net: ScaledObsInputActionReprNet | DQNet net = DQNet( c=c, h=h, @@ -132,7 +132,7 @@ def main(args: argparse.Namespace = get_args()) -> None: layer_init=layer_init, ) if args.scale_obs: - net = ScaledObsInputModule(net) + net = ScaledObsInputActionReprNet(net) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape, softmax_output=False) critic = DiscreteCritic(preprocess_net=net) optim = AdamOptimizerFactory(lr=args.lr, eps=1e-5) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 2e6622dce..005b14fb2 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -17,7 +17,7 @@ PrioritizedVectorReplayBuffer, VectorReplayBuffer, ) -from tianshou.env.atari.atari_network import Rainbow +from tianshou.env.atari.atari_network import RainbowNet from tianshou.env.atari.atari_wrapper import make_atari_env from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams @@ -103,7 +103,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: # define model c, h, w = args.state_shape - net = Rainbow( + net = RainbowNet( c=c, h=h, w=w, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 60fce546d..81bbc24fc 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -15,7 +15,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -69,7 +69,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index a1a1e90cb..8687d8c3e 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -17,7 +17,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -110,7 +110,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, @@ -118,7 +118,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -127,7 +127,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 99b9bcbc2..39b7f9b94 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -15,7 +15,7 @@ from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -71,7 +71,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index b2c8265a8..77bc0d5d7 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -16,7 +16,7 @@ from tianshou.exploration import OUNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -68,12 +68,12 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -81,7 +81,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index f6b171397..03fab73a3 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -25,7 +25,7 @@ def main() -> None: train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) - from tianshou.utils.net.common import MLPActor + from tianshou.utils.net.common import Net # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network @@ -34,7 +34,7 @@ def main() -> None: space_info = SpaceInfo.from_env(env) state_shape = space_info.observation_info.obs_shape action_shape = space_info.action_info.action_shape - net = MLPActor(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) + net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = AdamOptimizerFactory(lr=lr) policy = DiscreteQLearningPolicy( diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index e16f32483..13ec86685 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -16,7 +16,7 @@ from tianshou.algorithm import GAIL from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import ( Batch, @@ -29,7 +29,7 @@ from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -122,7 +122,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -132,7 +132,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -154,7 +154,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: optim = AdamOptimizerFactory(lr=args.lr) # discriminator - net_d = MLPActor( + net_d = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -204,7 +204,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: ) print("dataset loaded") - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index bd5848d51..fadb7df2f 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -27,7 +27,7 @@ from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import MLPActor, get_dict_state_decorator +from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import ActionSpaceInfo @@ -149,7 +149,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: state_shape=args.state_shape, keys=["observation", "achieved_goal", "desired_goal"], ) - net_a = dict_state_dec(MLPActor)( + net_a = dict_state_dec(Net)( flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device, @@ -161,7 +161,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c = dict_state_dec(MLPActor)( + net_c = dict_state_dec(Net)( flat_state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index ef50ef0e6..40d917e5b 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -13,12 +13,12 @@ from tianshou.algorithm import A2C from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import LRSchedulerFactoryLinear, RMSpropOptimizerFactory from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import ActorCritic, MLPActor +from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -89,7 +89,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -99,7 +99,7 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -140,7 +140,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index c64bd4329..ca0f04bf6 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -17,7 +17,7 @@ from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic @@ -85,14 +85,14 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index a4215f948..06265e42f 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -13,12 +13,12 @@ from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -94,7 +94,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -104,7 +104,7 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -138,7 +138,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index c691aa9e3..60b5fd5d7 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -13,12 +13,12 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import ActorCritic, MLPActor +from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -94,7 +94,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -104,7 +104,7 @@ def main(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -141,7 +141,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 2b1762659..b8d472ed1 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -17,7 +17,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import EnsembleLinear, MLPActor +from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -89,7 +89,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, @@ -101,7 +101,7 @@ def main(args: argparse.Namespace = get_args()) -> None: def linear(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(args.ensemble_size, x, y) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index b05ecbcc3..e7758bcdb 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -13,12 +13,12 @@ from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic @@ -86,7 +86,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -124,7 +124,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 1b9de80c0..d9849fb5e 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -16,7 +16,7 @@ from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -85,7 +85,7 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net_a, action_shape=args.action_shape, @@ -93,13 +93,13 @@ def main(args: argparse.Namespace = get_args()) -> None: conditioned_sigma=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 7c2df9aab..f62a177c0 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -17,7 +17,7 @@ from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OffPolicyTrainerParams -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic @@ -90,20 +90,20 @@ def main(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net_a, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 1f8cbab2a..707f174f2 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -13,12 +13,12 @@ from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.trainer import OnPolicyTrainerParams -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic @@ -97,7 +97,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -107,7 +107,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: action_shape=args.action_shape, unbounded=True, ).to(args.device) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -141,7 +141,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=True, diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 11399683d..53ff6d2ce 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -19,7 +19,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import MLP, MLPActor +from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo @@ -109,13 +109,13 @@ def test_bcq() -> None: ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 8260b837c..cbe699495 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -19,7 +19,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -245,7 +245,7 @@ def test_cql() -> None: # model # actor network - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -259,13 +259,13 @@ def test_cql() -> None: actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index c6b620016..e3993f244 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -21,7 +21,7 @@ from tianshou.env import SubprocVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic from tianshou.utils.space_info import SpaceInfo @@ -86,7 +86,7 @@ def test_il() -> None: test_envs.seed(args.seed) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 1d5a69798..8a571606b 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -20,7 +20,7 @@ from tianshou.exploration import GaussianNoise from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -104,7 +104,7 @@ def test_td3_bc() -> None: # model # actor network - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, ) @@ -116,13 +116,13 @@ def test_td3_bc() -> None: actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index e3353be30..68efc1eb6 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -12,7 +12,7 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory, LRSchedulerFactoryLinear from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env.atari.atari_network import DQNet @@ -149,7 +149,7 @@ def dist(logits: torch.Tensor) -> Categorical: return Categorical(logits=logits) # define policy and algorithm - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=False, diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 96a550398..940ffd0b2 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -9,10 +9,10 @@ RandomActionPolicy, episode_mc_return_to_go, ) -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Batch -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.net.discrete import DiscreteActor @@ -38,7 +38,7 @@ def algorithm(request: pytest.FixtureRequest) -> PPO: if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) actor = ContinuousActorProbabilistic( - preprocess_net=MLPActor( + preprocess_net=Net( state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape ), action_shape=action_space.shape, @@ -51,7 +51,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: elif action_type == "discrete": action_space = gym.spaces.Discrete(3) actor = DiscreteActor( - preprocess_net=MLPActor( + preprocess_net=Net( state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n ), action_shape=action_space.n, @@ -61,13 +61,13 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: raise ValueError(f"Unknown action type: {action_type}") critic = ContinuousCritic( - preprocess_net=MLPActor(state_shape=obs_shape, hidden_sizes=[64, 64]), + preprocess_net=Net(state_shape=obs_shape, hidden_sizes=[64, 64]), ) optim = AdamOptimizerFactory(lr=1e-3) algorithm: PPO - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist_fn, action_space=action_space, diff --git a/test/base/test_utils.py b/test/base/test_utils.py index cc448e5da..6c992a165 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -9,7 +9,7 @@ from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, RunningMeanStd -from tianshou.utils.net.common import MLP, MLPActor +from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.torch_utils import create_uniform_action_dist, torch_train_mode @@ -62,7 +62,7 @@ def test_net() -> None: action_shape = (5,) data = torch.rand([bsz, *state_shape]) expect_output_shape = [bsz, *action_shape] - net = MLPActor( + net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128], @@ -73,7 +73,7 @@ def test_net() -> None: assert str(net).count("LayerNorm") == 2 assert str(net).count("ReLU") == 0 Q_param = V_param = {"hidden_sizes": [128, 128]} - net = MLPActor( + net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128], @@ -81,13 +81,11 @@ def test_net() -> None: ) assert list(net(data)[0].shape) == expect_output_shape # concat - net = MLPActor( - state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], concat=True - ) + net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], concat=True) data = torch.rand([bsz, int(np.prod(state_shape)) + int(np.prod(action_shape))]) expect_output_shape = [bsz, 128] assert list(net(data)[0].shape) == expect_output_shape - net = MLPActor( + net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128], diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index f48cf90cc..4139d1926 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -16,7 +16,7 @@ from tianshou.exploration import GaussianNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -74,13 +74,13 @@ def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = T test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 7d43a153d..93ab52afd 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -11,13 +11,13 @@ from tianshou.algorithm import NPG from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -80,7 +80,7 @@ def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -89,7 +89,7 @@ def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( - preprocess_net=MLPActor( + preprocess_net=Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -108,7 +108,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index e177967e0..686fa86d7 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -10,13 +10,13 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, MLPActor +from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -84,12 +84,12 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( - preprocess_net=MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), + preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization @@ -105,7 +105,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index a43c43098..9ae03529c 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -17,7 +17,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import EnsembleLinear, MLPActor +from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -81,7 +81,7 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, @@ -93,7 +93,7 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T def linear(x: int, y: int) -> nn.Module: return EnsembleLinear(args.ensemble_size, x, y) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 409b1644f..31fd973c1 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -16,7 +16,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ( ContinuousActorDeterministic, ContinuousActorProbabilistic, @@ -94,12 +94,12 @@ def test_sac_with_il( test_envs.seed(args.seed + args.num_train_envs) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -107,7 +107,7 @@ def test_sac_with_il( ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -184,7 +184,7 @@ def stop_fn(mean_rewards: float) -> bool: # here we define an imitation collector with a trivial policy if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal - il_net = MLPActor( + il_net = Net( state_shape=args.state_shape, hidden_sizes=args.imitation_hidden_sizes, ) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index f6476f282..60678f5c5 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -16,7 +16,7 @@ from tianshou.exploration import GaussianNoise from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -77,14 +77,14 @@ def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorDeterministic( preprocess_net=net, action_shape=args.action_shape, max_action=args.max_action ).to( args.device, ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -92,7 +92,7 @@ def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) critic1 = ContinuousCritic(preprocess_net=net_c1).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 671e258f0..855a53def 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -11,13 +11,13 @@ from tianshou.algorithm import TRPO from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -81,7 +81,7 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -90,7 +90,7 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T preprocess_net=net, action_shape=args.action_shape, unbounded=True ).to(args.device) critic = ContinuousCritic( - preprocess_net=MLPActor( + preprocess_net=Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, activation=nn.Tanh, @@ -109,7 +109,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index e30d7ede1..ff255c096 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -10,13 +10,13 @@ from tianshou.algorithm import A2C, Algorithm, OffPolicyImitationLearning from tianshou.algorithm.imitation.imitation_base import ImitationPolicy -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams, OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic try: @@ -98,12 +98,12 @@ def test_a2c_with_il( default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) critic = DiscreteCritic(preprocess_net=net).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), @@ -168,7 +168,7 @@ def stop_fn(mean_rewards: float) -> bool: # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) optim = AdamOptimizerFactory(lr=args.il_lr) il_policy = ImitationPolicy( diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 7795a54f9..ed828d0ae 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -22,7 +22,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -87,7 +87,7 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index 5cd8ae305..dcec3863f 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -18,7 +18,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -83,15 +83,15 @@ def test_discrete_sac( # model obs_dim = space_info.observation_info.obs_dim action_dim = space_info.action_info.action_dim - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, softmax_output=False ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c1 = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) critic1 = DiscreteCritic(preprocess_net=net_c1, last_size=action_dim).to(args.device) critic1_optim = AdamOptimizerFactory(lr=args.critic_lr) - net_c2 = MLPActor(state_shape=obs_dim, hidden_sizes=args.hidden_sizes) + net_c2 = Net(state_shape=obs_dim, hidden_sizes=args.hidden_sizes) critic2 = DiscreteCritic(preprocess_net=net_c2, last_size=action_dim).to(args.device) critic2_optim = AdamOptimizerFactory(lr=args.critic_lr) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 654ff07c8..947fa5e66 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -21,7 +21,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -82,7 +82,7 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # Q_param = V_param = {"hidden_sizes": [128]} # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index b0b92d126..c2fd72531 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -21,7 +21,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -86,7 +86,7 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - feature_net = MLPActor( + feature_net = Net( state_shape=args.state_shape, action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index c0ad64700..bcbe8dee8 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -21,7 +21,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import ImplicitQuantileNetwork from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -86,7 +86,7 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs.seed(args.seed) # model - feature_net = MLPActor( + feature_net = Net( state_shape=args.state_shape, action_shape=args.hidden_sizes[-1], hidden_sizes=args.hidden_sizes[:-1], diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index d8a6e61f9..23f60ed87 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -16,11 +16,11 @@ from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ( + ActionReprNet, + ActionReprNetDataParallelWrapper, ActorCritic, - ActorForwardInterface, DataParallelNet, - MLPActor, - PolicyForwardDataParallelWrapper, + Net, ) from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -85,11 +85,11 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) critic: DiscreteCritic | DataParallelNet - actor: ActorForwardInterface + actor: ActionReprNet if torch.cuda.is_available(): - actor = PolicyForwardDataParallelWrapper( + actor = ActionReprNetDataParallelWrapper( DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) ) critic = DataParallelNet(DiscreteCritic(preprocess_net=net).to(args.device)) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index de67f7797..c9ddcd985 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -20,7 +20,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -86,7 +86,7 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index d326a4f17..418780908 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -21,7 +21,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear from tianshou.utils.space_info import SpaceInfo from tianshou.utils.torch_utils import policy_within_training_step @@ -96,7 +96,7 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: return NoisyLinear(x, y, args.noisy_std) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/discrete/test_reinforce.py b/test/discrete/test_reinforce.py index dd3376558..e234d30fa 100644 --- a/test/discrete/test_reinforce.py +++ b/test/discrete/test_reinforce.py @@ -10,13 +10,13 @@ from tianshou.algorithm import Reinforce from tianshou.algorithm.algorithm_base import Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -68,7 +68,7 @@ def test_reinforce(args: argparse.Namespace = get_args(), enable_assertions: boo test_envs.seed(args.seed) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -76,7 +76,7 @@ def test_reinforce(args: argparse.Namespace = get_args(), enable_assertions: boo ).to(args.device) optim = AdamOptimizerFactory(lr=args.lr) dist_fn = torch.distributions.Categorical - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=net, dist_fn=dist_fn, action_space=env.action_space, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index d44c003cc..38640edfe 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -18,7 +18,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLP, MLPActor +from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.discrete import IntrinsicCuriosityModule from tianshou.utils.space_info import SpaceInfo @@ -99,7 +99,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: # Q_param = V_param = {"hidden_sizes": [128]} # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index c0ebb723a..556bde6b6 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -10,13 +10,13 @@ from tianshou.algorithm import PPO from tianshou.algorithm.algorithm_base import Algorithm from tianshou.algorithm.modelbased.icm import ICMOnPolicyWrapper -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLP, ActorCritic, MLPActor +from tianshou.utils.net.common import MLP, ActorCritic, Net from tianshou.utils.net.discrete import ( DiscreteActor, DiscreteCritic, @@ -104,7 +104,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = DiscreteActor(preprocess_net=net, action_shape=args.action_shape).to(args.device) critic = DiscreteCritic(preprocess_net=net).to(args.device) actor_critic = ActorCritic(actor, critic) @@ -118,7 +118,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # base algorithm: PPO optim = AdamOptimizerFactory(lr=args.lr) dist = torch.distributions.Categorical - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_scaling=isinstance(env.action_space, Box), diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 583e2cf68..cda914cee 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -20,7 +20,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -89,7 +89,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 2e4513e3e..77b4ff4ee 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -15,7 +15,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -91,14 +91,14 @@ def gather_data() -> VectorReplayBuffer: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, unbounded=True, ).to(args.device) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index eeafcf922..1d5ce433c 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -17,7 +17,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLP, MLPActor +from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, ContinuousCritic, Perturbation from tianshou.utils.space_info import SpaceInfo @@ -111,7 +111,7 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) actor_optim = AdamOptimizerFactory(lr=args.actor_lr) - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index dcc7d53ad..803d99c4f 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -17,7 +17,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -105,7 +105,7 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # model # actor network - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -119,7 +119,7 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic network - net_c = MLPActor( + net_c = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index c04fbacab..1874a36fa 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -21,7 +21,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor from tianshou.utils.space_info import SpaceInfo @@ -80,7 +80,7 @@ def test_discrete_bcq( test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) policy_net = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index aa867a140..7fa959d11 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -21,7 +21,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.space_info import SpaceInfo @@ -77,7 +77,7 @@ def test_discrete_cql( test_envs.seed(args.seed) # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 71a2e23d1..aea9afce7 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -21,7 +21,7 @@ from tianshou.env import DummyVectorEnv from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DiscreteActor, DiscreteCritic from tianshou.utils.space_info import SpaceInfo @@ -75,7 +75,7 @@ def test_discrete_crr( test_envs.seed(args.seed) # model and algorithm - net = MLPActor(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0]) actor = DiscreteActor( preprocess_net=net, action_shape=args.action_shape, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index f4df597dc..56e9d9d5d 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -11,13 +11,13 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.algorithm import GAIL, Algorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.trainer import OnPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, MLPActor +from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ContinuousActorProbabilistic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -92,7 +92,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes) actor = ContinuousActorProbabilistic( preprocess_net=net, action_shape=args.action_shape, @@ -101,7 +101,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T args.device, ) critic = ContinuousCritic( - preprocess_net=MLPActor(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), + preprocess_net=Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes), ).to(args.device) actor_critic = ActorCritic(actor, critic) # orthogonal initialization @@ -112,7 +112,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T optim = AdamOptimizerFactory(lr=args.lr) # discriminator disc_net = ContinuousCritic( - preprocess_net=MLPActor( + preprocess_net=Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, @@ -133,7 +133,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 99f7e15bc..9c1ed615b 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -19,7 +19,7 @@ from tianshou.exploration import GaussianNoise from tianshou.trainer import OfflineTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ContinuousActorDeterministic, ContinuousCritic from tianshou.utils.space_info import SpaceInfo @@ -96,7 +96,7 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = test_envs.seed(args.seed) # actor network - net_a = MLPActor( + net_a = Net( state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, ) @@ -108,13 +108,13 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = actor_optim = AdamOptimizerFactory(lr=args.actor_lr) # critic networks - net_c1 = MLPActor( + net_c1 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, ) - net_c2 = MLPActor( + net_c2 = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 7f7857824..e1dc3f5d4 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -17,7 +17,7 @@ from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net def get_parser() -> argparse.ArgumentParser: @@ -96,7 +96,7 @@ def get_agents( optims = [] for _ in range(args.n_pistons): # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 30efcdd01..5b5fa1555 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -13,7 +13,7 @@ from tianshou.algorithm import PPO, Algorithm from tianshou.algorithm.algorithm_base import OnPolicyAlgorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.multiagent.marl import MultiAgentOnPolicyAlgorithm from tianshou.algorithm.optim import AdamOptimizerFactory from tianshou.data import Collector, CollectStats, VectorReplayBuffer @@ -193,7 +193,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy = ActorPolicyProbabilistic( + policy = ProbabilisticActorPolicy( actor=actor, dist_fn=dist, action_space=env.action_space, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 099be67a5..cebabd4c1 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -24,7 +24,7 @@ from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.trainer import OffPolicyTrainerParams from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net def get_env(render_mode: str | None = None) -> PettingZooEnv: @@ -115,7 +115,7 @@ def get_agents( args.action_shape = env.action_space.shape or int(env.action_space.n) if agent_learn is None: # model - net = MLPActor( + net = Net( state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, diff --git a/tianshou/algorithm/imitation/gail.py b/tianshou/algorithm/imitation/gail.py index 3d93ae8e9..648358d2c 100644 --- a/tianshou/algorithm/imitation/gail.py +++ b/tianshou/algorithm/imitation/gail.py @@ -6,7 +6,7 @@ from tianshou.algorithm.modelfree.a2c import A2CTrainingStats from tianshou.algorithm.modelfree.ppo import PPO -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ( ReplayBuffer, @@ -34,7 +34,7 @@ class GAIL(PPO): def __init__( self, *, - policy: ActorPolicyProbabilistic, + policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, expert_buffer: ReplayBuffer, diff --git a/tianshou/algorithm/modelfree/a2c.py b/tianshou/algorithm/modelfree/a2c.py index 91d7cbe9a..f26414f7e 100644 --- a/tianshou/algorithm/modelfree/a2c.py +++ b/tianshou/algorithm/modelfree/a2c.py @@ -10,7 +10,7 @@ OnPolicyAlgorithm, TrainingStats, ) -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol @@ -29,13 +29,13 @@ class A2CTrainingStats(TrainingStats): gradient_steps: int -class ActorCriticOnPolicyAlgorithm(OnPolicyAlgorithm[ActorPolicyProbabilistic], ABC): +class ActorCriticOnPolicyAlgorithm(OnPolicyAlgorithm[ProbabilisticActorPolicy], ABC): """Abstract base class for actor-critic algorithms that use generalized advantage estimation (GAE).""" def __init__( self, *, - policy: ActorPolicyProbabilistic, + policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_include_actor: bool, @@ -157,7 +157,7 @@ class A2C(ActorCriticOnPolicyAlgorithm): def __init__( self, *, - policy: ActorPolicyProbabilistic, + policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, vf_coef: float = 0.5, diff --git a/tianshou/algorithm/modelfree/c51.py b/tianshou/algorithm/modelfree/c51.py index ed865dbf6..8ca11f37d 100644 --- a/tianshou/algorithm/modelfree/c51.py +++ b/tianshou/algorithm/modelfree/c51.py @@ -10,13 +10,13 @@ from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import Batch, ReplayBuffer from tianshou.data.types import RolloutBatchProtocol -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net class C51Policy(DiscreteQLearningPolicy): def __init__( self, - model: torch.nn.Module | MLPActor, + model: torch.nn.Module | Net, action_space: gym.spaces.Space, observation_space: gym.Space | None = None, num_atoms: int = 51, diff --git a/tianshou/algorithm/modelfree/ddpg.py b/tianshou/algorithm/modelfree/ddpg.py index 0e44bc4ad..2520fd8d1 100644 --- a/tianshou/algorithm/modelfree/ddpg.py +++ b/tianshou/algorithm/modelfree/ddpg.py @@ -29,7 +29,7 @@ ) from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.utils.net.continuous import ( - ContinuousActorDeterministicInterface, + AbstractContinuousActorDeterministic, ContinuousCritic, ) @@ -117,7 +117,7 @@ class ContinuousDeterministicPolicy(ContinuousPolicyWithExplorationNoise): def __init__( self, *, - actor: ContinuousActorDeterministicInterface, + actor: AbstractContinuousActorDeterministic, exploration_noise: BaseNoise | Literal["default"] | None = None, action_space: gym.Space, observation_space: gym.Space | None = None, diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index 8118325c9..a7b6be867 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -28,11 +28,11 @@ RolloutBatchProtocol, ) from tianshou.utils.lagged_network import EvalModeModuleWrapper -from tianshou.utils.net.common import MLPActor +from tianshou.utils.net.common import Net mark_used(ActBatchProtocol) -TModel = TypeVar("TModel", bound=torch.nn.Module | MLPActor) +TModel = TypeVar("TModel", bound=torch.nn.Module | Net) log = logging.getLogger(__name__) diff --git a/tianshou/algorithm/modelfree/npg.py b/tianshou/algorithm/modelfree/npg.py index 71b170520..637850031 100644 --- a/tianshou/algorithm/modelfree/npg.py +++ b/tianshou/algorithm/modelfree/npg.py @@ -9,7 +9,7 @@ from tianshou.algorithm.algorithm_base import TrainingStats from tianshou.algorithm.modelfree.a2c import ActorCriticOnPolicyAlgorithm -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol @@ -33,7 +33,7 @@ class NPG(ActorCriticOnPolicyAlgorithm): def __init__( self, *, - policy: ActorPolicyProbabilistic, + policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, optim_critic_iters: int = 5, diff --git a/tianshou/algorithm/modelfree/ppo.py b/tianshou/algorithm/modelfree/ppo.py index b6a3e8aa6..ede1d3418 100644 --- a/tianshou/algorithm/modelfree/ppo.py +++ b/tianshou/algorithm/modelfree/ppo.py @@ -5,7 +5,7 @@ from tianshou.algorithm import A2C from tianshou.algorithm.modelfree.a2c import A2CTrainingStats -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol @@ -19,7 +19,7 @@ class PPO(A2C): def __init__( self, *, - policy: ActorPolicyProbabilistic, + policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, eps_clip: float = 0.2, diff --git a/tianshou/algorithm/modelfree/reinforce.py b/tianshou/algorithm/modelfree/reinforce.py index 7fcc2f19a..60fc91cae 100644 --- a/tianshou/algorithm/modelfree/reinforce.py +++ b/tianshou/algorithm/modelfree/reinforce.py @@ -31,9 +31,9 @@ ) from tianshou.utils import RunningMeanStd from tianshou.utils.net.common import ( - ActorForwardInterface, - ContinuousActorProbabilisticInterface, - DiscreteActorInterface, + AbstractContinuousActorProbabilistic, + AbstractDiscreteActor, + ActionReprNet, ) from tianshou.utils.net.discrete import dist_fn_categorical_from_logits @@ -65,7 +65,7 @@ class SimpleLossTrainingStats(TrainingStats): loss: float -class ActorPolicyProbabilistic(Policy): +class ProbabilisticActorPolicy(Policy): """ A policy that outputs (representations of) probability distributions from which actions can be sampled. @@ -74,9 +74,7 @@ class ActorPolicyProbabilistic(Policy): def __init__( self, *, - actor: ContinuousActorProbabilisticInterface - | DiscreteActorInterface - | ActorForwardInterface, + actor: AbstractContinuousActorProbabilistic | AbstractDiscreteActor | ActionReprNet, dist_fn: TDistFnDiscrOrCont, deterministic_eval: bool = False, action_space: gym.Space, @@ -194,11 +192,11 @@ def forward( return cast(DistBatchProtocol, result) -class DiscreteActorPolicy(ActorPolicyProbabilistic): +class DiscreteActorPolicy(ProbabilisticActorPolicy): def __init__( self, *, - actor: DiscreteActorInterface | ActorForwardInterface, + actor: AbstractDiscreteActor | ActionReprNet, dist_fn: TDistFnDiscrete = dist_fn_categorical_from_logits, deterministic_eval: bool = False, action_space: gym.Space, @@ -245,7 +243,7 @@ def __init__( ) -TActorPolicy = TypeVar("TActorPolicy", bound=ActorPolicyProbabilistic) +TActorPolicy = TypeVar("TActorPolicy", bound=ProbabilisticActorPolicy) class DiscountedReturnComputation: @@ -314,13 +312,13 @@ def add_discounted_returns( return cast(BatchWithReturnsProtocol, batch) -class Reinforce(OnPolicyAlgorithm[ActorPolicyProbabilistic]): +class Reinforce(OnPolicyAlgorithm[ProbabilisticActorPolicy]): """Implementation of the REINFORCE (a.k.a. vanilla policy gradient) algorithm.""" def __init__( self, *, - policy: ActorPolicyProbabilistic, + policy: ProbabilisticActorPolicy, gamma: float = 0.99, return_standardization: bool = False, optim: OptimizerFactory, diff --git a/tianshou/algorithm/modelfree/trpo.py b/tianshou/algorithm/modelfree/trpo.py index a52c1ac44..450fdde54 100644 --- a/tianshou/algorithm/modelfree/trpo.py +++ b/tianshou/algorithm/modelfree/trpo.py @@ -7,7 +7,7 @@ from tianshou.algorithm import NPG from tianshou.algorithm.modelfree.npg import NPGTrainingStats -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.optim import OptimizerFactory from tianshou.data import SequenceSummaryStats from tianshou.data.types import BatchWithAdvantagesProtocol @@ -26,7 +26,7 @@ class TRPO(NPG): def __init__( self, *, - policy: ActorPolicyProbabilistic, + policy: ProbabilisticActorPolicy, critic: torch.nn.Module | ContinuousCritic | DiscreteCritic, optim: OptimizerFactory, max_kl: float = 0.01, diff --git a/tianshou/env/atari/atari_network.py b/tianshou/env/atari/atari_network.py index ab4693a74..79b862c4c 100644 --- a/tianshou/env/atari/atari_network.py +++ b/tianshou/env/atari/atari_network.py @@ -18,7 +18,9 @@ IntermediateModuleFactory, ) from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical -from tianshou.utils.net.common import Actor, ModuleWithVectorOutput +from tianshou.utils.net.common import ( + ActionReprNetWithVectorOutput, +) from tianshou.utils.net.discrete import DiscreteActor, NoisyLinear from tianshou.utils.torch_utils import torch_device @@ -33,15 +35,12 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0. T = TypeVar("T") -class ScaledObsInputModule(Actor): - def __init__(self, module: Actor, denom: float = 255.0) -> None: +class ScaledObsInputActionReprNet(ActionReprNetWithVectorOutput): + def __init__(self, module: ActionReprNetWithVectorOutput, denom: float = 255.0) -> None: super().__init__(module.get_output_dim()) self.module = module self.denom = denom - def get_preprocess_net(self) -> ModuleWithVectorOutput: - return self.module.get_preprocess_net() - def forward( self, obs: TObs, @@ -58,7 +57,7 @@ def forward( return self.module.forward(scaled_obs, state, info) -class DQNet(Actor[Any]): +class DQNet(ActionReprNetWithVectorOutput[Any]): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -112,9 +111,6 @@ def __init__( super().__init__(output_dim) self.net = net - def get_preprocess_net(self) -> ModuleWithVectorOutput: - return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) - def forward( self, obs: TObs, @@ -163,7 +159,7 @@ def forward( return obs, state -class Rainbow(DQNet): +class RainbowNet(DQNet): """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning. For advanced usage (how to customize the network), please refer to @@ -273,7 +269,7 @@ def create_module(self, envs: Environments, device: TDevice) -> DiscreteActor: action_shape = envs.get_action_shape() if isinstance(action_shape, np.int64): action_shape = int(action_shape) - net: DQNet | ScaledObsInputModule + net: DQNet | ScaledObsInputActionReprNet net = DQNet( c=c, h=h, @@ -284,7 +280,7 @@ def create_module(self, envs: Environments, device: TDevice) -> DiscreteActor: layer_init=layer_init, ) if self.scale_obs: - net = ScaledObsInputModule(net) + net = ScaledObsInputActionReprNet(net) return DiscreteActor( preprocess_net=net, action_shape=envs.get_action_shape(), diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 9a52f1366..3e22e5320 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -32,7 +32,7 @@ from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy from tianshou.algorithm.modelfree.iqn import IQNPolicy from tianshou.algorithm.modelfree.redq import REDQPolicy -from tianshou.algorithm.modelfree.reinforce import ActorPolicyProbabilistic +from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy from tianshou.algorithm.modelfree.sac import SACPolicy from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data.collector import BaseCollector, CollectStats @@ -310,7 +310,7 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Reinforce: dist_fn = self.actor_factory.create_dist_fn(envs) assert dist_fn is not None policy = self._create_policy_from_args( - ActorPolicyProbabilistic, + ProbabilisticActorPolicy, kwargs, ["action_scaling", "action_bound_method", "deterministic_eval"], actor=actor, @@ -368,7 +368,7 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: def _create_algorithm(self, envs: Environments, device: TDevice) -> TAlgorithm: params = self._create_kwargs(envs, device) policy = self._create_policy_from_args( - ActorPolicyProbabilistic, + ProbabilisticActorPolicy, params, [ "actor", diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index aa3f7b233..83ae08878 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -26,9 +26,9 @@ from tianshou.utils.net import continuous, discrete from tianshou.utils.net.common import ( Actor, - MLPActor, ModuleType, ModuleWithVectorOutput, + Net, ) @@ -151,7 +151,7 @@ def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: - net_a = MLPActor( + net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, @@ -187,7 +187,7 @@ def __init__( self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: - net_a = MLPActor( + net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, @@ -222,7 +222,7 @@ def __init__( self.activation = activation def create_module(self, envs: Environments, device: TDevice) -> Actor: - net_a = MLPActor( + net_a = Net( state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index f76ab6be5..54596be12 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -9,7 +9,7 @@ from tianshou.highlevel.module.actor import ActorFuture from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.utils.net import continuous -from tianshou.utils.net.common import Actor, EnsembleLinear, MLPActor, ModuleType +from tianshou.utils.net.common import Actor, EnsembleLinear, ModuleType, Net from tianshou.utils.net.continuous import ContinuousCritic from tianshou.utils.net.discrete import DiscreteCritic @@ -91,7 +91,7 @@ def create_module( discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 - net_c = MLPActor( + net_c = Net( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, @@ -116,7 +116,7 @@ def create_module( discrete_last_size_use_action_shape: bool = False, ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 - net_c = MLPActor( + net_c = Net( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, @@ -239,7 +239,7 @@ def linear_layer(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(ensemble_size, x, y) action_shape = envs.get_action_shape() if use_action else 0 - net_c = MLPActor( + net_c = Net( state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 49d093a12..b6da01b5f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -181,11 +181,17 @@ def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: TRecurrentState = TypeVar("TRecurrentState", bound=Any) -class ActorForwardInterface(Generic[TRecurrentState], nn.Module, ABC): - """Defines the `forward` interface for neural networks used as actors in policies. - - Note that for DQN-like algorithms the critic is used as an actor (since actions - are computed from it), see e.g. :class:`~DiscreteActor`. +class ActionReprNet(Generic[TRecurrentState], nn.Module, ABC): + """Abstract base class for neural networks used to compute action-related + representations from environment observations, which defines the + signature of the forward method. + + An action-related representation can be a number of things, including: + * a distribution over actions in a discrete action space in the form of a vector of + unnormalized log probabilities (called "logits" in PyTorch jargon) + * the Q-values of all actions in a discrete action space + * the parameters of a distribution (e.g., mean and std. dev. for a Gaussian distribution) + over actions in a continuous action space """ @abstractmethod @@ -201,7 +207,8 @@ def forward( Implementations will always make use of the preprocess_net as the first processing step. :param obs: the observations from the environment as retrieved from `ObsBatchProtocol.obs`. - If the environment is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors). + If the environment is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your + env returns tensors). :param state: the hidden state of the RNN, if applicable :param info: the info object from the environment step :return: a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or @@ -210,22 +217,34 @@ def forward( """ -class Actor(Generic[T], ModuleWithVectorOutput, ActorForwardInterface[T], ABC): +class ActionReprNetWithVectorOutput(Generic[T], ActionReprNet[T], ModuleWithVectorOutput): + """A neural network for the computation of action-related representations which outputs + a vector of a known size. + """ + + def __init__(self, output_dim: int) -> None: + super().__init__(output_dim) + + +class Actor(Generic[T], ActionReprNetWithVectorOutput[T], ABC): @abstractmethod def get_preprocess_net(self) -> ModuleWithVectorOutput: - """Typically a first part of the network that preprocesses the input into a latent representation. + """Returns the network component that is used for pre-processing, i.e. + the component which produces a latent representation, which then is transformed + into the final output. + This is, therefore, the first part of the network which processes the input. + For example, a CNN is often used in Atari examples. - E.g., a CNN (often used in atari examples). We need this method to be able to - share latent representation with other networks (e.g., critic) within an Algorithm. - Networks that don't have this can use nn.Identity() as a preprocess net (see :class:`RandomActor`). - """ + We need this method to be able to share latent representation computations with + other networks (e.g. critics) within an algorithm. + Actors that do not have a pre-processing stage can return nn.Identity() + (see :class:`RandomActor` for an example). + """ -class MLPActor(Actor[Any]): - """Wrapper of MLP to support more specific DRL usage. - For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. +class Net(ActionReprNetWithVectorOutput[Any]): + """A multi-layer perceptron which outputs an action-related representation. :param state_shape: int or a sequence of int of the shape of state. :param action_shape: int or a sequence of int of the shape of action. @@ -321,9 +340,6 @@ def __init__( self.Q = Q self.V = V - def get_preprocess_net(self) -> ModuleWithVectorOutput: - return ModuleWithVectorOutput.from_module(nn.Identity(), self.output_dim) - def forward( self, obs: TObs, @@ -353,7 +369,7 @@ def forward( return logits, state -class Recurrent(Actor[RecurrentStateBatch]): +class Recurrent(ActionReprNetWithVectorOutput[RecurrentStateBatch]): """Simple Recurrent network based on LSTM. For advanced usage (how to customize the network), please refer to @@ -485,9 +501,9 @@ def forward( # The same functionality as DataParallelNet -# The duplication is worth it because the PolicyForwardInterface is so important -class PolicyForwardDataParallelWrapper(ActorForwardInterface): - def __init__(self, net: ActorForwardInterface) -> None: +# The duplication is worth it because the ActionReprNet abstraction is so important +class ActionReprNetDataParallelWrapper(ActionReprNet): + def __init__(self, net: ActionReprNet) -> None: super().__init__() self.net = nn.DataParallel(net) @@ -538,7 +554,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class BranchingNet(ActorForwardInterface): +class BranchingNet(ActionReprNet): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module @@ -709,11 +725,11 @@ def forward(self, obs: TObs, *args, **kwargs) -> Any: return decorator_fn, new_state_shape -class ContinuousActorProbabilisticInterface(Actor, ABC): +class AbstractContinuousActorProbabilistic(Actor, ABC): """Type bound for probabilistic actors which output distribution parameters for continuous action spaces.""" -class DiscreteActorInterface(Actor, ABC): +class AbstractDiscreteActor(Actor, ABC): """ Type bound for discrete actors. @@ -731,7 +747,7 @@ class DiscreteActorInterface(Actor, ABC): """ -class RandomActor(ContinuousActorProbabilisticInterface, DiscreteActorInterface): +class RandomActor(AbstractContinuousActorProbabilistic, AbstractDiscreteActor): """An actor that returns random actions. For continuous action spaces, forward returns a batch of random actions sampled from the action space. @@ -787,7 +803,3 @@ def compute_action_batch(self, obs: TObs) -> torch.Tensor: return torch.Tensor(np.random.randint(low=0, high=self.action_space.n, size=len(obs))) else: return self.forward(obs)[0] - - -class NetBase: - pass diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 89adc9281..83cddd049 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -11,8 +11,8 @@ from tianshou.data.types import TObs from tianshou.utils.net.common import ( MLP, + AbstractContinuousActorProbabilistic, Actor, - ContinuousActorProbabilisticInterface, ModuleWithVectorOutput, TActionShape, TLinearLayer, @@ -25,11 +25,11 @@ T = TypeVar("T") -class ContinuousActorDeterministicInterface(Actor, ABC): +class AbstractContinuousActorDeterministic(Actor, ABC): """Marker interface for continuous deterministic actors (DDPG like).""" -class ContinuousActorDeterministic(ContinuousActorDeterministicInterface): +class ContinuousActorDeterministic(AbstractContinuousActorDeterministic): """Actor network that directly outputs actions for continuous action space. Used primarily in DDPG and its variants. @@ -175,7 +175,7 @@ def forward( return self.last(obs) -class ContinuousActorProbabilistic(ContinuousActorProbabilisticInterface): +class ContinuousActorProbabilistic(AbstractContinuousActorProbabilistic): """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian). Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 69f3b1753..8f2e11f1e 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -10,7 +10,7 @@ from tianshou.data.types import TObs from tianshou.utils.net.common import ( MLP, - DiscreteActorInterface, + AbstractDiscreteActor, ModuleWithVectorOutput, TActionShape, ) @@ -24,7 +24,7 @@ def dist_fn_categorical_from_logits(logits: torch.Tensor) -> torch.distributions return torch.distributions.Categorical(logits=logits) -class DiscreteActor(DiscreteActorInterface): +class DiscreteActor(AbstractDiscreteActor): """ Generic discrete actor which uses a preprocessing network to generate a latent representation which is subsequently passed to an MLP to compute the output. diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index a4e15ba6c..6273a41b4 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -23,7 +23,9 @@ def torch_train_mode(module: nn.Module, enabled: bool = True) -> Iterator[None]: @contextmanager -def policy_within_training_step(policy: "algorithm_base.Policy", enabled: bool = True) -> Iterator[None]: +def policy_within_training_step( + policy: "algorithm_base.Policy", enabled: bool = True +) -> Iterator[None]: """Temporarily switch to `policy.is_within_training_step=enabled`. Enabling this ensures that the policy is able to adapt its behavior, From 75cfaf0f859f7794bd0a6e97412432f045e503db Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 12:11:40 +0200 Subject: [PATCH 214/230] v2: Add mock import for cv2 (used in atari_wrapper) --- docs/_config.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/_config.yml b/docs/_config.yml index a0bb290a2..fce609211 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -103,6 +103,9 @@ sphinx: config : # key-value pairs to directly over-ride the Sphinx configuration autodoc_typehints_format: "short" autodoc_member_order: "bysource" + autodoc_mock_imports: + # mock imports for optional dependencies (e.g. dependencies of atari/atari_wrapper) + - cv2 autoclass_content: "both" autodoc_default_options: show-inheritance: True From 52092181423f2203f1d00f50c770bf8b2e8c06cc Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 13:34:00 +0200 Subject: [PATCH 215/230] v2: Fix argument references: test_num -> num_test_envs --- docs/01_tutorials/04_tictactoe.rst | 4 ++-- examples/atari/atari_c51.py | 6 +++--- examples/atari/atari_dqn.py | 6 +++--- examples/atari/atari_fqf.py | 6 +++--- examples/atari/atari_iqn.py | 6 +++--- examples/atari/atari_ppo.py | 6 +++--- examples/atari/atari_qrdqn.py | 6 +++--- examples/atari/atari_rainbow.py | 6 +++--- examples/atari/atari_sac.py | 6 +++--- examples/box2d/acrobot_dualdqn.py | 6 +++--- examples/box2d/bipedal_bdq.py | 6 +++--- examples/box2d/bipedal_hardcore_sac.py | 6 +++--- examples/box2d/lunarlander_dqn.py | 6 +++--- examples/box2d/mcc_sac.py | 6 +++--- examples/inverse/irl_gail.py | 6 +++--- examples/mujoco/fetch_her_ddpg.py | 6 +++--- examples/mujoco/mujoco_a2c.py | 6 +++--- examples/mujoco/mujoco_ddpg.py | 6 +++--- examples/mujoco/mujoco_npg.py | 6 +++--- examples/mujoco/mujoco_ppo.py | 6 +++--- examples/mujoco/mujoco_redq.py | 6 +++--- examples/mujoco/mujoco_reinforce.py | 6 +++--- examples/mujoco/mujoco_sac.py | 6 +++--- examples/mujoco/mujoco_td3.py | 6 +++--- examples/mujoco/mujoco_trpo.py | 6 +++--- examples/offline/atari_bcq.py | 6 +++--- examples/offline/atari_cql.py | 6 +++--- examples/offline/atari_crr.py | 6 +++--- examples/offline/atari_il.py | 6 +++--- examples/offline/d4rl_bcq.py | 6 +++--- examples/offline/d4rl_cql.py | 6 +++--- examples/offline/d4rl_il.py | 6 +++--- examples/offline/d4rl_td3_bc.py | 6 +++--- examples/vizdoom/vizdoom_c51.py | 6 +++--- examples/vizdoom/vizdoom_ppo.py | 6 +++--- test/continuous/test_ddpg.py | 4 ++-- test/continuous/test_npg.py | 4 ++-- test/continuous/test_redq.py | 4 ++-- test/continuous/test_sac_with_il.py | 12 ++++++------ test/continuous/test_td3.py | 4 ++-- test/continuous/test_trpo.py | 4 ++-- test/discrete/test_a2c_with_il.py | 12 ++++++------ test/discrete/test_bdqn.py | 4 ++-- test/discrete/test_c51.py | 4 ++-- test/discrete/test_discrete_sac.py | 4 ++-- test/discrete/test_dqn.py | 4 ++-- test/discrete/test_drqn.py | 4 ++-- test/discrete/test_fqf.py | 4 ++-- test/discrete/test_iqn.py | 4 ++-- test/discrete/test_ppo_discrete.py | 4 ++-- test/discrete/test_qrdqn.py | 4 ++-- test/discrete/test_rainbow.py | 4 ++-- test/discrete/test_reinforce.py | 4 ++-- test/modelbased/test_dqn_icm.py | 4 ++-- test/modelbased/test_ppo_icm.py | 4 ++-- test/modelbased/test_psrl.py | 4 ++-- test/offline/gather_cartpole_data.py | 4 ++-- test/offline/gather_pendulum_data.py | 4 ++-- test/offline/test_bcq.py | 4 ++-- test/offline/test_cql.py | 4 ++-- test/offline/test_discrete_bcq.py | 4 ++-- test/offline/test_discrete_cql.py | 4 ++-- test/offline/test_discrete_crr.py | 4 ++-- test/offline/test_gail.py | 4 ++-- test/offline/test_td3_bc.py | 4 ++-- test/pettingzoo/pistonball.py | 4 ++-- test/pettingzoo/pistonball_continuous.py | 4 ++-- test/pettingzoo/tic_tac_toe.py | 4 ++-- 68 files changed, 178 insertions(+), 178 deletions(-) diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst index 22d712024..cba591fbb 100644 --- a/docs/01_tutorials/04_tictactoe.rst +++ b/docs/01_tutorials/04_tictactoe.rst @@ -357,7 +357,7 @@ With the above preparation, we are close to the first learned agent. The followi # ======== environment setup ========= train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -418,7 +418,7 @@ With the above preparation, we are close to the first learned agent. The followi args.epoch, args.epoch_num_steps, args.collection_step_num_env_steps, - args.test_num, + args.num_test_envs, args.batch_size, train_fn=train_fn, test_fn=test_fn, diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 2f2482a3a..481562855 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -73,7 +73,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -191,7 +191,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -209,7 +209,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index a40f60162..c516520ea 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -90,7 +90,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -233,7 +233,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -252,7 +252,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index d0c768997..f6b1c1deb 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -76,7 +76,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -207,7 +207,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -226,7 +226,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, test_fn=test_fn, diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index f489057bb..b89738e53 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -76,7 +76,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -201,7 +201,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -220,7 +220,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 0d1aa385a..16d72fb15 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -107,7 +107,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, scale=0, frame_stack=args.frames_stack, ) @@ -263,7 +263,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -282,7 +282,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 2b2c84890..99706ed0e 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -71,7 +71,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -195,7 +195,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -214,7 +214,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 005b14fb2..27f13c3d7 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -87,7 +87,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -238,7 +238,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -257,7 +257,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index faac23bf2..047c6b435 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -97,7 +97,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -249,7 +249,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -268,7 +268,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 81bbc24fc..97e6c56d1 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -60,7 +60,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -132,7 +132,7 @@ def train_fn(epoch: int, env_step: int) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, @@ -149,7 +149,7 @@ def train_fn(epoch: int, env_step: int) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index d1afcd7f5..41882525b 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -84,7 +84,7 @@ def run_bdq(args: argparse.Namespace = get_args()) -> None: test_envs = SubprocVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) - for _ in range(args.test_num) + for _ in range(args.num_test_envs) ], ) # seed @@ -150,7 +150,7 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, @@ -168,7 +168,7 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay policy.set_eps_training(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 8687d8c3e..f27d72017 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -99,7 +99,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: test_envs = SubprocVectorEnv( [ lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False) - for _ in range(args.test_num) + for _ in range(args.num_test_envs) ], ) @@ -197,7 +197,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, test_in_train=False, @@ -212,7 +212,7 @@ def stop_fn(mean_rewards: float) -> bool: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 39b7f9b94..8ecbda311 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -62,7 +62,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -129,7 +129,7 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, @@ -146,7 +146,7 @@ def train_fn(epoch: int, env_step: int) -> None: # exp decay # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 77bc0d5d7..031b759fb 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -61,7 +61,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -146,7 +146,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, @@ -162,7 +162,7 @@ def stop_fn(mean_rewards: float) -> bool: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 13ec86685..22d35c4b7 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -112,7 +112,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: ) train_envs = VectorEnvNormObs(train_envs) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) test_envs.set_obs_rms(train_envs.get_obs_rms()) @@ -265,7 +265,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, @@ -278,7 +278,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index fadb7df2f..c77eaec97 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -117,7 +117,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: config_dict=vars(args), ) - env, train_envs, test_envs = make_fetch_env(args.task, args.num_train_envs, args.test_num) + env, train_envs, test_envs = make_fetch_env(args.task, args.num_train_envs, args.num_test_envs) # The method HER works with goal-based environments if not isinstance(env.observation_space, gym.spaces.Dict): raise ValueError( @@ -233,7 +233,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -246,7 +246,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) collector_stats.pprint_asdict() diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 40d917e5b..1abaa6123 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -76,7 +76,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -210,7 +210,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, @@ -223,7 +223,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index ca0f04bf6..f3ae4e968 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -71,7 +71,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=False, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -164,7 +164,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -177,7 +177,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 06265e42f..6beac506d 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -81,7 +81,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -208,7 +208,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, @@ -221,7 +221,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 60b5fd5d7..06b0ac904 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -81,7 +81,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -216,7 +216,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, @@ -229,7 +229,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index b8d472ed1..91f1c56ae 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -76,7 +76,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=False, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -189,7 +189,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -202,7 +202,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index e7758bcdb..81f3e527f 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -73,7 +73,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -189,7 +189,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, @@ -202,7 +202,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index d9849fb5e..16db8ebf5 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -72,7 +72,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=False, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -182,7 +182,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -195,7 +195,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index f62a177c0..030cfa05a 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -74,7 +74,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=False, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -183,7 +183,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -196,7 +196,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 707f174f2..59c6aa005 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -84,7 +84,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, obs_norm=True, ) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -213,7 +213,7 @@ def save_best_fn(policy: Algorithm) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, save_best_fn=save_best_fn, @@ -226,7 +226,7 @@ def save_best_fn(policy: Algorithm) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index d4f6015ca..8a5ec1ecf 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -80,7 +80,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, 1, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -191,7 +191,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -204,7 +204,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.update_per_epoch, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index d9a62f056..620b97882 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -80,7 +80,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, 1, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -177,7 +177,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -190,7 +190,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.update_per_epoch, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 57da6b6ba..bd2ff45a6 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -81,7 +81,7 @@ def main(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, 1, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -191,7 +191,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -204,7 +204,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.update_per_epoch, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 342d1c3eb..0819aed6c 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -74,7 +74,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: args.task, args.seed, 1, - args.test_num, + args.num_test_envs, scale=args.scale_obs, frame_stack=args.frames_stack, ) @@ -156,7 +156,7 @@ def watch() -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -169,7 +169,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.update_per_epoch, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 53ff6d2ce..1fd1c51be 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -91,7 +91,7 @@ def test_bcq() -> None: print("Max_action", args.max_action) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -216,7 +216,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -229,7 +229,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index cbe699495..51e965193 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -237,7 +237,7 @@ def test_cql() -> None: print("Max_action", args.max_action) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -356,7 +356,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -369,7 +369,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index e3993f244..3b05e5f1a 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -79,7 +79,7 @@ def test_il() -> None: args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -159,7 +159,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -172,7 +172,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 8a571606b..d0a4d42ae 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -93,7 +93,7 @@ def test_td3_bc() -> None: print("Max_action", args.max_action) test_envs: BaseVectorEnv - test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) if args.norm_obs: test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) @@ -207,7 +207,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, logger=logger, @@ -220,7 +220,7 @@ def watch() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect(n_episode=args.num_test_envs, render=args.render) print(collector_stats) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 3e29cf2ed..941a61c3c 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -83,7 +83,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: args.save_lmp, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, ) args.state_shape = env.observation_space.shape args.action_shape = env.action_space.shape or env.action_space.n @@ -193,7 +193,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -211,7 +211,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 68efc1eb6..499f548a1 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -112,7 +112,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.save_lmp, args.seed, args.num_train_envs, - args.test_num, + args.num_test_envs, ) args.state_shape = env.observation_space.shape args.action_shape = env.action_space.n @@ -267,7 +267,7 @@ def watch() -> None: else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.num_test_envs, render=args.render) result.pprint_asdict() if args.watch: @@ -286,7 +286,7 @@ def watch() -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 4139d1926..325b4ce23 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -65,7 +65,7 @@ def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = T env.spec.reward_threshold if env.spec else None, ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -131,7 +131,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 93ab52afd..e2ce35cd0 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -71,7 +71,7 @@ def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr env.spec.reward_threshold if env.spec else None, ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -151,7 +151,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index dd0065ec3..5a2538033 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -74,7 +74,7 @@ def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = T # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -159,7 +159,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 31fd973c1..0a0440d1d 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -72,10 +72,10 @@ def test_sac_with_il( ) -> None: # if you want to use python vector env, please refer to other test scripts # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.num_train_envs, seed=args.seed) - # test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) + # test_envs = envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed) env = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -165,7 +165,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, @@ -205,10 +205,10 @@ def stop_fn(mean_rewards: float) -> bool: optim=optim, ) il_test_env = gym.make(args.task) - il_test_env.reset(seed=args.seed + args.num_train_envs + args.test_num) + il_test_env.reset(seed=args.seed + args.num_train_envs + args.num_test_envs) il_test_collector = Collector[CollectStats]( il_algorithm, - # envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), + # envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed), il_test_env, ) train_collector.reset() @@ -219,7 +219,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.il_epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 60678f5c5..df532aed4 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -70,7 +70,7 @@ def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -147,7 +147,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 855a53def..b5e24ad30 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -74,7 +74,7 @@ def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = T # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -153,7 +153,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index ff255c096..923205d7f 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -81,7 +81,7 @@ def test_a2c_with_il( test_envs = envpool.make( args.task, env_type="gymnasium", - num_envs=args.test_num, + num_envs=args.num_test_envs, seed=args.seed, ) else: @@ -89,7 +89,7 @@ def test_a2c_with_il( train_envs = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.num_train_envs)] ) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) train_envs.seed(args.seed) test_envs.seed(args.seed) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -148,7 +148,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, @@ -183,12 +183,12 @@ def stop_fn(mean_rewards: float) -> bool: il_env = envpool.make( args.task, env_type="gymnasium", - num_envs=args.test_num, + num_envs=args.num_test_envs, seed=args.seed, ) else: il_env = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)], + [lambda: gym.make(args.task) for _ in range(args.num_test_envs)], ) il_env.seed(args.seed) @@ -204,7 +204,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.il_epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/discrete/test_bdqn.py b/test/discrete/test_bdqn.py index 1bf93cd87..ef766707c 100644 --- a/test/discrete/test_bdqn.py +++ b/test/discrete/test_bdqn.py @@ -82,7 +82,7 @@ def test_bdq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr test_envs = DummyVectorEnv( [ lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch) - for _ in range(args.test_num) + for _ in range(args.num_test_envs) ], ) @@ -142,7 +142,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index ed828d0ae..5e0977e6d 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -78,7 +78,7 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr env.spec.reward_threshold if env.spec else None, ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -195,7 +195,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, diff --git a/test/discrete/test_discrete_sac.py b/test/discrete/test_discrete_sac.py index dcec3863f..ffd59afa5 100644 --- a/test/discrete/test_discrete_sac.py +++ b/test/discrete/test_discrete_sac.py @@ -74,7 +74,7 @@ def test_discrete_sac( ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -145,7 +145,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 947fa5e66..092763e5b 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -72,7 +72,7 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr env.spec.reward_threshold if env.spec else None, ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -154,7 +154,7 @@ def train_fn(epoch: int, env_step: int) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 95f250bbd..89b6185f8 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -67,7 +67,7 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -129,7 +129,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, stop_fn=stop_fn, diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index c2fd72531..19fb5768b 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -77,7 +77,7 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr env.spec.reward_threshold if env.spec else None, ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -169,7 +169,7 @@ def train_fn(epoch: int, env_step: int) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index bcbe8dee8..0dadeb6bd 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -77,7 +77,7 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr env.spec.reward_threshold if env.spec else None, ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -165,7 +165,7 @@ def train_fn(epoch: int, env_step: int) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/test/discrete/test_ppo_discrete.py b/test/discrete/test_ppo_discrete.py index 23f60ed87..cb1e31c9f 100644 --- a/test/discrete/test_ppo_discrete.py +++ b/test/discrete/test_ppo_discrete.py @@ -78,7 +78,7 @@ def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -152,7 +152,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index c9ddcd985..f44562b2b 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -79,7 +79,7 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -160,7 +160,7 @@ def train_fn(epoch: int, env_step: int) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 418780908..4666d2299 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -85,7 +85,7 @@ def test_rainbow(args: argparse.Namespace = get_args(), enable_assertions: bool # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -214,7 +214,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, diff --git a/test/discrete/test_reinforce.py b/test/discrete/test_reinforce.py index e234d30fa..df0e1cf53 100644 --- a/test/discrete/test_reinforce.py +++ b/test/discrete/test_reinforce.py @@ -59,7 +59,7 @@ def test_reinforce(args: argparse.Namespace = get_args(), enable_assertions: boo env.spec.reward_threshold if env.spec else None, ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -120,7 +120,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 38640edfe..e343fd586 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -89,7 +89,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: env.spec.reward_threshold if env.spec else None, ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -194,7 +194,7 @@ def train_fn(epoch: int, env_step: int) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, train_fn=train_fn, diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 556bde6b6..1d99bf05c 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -95,7 +95,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -193,7 +193,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_env_steps=args.collection_step_num_env_steps, stop_fn=stop_fn, diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index a84718075..0a3efe8fc 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -53,7 +53,7 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: train_envs = env = envpool.make_gymnasium( args.task, num_envs=args.num_train_envs, seed=args.seed ) - test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) + test_envs = envpool.make_gymnasium(args.task, num_envs=args.num_test_envs, seed=args.seed) if args.reward_threshold is None: default_reward_threshold = {"NChain-v0": 3400} args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) @@ -121,7 +121,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=1, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=0, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index cda914cee..aed2b9428 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -82,7 +82,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -157,7 +157,7 @@ def train_fn(epoch: int, env_step: int) -> None: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, train_fn=train_fn, stop_fn=stop_fn, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 77b4ff4ee..8e7b56b4f 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -84,7 +84,7 @@ def gather_data() -> VectorReplayBuffer: # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -152,7 +152,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, update_step_num_gradient_steps_per_sample=args.update_per_step, save_best_fn=save_best_fn, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 1d5ce433c..a33442a49 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -91,7 +91,7 @@ def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -196,7 +196,7 @@ def watch() -> None: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 803d99c4f..d320b2bfb 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -97,7 +97,7 @@ def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -188,7 +188,7 @@ def stop_fn(mean_rewards: float) -> bool: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 1874a36fa..5c48ea017 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -72,7 +72,7 @@ def test_discrete_bcq( args.task, env.spec.reward_threshold if env.spec else None, ) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -162,7 +162,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 7fa959d11..12c0af017 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -69,7 +69,7 @@ def test_discrete_cql( args.task, env.spec.reward_threshold if env.spec else None, ) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -131,7 +131,7 @@ def stop_fn(mean_rewards: float) -> bool: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index aea9afce7..b547cc3d5 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -67,7 +67,7 @@ def test_discrete_crr( args.task, env.spec.reward_threshold if env.spec else None, ) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -132,7 +132,7 @@ def stop_fn(mean_rewards: float) -> bool: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 56e9d9d5d..a54fb5d3d 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -85,7 +85,7 @@ def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = T args.action_shape = space_info.action_info.action_shape args.max_action = space_info.action_info.max_action train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -206,7 +206,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 9c1ed615b..dfe4d6b70 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -88,7 +88,7 @@ def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = args.state_dim = space_info.action_info.action_dim args.action_dim = space_info.observation_info.obs_dim - test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) @@ -178,7 +178,7 @@ def stop_fn(mean_rewards: float) -> bool: test_collector=test_collector, max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, save_best_fn=save_best_fn, stop_fn=stop_fn, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index e1dc3f5d4..6b6f6edcd 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -128,7 +128,7 @@ def train_agent( optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[InfoStats, Algorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -170,7 +170,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 5b5fa1555..5f0fb5839 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -234,7 +234,7 @@ def train_agent( optims: list[torch.optim.Optimizer] | None = None, ) -> tuple[InfoStats, Algorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -275,7 +275,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, update_step_num_repetitions=args.update_step_num_repetitions, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, collection_step_num_episodes=args.collection_step_num_episodes, collection_step_num_env_steps=None, diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index cebabd4c1..de9cdeccc 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -162,7 +162,7 @@ def train_agent( optim: OptimizerFactory | None = None, ) -> tuple[InfoStats, OffPolicyAlgorithm]: train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -215,7 +215,7 @@ def reward_metric(rews: np.ndarray) -> np.ndarray: max_epochs=args.epoch, epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, - test_step_num_episodes=args.test_num, + test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, stop_fn=stop_fn, save_best_fn=save_best_fn, From f68b6f56752ad3dbe6bf8e7e708a079068ab649d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 13:42:25 +0200 Subject: [PATCH 216/230] v2: Disable determinism tests for CI --- test/determinism_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/determinism_test.py b/test/determinism_test.py index f71bcb6e5..754c4a6c9 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -39,7 +39,7 @@ class AlgorithmDeterminismTest: 3. Inspect determinism_tests.log """ - ENABLED = True + ENABLED = False """ whether determinism tests are enabled. """ From 556224ed083c0f2d1e6582e83996456654228a72 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 21:30:05 +0200 Subject: [PATCH 217/230] v2: Update parameter names (mainly test_num -> num_test_envs) --- examples/atari/atari_dqn_hl.py | 4 ++-- examples/atari/atari_iqn_hl.py | 4 ++-- examples/atari/atari_ppo_hl.py | 4 ++-- examples/atari/atari_sac_hl.py | 4 ++-- examples/discrete/discrete_dqn.py | 10 +++++----- examples/mujoco/fetch_her_ddpg.py | 4 ++-- examples/mujoco/mujoco_a2c_hl.py | 4 ++-- examples/mujoco/mujoco_ddpg_hl.py | 4 ++-- examples/mujoco/mujoco_ppo_hl.py | 4 ++-- examples/mujoco/mujoco_redq_hl.py | 4 ++-- examples/mujoco/mujoco_reinforce_hl.py | 4 ++-- examples/mujoco/mujoco_sac_hl.py | 4 ++-- examples/mujoco/mujoco_td3_hl.py | 12 ++++++------ examples/mujoco/mujoco_trpo_hl.py | 4 ++-- examples/vizdoom/env.py | 8 ++++---- test/continuous/test_sac_with_il.py | 2 +- test/determinism_test.py | 2 +- test/discrete/test_a2c_with_il.py | 2 +- tianshou/env/atari/atari_wrapper.py | 4 ++-- 19 files changed, 44 insertions(+), 44 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 3fb6822e0..eaa213db1 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -43,7 +43,7 @@ def main( update_per_step: float = 0.1, batch_size: int = 32, num_train_envs: int = 10, - test_num: int = 10, + num_test_envs: int = 10, frames_stack: int = 4, icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, @@ -56,7 +56,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 818675d49..ec884ba9c 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -44,7 +44,7 @@ def main( update_per_step: float = 0.1, batch_size: int = 32, num_train_envs: int = 10, - test_num: int = 10, + num_test_envs: int = 10, frames_stack: int = 4, ) -> None: log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag()) @@ -54,7 +54,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 0066df305..9079fe844 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -37,7 +37,7 @@ def main( batch_size: int = 256, hidden_sizes: Sequence[int] = (512,), num_train_envs: int = 10, - test_num: int = 10, + num_test_envs: int = 10, return_scaling: bool = False, vf_coef: float = 0.25, ent_coef: float = 0.01, @@ -62,7 +62,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_repetitions=update_step_num_repetitions, diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index a7e798415..97ab6a1f1 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -43,7 +43,7 @@ def main( batch_size: int = 64, hidden_sizes: Sequence[int] = (512,), num_train_envs: int = 10, - test_num: int = 10, + num_test_envs: int = 10, frames_stack: int = 4, icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, @@ -57,7 +57,7 @@ def main( update_step_num_gradient_steps_per_sample=update_per_step, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, replay_buffer_stack_num=frames_stack, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 03fab73a3..05604b626 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -12,7 +12,7 @@ def main() -> None: task = "CartPole-v1" lr, epoch, batch_size = 1e-3, 10, 64 - train_num, test_num = 10, 100 + num_train_envs, num_test_envs = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 @@ -22,8 +22,8 @@ def main() -> None: # For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html # You can also try SubprocVectorEnv, which will use parallelization - train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)]) - test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) + train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) + test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) from tianshou.utils.net.common import Net @@ -50,7 +50,7 @@ def main() -> None: train_collector = ts.data.Collector[CollectStats]( algorithm, train_envs, - ts.data.VectorReplayBuffer(buffer_size, train_num), + ts.data.VectorReplayBuffer(buffer_size, num_train_envs), exploration_noise=True, ) test_collector = ts.data.Collector[CollectStats]( @@ -74,7 +74,7 @@ def stop_fn(mean_rewards: float) -> bool: max_epochs=epoch, epoch_num_steps=epoch_num_steps, collection_step_num_env_steps=collection_step_num_env_steps, - test_step_num_episodes=test_num, + test_step_num_episodes=num_test_envs, batch_size=batch_size, update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, stop_fn=stop_fn, diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index c77eaec97..05fef72c6 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -83,14 +83,14 @@ def get_args() -> argparse.Namespace: def make_fetch_env( task: str, num_train_envs: int, - test_num: int, + num_test_envs: int, ) -> tuple[gym.Env, BaseVectorEnv, BaseVectorEnv]: env = TruncatedAsTerminated(gym.make(task)) train_envs = ShmemVectorEnv( [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_train_envs)], ) test_envs = ShmemVectorEnv( - [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)], + [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_test_envs)], ) return env, train_envs, test_envs diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 15573c90d..6a94ef110 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -32,7 +32,7 @@ def main( update_step_num_repetitions: int = 1, batch_size: int = 16, num_train_envs: int = 16, - test_num: int = 10, + num_test_envs: int = 10, return_scaling: bool = True, vf_coef: float = 0.5, ent_coef: float = 0.01, @@ -48,7 +48,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_repetitions=update_step_num_repetitions, diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 43afd139e..414faa145 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -34,7 +34,7 @@ def main( n_step: int = 1, batch_size: int = 256, num_train_envs: int = 1, - test_num: int = 10, + num_test_envs: int = 10, ) -> None: log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag()) @@ -43,7 +43,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 19ed339ab..d402044cd 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -31,7 +31,7 @@ def main( update_step_num_repetitions: int = 10, batch_size: int = 64, num_train_envs: int = 10, - test_num: int = 10, + num_test_envs: int = 10, return_scaling: bool = True, vf_coef: float = 0.25, ent_coef: float = 0.0, @@ -52,7 +52,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_repetitions=update_step_num_repetitions, diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 0537535d2..deb7270da 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -40,7 +40,7 @@ def main( batch_size: int = 256, target_mode: Literal["mean", "min"] = "min", num_train_envs: int = 1, - test_num: int = 10, + num_test_envs: int = 10, ) -> None: log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag()) @@ -49,7 +49,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_gradient_steps_per_sample=update_per_step, diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 963f53002..5edf6fc55 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -31,7 +31,7 @@ def main( update_step_num_repetitions: int = 1, batch_size: int | None = None, num_train_envs: int = 10, - test_num: int = 10, + num_test_envs: int = 10, return_scaling: bool = True, action_bound_method: Literal["clip", "tanh"] = "tanh", lr_decay: bool = True, @@ -43,7 +43,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_repetitions=update_step_num_repetitions, diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index da3e77655..84b51ed25 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -36,7 +36,7 @@ def main( n_step: int = 1, batch_size: int = 256, num_train_envs: int = 1, - test_num: int = 10, + num_test_envs: int = 10, ) -> None: log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) @@ -44,7 +44,7 @@ def main( max_epochs=epoch, epoch_num_steps=epoch_num_steps, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, batch_size=batch_size, collection_step_num_env_steps=collection_step_num_env_steps, diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 8a0483441..56898319e 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -8,7 +8,7 @@ from sensai.util.logging import datetime_tag from examples.mujoco.mujoco_env import MujocoEnvFactory -from tianshou.highlevel.config import TrainingConfig +from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, TD3ExperimentBuilder, @@ -37,23 +37,23 @@ def main( epoch: int = 200, epoch_num_steps: int = 5000, collection_step_num_env_steps: int = 1, - update_per_step: int = 1, + update_step_num_gradient_steps_per_sample: int = 1, n_step: int = 1, batch_size: int = 256, num_train_envs: int = 1, - test_num: int = 10, + num_test_envs: int = 10, ) -> None: log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag()) - training_config = TrainingConfig( + training_config = OffPolicyTrainingConfig( max_epochs=epoch, epoch_num_steps=epoch_num_steps, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, batch_size=batch_size, collection_step_num_env_steps=collection_step_num_env_steps, - update_per_step=update_per_step, + update_step_num_gradient_steps_per_sample=update_step_num_gradient_steps_per_sample, start_timesteps=start_timesteps, start_timesteps_random=True, ) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index afefff8c6..73c501ae8 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -31,7 +31,7 @@ def main( update_step_num_repetitions: int = 1, batch_size: int = 16, num_train_envs: int = 16, - test_num: int = 10, + num_test_envs: int = 10, return_scaling: bool = True, gae_lambda: float = 0.95, bound_action_method: Literal["clip", "tanh"] = "clip", @@ -49,7 +49,7 @@ def main( epoch_num_steps=epoch_num_steps, batch_size=batch_size, num_train_envs=num_train_envs, - num_test_envs=test_num, + num_test_envs=num_test_envs, buffer_size=buffer_size, collection_step_num_env_steps=collection_step_num_env_steps, update_step_num_repetitions=update_step_num_repetitions, diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 295193408..f8dcc8816 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -134,11 +134,11 @@ def make_vizdoom_env( save_lmp: bool = False, seed: int | None = None, num_train_envs: int = 10, - test_num: int = 10, + num_test_envs: int = 10, ) -> tuple[Env, ShmemVectorEnv, ShmemVectorEnv]: cpu_count = os.cpu_count() if cpu_count is not None: - test_num = min(cpu_count - 1, test_num) + num_test_envs = min(cpu_count - 1, num_test_envs) if envpool is not None: task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1" lmp_save_dir = "lmps/" if save_lmp else "" @@ -166,7 +166,7 @@ def make_vizdoom_env( stack_num=res[0], lmp_save_dir=lmp_save_dir, seed=seed, - num_envs=test_num, + num_envs=num_test_envs, reward_config=reward_config, use_combined_action=True, max_episode_steps=2625, @@ -179,7 +179,7 @@ def make_vizdoom_env( [lambda: Env(cfg_path, frame_skip, res) for _ in range(num_train_envs)], ) test_envs = ShmemVectorEnv( - [lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(test_num)], + [lambda: Env(cfg_path, frame_skip, res, save_lmp) for _ in range(num_test_envs)], ) train_envs.seed(seed) test_envs.seed(seed) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 0a0440d1d..5f68630fc 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -217,7 +217,7 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=il_test_collector, max_epochs=args.epoch, - epoch_num_steps=args.il_epoch_num_steps, + epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, diff --git a/test/determinism_test.py b/test/determinism_test.py index 754c4a6c9..6a5deb566 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -93,7 +93,7 @@ def set(attr: str, value: Any) -> None: set("device", "cpu") if not is_offline: set("num_train_envs", 1) - set("test_num", 1) + set("num_test_envs", 1) self.args = args self.main_fn = main_fn diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 923205d7f..e865b0c7c 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -202,7 +202,7 @@ def stop_fn(mean_rewards: float) -> bool: train_collector=train_collector, test_collector=il_test_collector, max_epochs=args.epoch, - epoch_num_steps=args.il_epoch_num_steps, + epoch_num_steps=args.epoch_num_steps, collection_step_num_env_steps=args.collection_step_num_env_steps, test_step_num_episodes=args.num_test_envs, batch_size=args.batch_size, diff --git a/tianshou/env/atari/atari_wrapper.py b/tianshou/env/atari/atari_wrapper.py index 10f5599bd..320ca820a 100644 --- a/tianshou/env/atari/atari_wrapper.py +++ b/tianshou/env/atari/atari_wrapper.py @@ -380,7 +380,7 @@ def make_atari_env( task: str, seed: int, num_train_envs: int, - test_num: int, + num_test_envs: int, scale: int | bool = False, frame_stack: int = 4, ) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: @@ -391,7 +391,7 @@ def make_atari_env( :return: a tuple of (single env, training envs, test envs). """ env_factory = AtariEnvFactory(task, frame_stack, scale=bool(scale)) - envs = env_factory.create_envs(num_train_envs, test_num, seed=seed) + envs = env_factory.create_envs(num_train_envs, num_test_envs, seed=seed) return envs.env, envs.train_envs, envs.test_envs From d5960cf371c57dfd78e153558d81b0d870f2e065 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 20 May 2025 03:13:15 +0200 Subject: [PATCH 218/230] v2: Fix logic error introduced in commit 03123510 --- tianshou/algorithm/modelfree/bdqn.py | 2 +- tianshou/algorithm/modelfree/dqn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/algorithm/modelfree/bdqn.py b/tianshou/algorithm/modelfree/bdqn.py index 4b09698d0..9ab9b34c3 100644 --- a/tianshou/algorithm/modelfree/bdqn.py +++ b/tianshou/algorithm/modelfree/bdqn.py @@ -83,7 +83,7 @@ def add_exploration_noise( batch: ObsBatchProtocol, ) -> TArrOrActBatch: eps = self.eps_training if self.is_within_training_step else self.eps_inference - if not np.isclose(eps, 0.0): + if np.isclose(eps, 0.0): return act if isinstance(act, np.ndarray): bsz = len(act) diff --git a/tianshou/algorithm/modelfree/dqn.py b/tianshou/algorithm/modelfree/dqn.py index a7b6be867..530de6ef0 100644 --- a/tianshou/algorithm/modelfree/dqn.py +++ b/tianshou/algorithm/modelfree/dqn.py @@ -156,7 +156,7 @@ def add_exploration_noise( batch: ObsBatchProtocol, ) -> TArrOrActBatch: eps = self.eps_training if self.is_within_training_step else self.eps_inference - if not np.isclose(eps, 0.0): + if np.isclose(eps, 0.0): return act if isinstance(act, np.ndarray): batch_size = len(act) From 6c3abb0cbc18988ea3d56f101f4308d642035842 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 20 May 2025 12:41:13 +0200 Subject: [PATCH 219/230] v2: Handle nested algorithms in Algorithm.state_dict --- tianshou/algorithm/algorithm_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tianshou/algorithm/algorithm_base.py b/tianshou/algorithm/algorithm_base.py index eed5e4b49..50884d646 100644 --- a/tianshou/algorithm/algorithm_base.py +++ b/tianshou/algorithm/algorithm_base.py @@ -522,8 +522,9 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): # ty d = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) # add optimizer states - assert self._STATE_DICT_KEY_OPTIMIZERS not in d - d[self._STATE_DICT_KEY_OPTIMIZERS] = [o.state_dict() for o in self._optimizers] + opt_key = prefix + self._STATE_DICT_KEY_OPTIMIZERS + assert opt_key not in d + d[opt_key] = [o.state_dict() for o in self._optimizers] return d From 78d52ed580978221133169d858676b3a0b16a619 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 20 May 2025 13:23:38 +0200 Subject: [PATCH 220/230] v2: Update identifier names (policy -> algorithm) --- tianshou/highlevel/algorithm.py | 16 ++++++++-------- tianshou/highlevel/experiment.py | 4 ++-- tianshou/highlevel/persistence.py | 4 ++-- tianshou/highlevel/world.py | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 3e22e5320..9282d1981 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -175,15 +175,15 @@ def _create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: pass def create_algorithm(self, envs: Environments, device: TDevice) -> Algorithm: - policy = self._create_algorithm(envs, device) + algorithm = self._create_algorithm(envs, device) if self.algorithm_wrapper_factory is not None: - policy = self.algorithm_wrapper_factory.create_wrapped_algorithm( - policy, + algorithm = self.algorithm_wrapper_factory.create_wrapped_algorithm( + algorithm, envs, self.optim_factory, device, ) - return policy + return algorithm @abstractmethod def create_trainer(self, world: World, policy_persistence: PolicyPersistence) -> Trainer: @@ -198,7 +198,7 @@ def create_trainer( ) -> OnPolicyTrainer: training_config = self.training_config callbacks = self.trainer_callbacks - context = TrainingContext(world.policy, world.envs, world.logger) + context = TrainingContext(world.algorithm, world.envs, world.logger) train_fn = ( callbacks.epoch_train_callback.get_trainer_fn(context) if callbacks.epoch_train_callback @@ -214,7 +214,7 @@ def create_trainer( if callbacks.epoch_stop_callback else None ) - algorithm = cast(OnPolicyAlgorithm, world.policy) + algorithm = cast(OnPolicyAlgorithm, world.algorithm) assert world.train_collector is not None return algorithm.create_trainer( OnPolicyTrainerParams( @@ -246,7 +246,7 @@ def create_trainer( ) -> OffPolicyTrainer: training_config = self.training_config callbacks = self.trainer_callbacks - context = TrainingContext(world.policy, world.envs, world.logger) + context = TrainingContext(world.algorithm, world.envs, world.logger) train_fn = ( callbacks.epoch_train_callback.get_trainer_fn(context) if callbacks.epoch_train_callback @@ -262,7 +262,7 @@ def create_trainer( if callbacks.epoch_stop_callback else None ) - algorithm = cast(OffPolicyAlgorithm, world.policy) + algorithm = cast(OffPolicyAlgorithm, world.algorithm) assert world.train_collector is not None return algorithm.create_trainer( OffPolicyTrainerParams( diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 9b38f16e0..bcd933100 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -350,7 +350,7 @@ def create_experiment_world( # create context object with all relevant instances (except trainer; added later) world = World( envs=envs, - policy=policy, + algorithm=policy, train_collector=train_collector, test_collector=test_collector, logger=logger, @@ -448,7 +448,7 @@ def run( log.info("Watching agent performance") self._watch_agent( self.config.watch_num_episodes, - world.policy, + world.algorithm, world.envs.watch_env, self.config.watch_render, ) diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 2758f5066..1c38d3602 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -141,10 +141,10 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: match self.mode: case self.Mode.POLICY_STATE_DICT: log.info(f"Saving policy state dictionary in {path}") - torch.save(world.policy.state_dict(), path) + torch.save(world.algorithm.state_dict(), path) case self.Mode.POLICY: log.info(f"Saving policy object in {path}") - torch.save(world.policy, path) + torch.save(world.algorithm, path) case _: raise NotImplementedError if self.additional_persistence is not None: diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 7f3521773..98ac46dea 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -15,7 +15,7 @@ class World: """Container for instances and configuration items that are relevant to an experiment.""" envs: "Environments" - policy: "Algorithm" + algorithm: "Algorithm" train_collector: Optional["BaseCollector"] = None test_collector: Optional["BaseCollector"] = None logger: "TLogger" From 989ecc67aa1da3261da4467bc5e5b5964b52dd08 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 20 May 2025 13:26:01 +0200 Subject: [PATCH 221/230] v2: Rename hl module: policy_wrapper -> algorithm_wrapper --- examples/atari/atari_dqn_hl.py | 2 +- examples/atari/atari_ppo_hl.py | 2 +- examples/atari/atari_sac_hl.py | 2 +- tianshou/highlevel/algorithm.py | 2 +- tianshou/highlevel/experiment.py | 2 +- .../params/{policy_wrapper.py => algorithm_wrapper.py} | 0 6 files changed, 5 insertions(+), 5 deletions(-) rename tianshou/highlevel/params/{policy_wrapper.py => algorithm_wrapper.py} (100%) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index eaa213db1..4310deff7 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -16,7 +16,7 @@ ExperimentConfig, ) from tianshou.highlevel.params.algorithm_params import DQNParams -from tianshou.highlevel.params.policy_wrapper import ( +from tianshou.highlevel.params.algorithm_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) from tianshou.highlevel.trainer import ( diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 9079fe844..dbbd6f7a9 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -18,7 +18,7 @@ ) from tianshou.highlevel.params.algorithm_params import PPOParams from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear -from tianshou.highlevel.params.policy_wrapper import ( +from tianshou.highlevel.params.algorithm_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 97ab6a1f1..4b1376466 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -18,7 +18,7 @@ ) from tianshou.highlevel.params.algorithm_params import DiscreteSACParams from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault -from tianshou.highlevel.params.policy_wrapper import ( +from tianshou.highlevel.params.algorithm_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index 9282d1981..a05dc8ff3 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -69,7 +69,7 @@ TD3Params, TRPOParams, ) -from tianshou.highlevel.params.policy_wrapper import AlgorithmWrapperFactory +from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory from tianshou.highlevel.persistence import PolicyPersistence from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.world import World diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index bcd933100..d97b92250 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -101,7 +101,7 @@ TD3Params, TRPOParams, ) -from tianshou.highlevel.params.policy_wrapper import AlgorithmWrapperFactory +from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory from tianshou.highlevel.persistence import ( PersistenceGroup, PolicyPersistence, diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/algorithm_wrapper.py similarity index 100% rename from tianshou/highlevel/params/policy_wrapper.py rename to tianshou/highlevel/params/algorithm_wrapper.py From 3fb51cd3cdd4f02ab3c266cd49144b765b8d8201 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 20 May 2025 13:29:53 +0200 Subject: [PATCH 222/230] v2: HL: Move optim module to params package --- examples/atari/atari_ppo_hl.py | 2 +- examples/atari/atari_sac_hl.py | 2 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- tianshou/highlevel/algorithm.py | 2 +- tianshou/highlevel/experiment.py | 8 ++++---- tianshou/highlevel/params/algorithm_params.py | 2 +- tianshou/highlevel/params/algorithm_wrapper.py | 2 +- tianshou/highlevel/params/alpha.py | 2 +- tianshou/highlevel/{ => params}/optim.py | 0 9 files changed, 11 insertions(+), 11 deletions(-) rename tianshou/highlevel/{ => params}/optim.py (100%) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index dbbd6f7a9..393040e54 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -17,10 +17,10 @@ PPOExperimentBuilder, ) from tianshou.highlevel.params.algorithm_params import PPOParams -from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear from tianshou.highlevel.params.algorithm_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear def main( diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 4b1376466..b21ed5e44 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -17,10 +17,10 @@ ExperimentConfig, ) from tianshou.highlevel.params.algorithm_params import DiscreteSACParams -from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.algorithm_wrapper import ( AlgorithmWrapperFactoryIntrinsicCuriosity, ) +from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault def main( diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 6a94ef110..6922a1209 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -14,9 +14,9 @@ A2CExperimentBuilder, ExperimentConfig, ) -from tianshou.highlevel.optim import OptimizerFactoryFactoryRMSprop from tianshou.highlevel.params.algorithm_params import A2CParams from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactoryLinear +from tianshou.highlevel.params.optim import OptimizerFactoryFactoryRMSprop def main( diff --git a/tianshou/highlevel/algorithm.py b/tianshou/highlevel/algorithm.py index a05dc8ff3..deaa2cd35 100644 --- a/tianshou/highlevel/algorithm.py +++ b/tianshou/highlevel/algorithm.py @@ -50,7 +50,6 @@ TDevice, ) from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory -from tianshou.highlevel.optim import OptimizerFactoryFactory from tianshou.highlevel.params.algorithm_params import ( A2CParams, DDPGParams, @@ -70,6 +69,7 @@ TRPOParams, ) from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory +from tianshou.highlevel.params.optim import OptimizerFactoryFactory from tianshou.highlevel.persistence import PolicyPersistence from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext from tianshou.highlevel.world import World diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index d97b92250..2828f818b 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -83,10 +83,6 @@ ) from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory -from tianshou.highlevel.optim import ( - OptimizerFactoryFactory, - OptimizerFactoryFactoryAdam, -) from tianshou.highlevel.params.algorithm_params import ( A2CParams, DDPGParams, @@ -102,6 +98,10 @@ TRPOParams, ) from tianshou.highlevel.params.algorithm_wrapper import AlgorithmWrapperFactory +from tianshou.highlevel.params.optim import ( + OptimizerFactoryFactory, + OptimizerFactoryFactoryAdam, +) from tianshou.highlevel.persistence import ( PersistenceGroup, PolicyPersistence, diff --git a/tianshou/highlevel/params/algorithm_params.py b/tianshou/highlevel/params/algorithm_params.py index 88fc7a060..4c1c81dad 100644 --- a/tianshou/highlevel/params/algorithm_params.py +++ b/tianshou/highlevel/params/algorithm_params.py @@ -8,11 +8,11 @@ from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice -from tianshou.highlevel.optim import OptimizerFactoryFactory from tianshou.highlevel.params.alpha import AutoAlphaFactory from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryFactory from tianshou.highlevel.params.noise import NoiseFactory +from tianshou.highlevel.params.optim import OptimizerFactoryFactory @dataclass(kw_only=True) diff --git a/tianshou/highlevel/params/algorithm_wrapper.py b/tianshou/highlevel/params/algorithm_wrapper.py index 9ff957d49..a5c287fd4 100644 --- a/tianshou/highlevel/params/algorithm_wrapper.py +++ b/tianshou/highlevel/params/algorithm_wrapper.py @@ -10,7 +10,7 @@ from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.intermediate import IntermediateModuleFactory -from tianshou.highlevel.optim import OptimizerFactoryFactory +from tianshou.highlevel.params.optim import OptimizerFactoryFactory from tianshou.utils.net.discrete import IntrinsicCuriosityModule TAlgorithmOut = TypeVar("TAlgorithmOut", bound=Algorithm) diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index f1a36db3e..61c86cf24 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -6,7 +6,7 @@ from tianshou.algorithm.modelfree.sac import Alpha, AutoAlpha from tianshou.highlevel.env import Environments from tianshou.highlevel.module.core import TDevice -from tianshou.highlevel.optim import OptimizerFactoryFactory +from tianshou.highlevel.params.optim import OptimizerFactoryFactory class AutoAlphaFactory(ToStringMixin, ABC): diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/params/optim.py similarity index 100% rename from tianshou/highlevel/optim.py rename to tianshou/highlevel/params/optim.py From 6ebb6def7d21d28e9091fc183b7e99ae93616471 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 20 May 2025 15:25:39 +0200 Subject: [PATCH 223/230] v2: Add issue references --- CHANGELOG.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f1048f80..48de4525a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ Developers: * The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely the methods and attributes a user should reasonably access. * Further changes potentially affecting usage: - * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. + * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. #913 * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full episodes) * See also "Issues resolved" below (as issue resolution can result in usage changes) @@ -94,16 +94,16 @@ Developers: * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. * Interface changes/improvements: - * Core methods have been renamed (and removed from the public interface): + * Core methods have been renamed (and removed from the public interface; #898): * `process_fn` -> `_preprocess_batch` * `post_process_fn` -> `_postprocess_batch` * `learn` -> `_update_with_batch` - * The updating interface has been cleaned up: + * The updating interface has been cleaned up (#949): * Functions `update` and `_update_with_batch` (formerly `learn`) no longer have `*args` and `**kwargs`. * Instead, the interfaces for the offline, off-policy and on-policy cases are properly differentiated. * New method `run_training`: The `Algorithm` abstraction can now directly initiate the learning process via this method. * `Algorithms` no longer require `torch.optim.Optimizer` instances and instead require `OptimizerFactory` - instances, which create the actual optimizers internally. + instances, which create the actual optimizers internally. #959 The new `OptimizerFactory` abstraction simultaneously handles the creation of learning rate schedulers for the optimizers created (via method `with_lr_scheduler_factory` and accompanying factory abstraction `LRSchedulerFactory`). @@ -196,12 +196,12 @@ Developers: ### Peripheral Changes -* The `Actor` classes have been renamed for clarity: +* The `Actor` classes have been renamed for clarity (#1091): * `BaseActor` -> `Actor` * `continuous.ActorProb` -> `ContinuousActorProbabilistic` * `coninuous.Actor` -> `ContinuousActorDeterministic` * `discrete.Actor` -> `DiscreteActor` -* The `Critic` classes have been renamed for clarity: +* The `Critic` classes have been renamed for clarity (#1091): * `continuous.Critic` -> `ContinuousCritic` * `discrete.Critic` -> `DiscreteCritic` * Moved Atari helper modules `atari_network` and `atari_wrapper` to the library under `tianshou.env.atari`. @@ -218,9 +218,9 @@ Developers: dimension as an argument were changed to use `ModuleWithVectorOutput`. * The high-level API class `IntermediateModule` can now provide a `ModuleWithVectorOutput` instance (via adaptation if necessary). -* The class hierarchy of supporting `nn.Module` implementations was cleaned up: +* The class hierarchy of supporting `nn.Module` implementations was cleaned up (#1091): * With the fundamental base classes `ActionReprNet` and `ActionReprNetWithVectorOutput`, we etablished a - well-defined interface for the most commonly used `forward` interface in Tianshou's algorithms & policies. + well-defined interface for the most commonly used `forward` interface in Tianshou's algorithms & policies. #948 * Some network classes were renamed: * `ScaledObsInputModule` -> `ScaledObsInputActionReprNet` * `Rainbow` -> `RainbowNet` From 87cc6faf797168ccb3e8eca0fdb654b683ae9e13 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 3 Jul 2025 14:00:21 +0200 Subject: [PATCH 224/230] v2: Set version to 2.0.0b1 --- pyproject.toml | 2 +- tianshou/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7e8634da0..2b76846b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "tianshou" -version = "1.2.0" +version = "2.0.0b1" description = "A Library for Deep Reinforcement Learning" authors = ["TSAIL "] license = "MIT" diff --git a/tianshou/__init__.py b/tianshou/__init__.py index bee88ccca..3c719b1eb 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -2,7 +2,7 @@ # NOTE: Import order is important to avoid circular import errors! from tianshou import data, env, exploration, algorithm, trainer, utils -__version__ = "1.2.0" +__version__ = "2.0.0b1" def _register_log_config_callback() -> None: From 6bf4f031a8e50e7760af63b9d7acc7e2f9b37cad Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 8 Jul 2025 13:36:58 +0200 Subject: [PATCH 225/230] v2: adjusted dqn.rst to reflect the new API --- docs/01_tutorials/00_dqn.rst | 101 +++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/docs/01_tutorials/00_dqn.rst b/docs/01_tutorials/00_dqn.rst index 79cd2d903..3c28e7163 100644 --- a/docs/01_tutorials/00_dqn.rst +++ b/docs/01_tutorials/00_dqn.rst @@ -129,10 +129,15 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour logits = self.model(obs.view(batch, -1)) return logits, state - state_shape = env.observation_space.shape or env.observation_space.n - action_shape = env.action_space.shape or env.action_space.n - net = MLPActor(state_shape, action_shape) - optim = torch.optim.Adam(net.parameters(), lr=1e-3) + from tianshou.utils.net.common import Net + from tianshou.utils.space_info import SpaceInfo + from tianshou.algorithm.optim import AdamOptimizerFactory + + space_info = SpaceInfo.from_env(env) + state_shape = space_info.observation_info.obs_shape + action_shape = space_info.action_info.action_shape + net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128, 128]) + optim = AdamOptimizerFactory(lr=1e-3) You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: @@ -150,13 +155,22 @@ Setup Policy We use the defined ``net`` and ``optim`` above, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with a target network: :: - policy = ts.policy.DQNPolicy( + from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy + from tianshou.algorithm import DQN + + policy = DiscreteQLearningPolicy( model=net, - optim=optim, action_space=env.action_space, - discount_factor=0.9, - estimation_step=3, - target_update_freq=320 + observation_space=env.observation_space, + eps_training=0.1, + eps_inference=0.05, + ) + algorithm = DQN( + policy=policy, + optim=optim, + gamma=0.9, + n_step_return_horizon=3, + target_update_freq=320, ) @@ -170,8 +184,11 @@ The following code shows how to set up a collector in practice. It is worth noti :: - train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True) - test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) + from tianshou.data import Collector, CollectStats, VectorReplayBuffer + + buf = VectorReplayBuffer(20000, buffer_num=len(train_envs)) + train_collector = Collector[CollectStats](algorithm, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True) The main function of collector is the collect function, which can be summarized in the following lines: @@ -194,17 +211,29 @@ reaches the stop condition ``stop_fn`` on test collector. Since DQN is an off-po :class:`~tianshou.trainer.OffpolicyTrainer` as follows: :: - result = ts.trainer.OffpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=10, epoch_num_steps=10000, collection_step_num_env_steps=10, - update_per_step=0.1, episode_per_test=100, batch_size=64, - train_fn=lambda epoch, env_step: policy.set_eps(0.1), - test_fn=lambda epoch, env_step: policy.set_eps(0.05), - stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold - ).run() - print(f'Finished training! Use {result["duration"]}') + from tianshou.trainer import OffPolicyTrainerParams + + def train_fn(epoch, env_step): + policy.set_eps_training(0.1) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + result = algorithm.run_training( + OffPolicyTrainerParams( + train_collector=train_collector, + test_collector=test_collector, + max_epochs=10, + epoch_num_steps=10000, + collection_step_num_env_steps=10, + test_step_num_episodes=100, + batch_size=64, + update_step_num_gradient_steps_per_sample=0.1, + train_fn=train_fn, + stop_fn=stop_fn, + ) + ) + print(f'Finished training! Use {result.duration}') The meaning of each parameter is as follows (full description can be found at :class:`~tianshou.trainer.OffpolicyTrainer`): @@ -232,18 +261,10 @@ The returned result is a dictionary as follows: :: { - 'train_step': 9246, - 'train_episode': 504.0, - 'train_time/collector': '0.65s', - 'train_time/model': '1.97s', - 'train_speed': '3518.79 step/s', - 'test_step': 49112, - 'test_episode': 400.0, - 'test_time': '1.38s', - 'test_speed': '35600.52 step/s', - 'best_reward': 199.03, - 'duration': '4.01s' - } + TrainingResult object with attributes like: + best_reward: 199.03 + duration: 4.01s + And other training statistics It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03. @@ -265,8 +286,8 @@ Watch the Agent's Performance :: policy.eval() - policy.set_eps(0.05) - collector = ts.data.Collector(policy, env, exploration_noise=True) + policy.set_eps_inference(0.05) + collector = ts.data.Collector(algorithm, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) If you'd like to manually see the action generated by a well-trained agent: @@ -289,24 +310,24 @@ Tianshou supports user-defined training code. Here is the code snippet: # pre-collect at least 5000 transitions with random action before training train_collector.collect(n_step=5000, random=True) - policy.set_eps(0.1) + policy.set_eps_training(0.1) for i in range(int(1e6)): # total step collect_result = train_collector.collect(n_step=10) # once if the collected episodes' mean returns reach the threshold, # or every 1000 steps, we test it on test_collector if collect_result['rews'].mean() >= env.spec.reward_threshold or i % 1000 == 0: - policy.set_eps(0.05) + policy.set_eps_inference(0.05) result = test_collector.collect(n_episode=100) if result['rews'].mean() >= env.spec.reward_threshold: print(f'Finished training! Test mean returns: {result["rews"].mean()}') break else: # back to training eps - policy.set_eps(0.1) + policy.set_eps_training(0.1) # train policy with a sampled batch data from buffer - losses = policy.update(64, train_collector.buffer) + losses = algorithm.update(64, train_collector.buffer) For further usage, you can refer to the :doc:`/01_tutorials/07_cheatsheet`. From 16270cbcfc1b2bb9aa5cded4a4d7c72c22117418 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 14 Jul 2025 11:30:24 +0200 Subject: [PATCH 226/230] v2: Docs. Improved concepts_rst, mentioned that parts of the docs are outdated --- README.md | 6 +- docs/01_tutorials/01_concepts.rst | 167 ++++++++++++++-------------- docs/01_tutorials/07_cheatsheet.rst | 2 + docs/02_notebooks/0_intro.md | 2 + 4 files changed, 89 insertions(+), 88 deletions(-) diff --git a/README.md b/README.md index 573f1caae..02ef2f052 100644 --- a/README.md +++ b/README.md @@ -148,9 +148,11 @@ If no errors are reported, you have successfully installed Tianshou. ## Documentation -Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/). +Find example scripts in the [test/]( https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders. -Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders. +Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/). +**Important**: The documentation is currently being updated to reflect the changes in Tianshou v2.0.0. Not all features are documented yet, and some parts are outdated (they are marked as such). The documentation will be fully updated when +the v2.0.0 release is finalized. ## Why Tianshou? diff --git a/docs/01_tutorials/01_concepts.rst b/docs/01_tutorials/01_concepts.rst index 0f381262e..4204672b7 100644 --- a/docs/01_tutorials/01_concepts.rst +++ b/docs/01_tutorials/01_concepts.rst @@ -1,18 +1,22 @@ Basic concepts in Tianshou ========================== -Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as: +Tianshou splits a Reinforcement Learning agent training procedure into these parts: algorithm, trainer, collector, policy, a data buffer and batches from the buffer. +The algorithm encapsulates the specific RL learning method (e.g., DQN, PPO), which contains a policy and defines how to update it. -.. image:: /_static/images/concepts_arch.png - :align: center - :height: 300 +.. + The general control flow can be described as: + .. image:: /_static/images/concepts_arch.png + :align: center + :height: 300 -Here is a more detailed description, where ``Env`` is the environment and ``Model`` is the neural network: -.. image:: /_static/images/concepts_arch2.png - :align: center - :height: 300 + Here is a more detailed description, where ``Env`` is the environment and ``Model`` is the neural network: + + .. image:: /_static/images/concepts_arch2.png + :align: center + :height: 300 Batch @@ -220,19 +224,28 @@ The following code snippet illustrates the usage, including: Tianshou provides other type of data buffer such as :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``) and :class:`~tianshou.data.VectorReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. -Policy ------- +Algorithm and Policy +-------------------- + +Tianshou's RL framework is built around two key abstractions: :class:`~tianshou.algorithm.Algorithm` and :class:`~tianshou.algorithm.Policy`. + +**Algorithm**: The core abstraction that encapsulates a complete RL learning method (e.g., DQN, PPO, SAC). Each algorithm contains a policy and defines how to update it using training data. All algorithm classes inherit from :class:`~tianshou.algorithm.Algorithm`. + +An algorithm class typically has the following parts: -Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.algorithm.BasePolicy`. +* :meth:`~tianshou.algorithm.Algorithm.__init__`: initialize the algorithm with a policy and optimization configuration; +* :meth:`~tianshou.algorithm.Algorithm._preprocess_batch`: pre-process data from the replay buffer (e.g., compute n-step returns); +* :meth:`~tianshou.algorithm.Algorithm._update_with_batch`: the algorithm-specific network update logic; +* :meth:`~tianshou.algorithm.Algorithm._postprocess_batch`: post-process the batch data (e.g., update prioritized replay buffer weights); +* :meth:`~tianshou.algorithm.Algorithm.create_trainer`: create the appropriate trainer for this algorithm; -A policy class typically has the following parts: +**Policy**: Represents the mapping from observations to actions. Policy classes inherit from :class:`~tianshou.algorithm.Policy`. -* :meth:`~tianshou.algorithm.BasePolicy.__init__`: initialize the policy, including copying the target network and so on; -* :meth:`~tianshou.algorithm.BasePolicy.forward`: compute action with given observation; -* :meth:`~tianshou.algorithm.BasePolicy.process_fn`: pre-process data from the replay buffer; -* :meth:`~tianshou.algorithm.BasePolicy.learn`: update policy with a given batch of data. -* :meth:`~tianshou.algorithm.BasePolicy.post_process_fn`: update the buffer with a given batch of data. -* :meth:`~tianshou.algorithm.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. +A policy class typically provides: + +* :meth:`~tianshou.algorithm.Policy.forward`: compute action distribution or Q-values given observations; +* :meth:`~tianshou.algorithm.Policy.compute_action`: get concrete actions from observations for environment interaction; +* :meth:`~tianshou.algorithm.Policy.map_action`: transform raw network outputs to environment action space; .. _policy_state: @@ -245,22 +258,10 @@ During the training process, the policy has two main states: training state and The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process. As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer; -we define the updating state as performing a model update by :meth:`~tianshou.algorithm.BasePolicy.update` during training process. - - -In order to distinguish these states, you can check the policy state by ``policy.training`` and ``policy.updating``. The state setting is as follows: +we define the updating state as performing a model update by the algorithm's update methods during training process. -+-----------------------------------+-----------------+-----------------+ -| State for policy | policy.training | policy.updating | -+================+==================+=================+=================+ -| | Collecting state | True | False | -| Training state +------------------+-----------------+-----------------+ -| | Updating state | True | True | -+----------------+------------------+-----------------+-----------------+ -| Testing state | False | False | -+-----------------------------------+-----------------+-----------------+ - -``policy.updating`` is helpful to distinguish the different exploration state, for example, in DQN we don't have to use epsilon-greedy in a pure network update, so ``policy.updating`` is helpful for setting epsilon in this case. +The collection of data from the env may differ in training and in inference (for example, in training one may add exploration noise, or sample from the predicted action distribution instead of taking its mode). The switch between the different collection strategies in training and inference is controlled by ``policy.is_within_training_step``, see also the docstring of it +for more details. policy.forward @@ -282,15 +283,17 @@ For example, if you try to use your policy to evaluate one episode (and don't wa act = policy(batch).act[0] # policy.forward return a batch, use ".act" to extract the action obs, rew, done, info = env.step(act) +For inference, it is recommended to use the shortcut method :meth:`~tianshou.algorithm.Policy.compute_action` to compute the action directly from the observation. + Here, ``Batch(obs=[obs])`` will automatically create the 0-dimension to be the batch-size. Otherwise, the network cannot determine the batch-size. .. _process_fn: -policy.process_fn -^^^^^^^^^^^^^^^^^ +Algorithm Preprocessing and N-step Returns +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns. +The algorithm handles data preprocessing, including computing variables that depend on time-series such as N-step or GAE returns. This functionality is implemented in :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` and the static methods :meth:`~tianshou.algorithm.Algorithm.compute_nstep_return` and :meth:`~tianshou.algorithm.Algorithm.compute_episodic_return`. Take 2-step return DQN as an example. The 2-step return DQN compute each transition's return as: @@ -304,40 +307,19 @@ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is # pseudocode, cannot work obs = env.reset() buffer = Buffer(size=10000) - agent = DQN() + algorithm = DQN(...) for i in range(int(1e6)): - act = agent.compute_action(obs) + act = algorithm.policy.compute_action(obs) obs_next, rew, done, _ = env.step(act) buffer.store(obs, act, obs_next, rew, done) obs = obs_next if i % 1000 == 0: - b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) - # compute 2-step returns. How? - b_ret = compute_2_step_return(buffer, b_r, b_d, ...) - # update DQN policy - agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) - -Thus, we need a time-related interface for calculating the 2-step return. :meth:`~tianshou.algorithm.BasePolicy.process_fn` finishes this work by providing the replay buffer, the sample index, and the sample batch data. Since we store all the data in the order of time, you can simply compute the 2-step return as: -:: + # algorithm handles sampling, preprocessing, and updating + algorithm.update(sample_size=64, buffer=buffer) - class DQN_2step(BasePolicy): - """some code""" +The algorithm's :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` method automatically handles n-step return computation by calling :meth:`~tianshou.algorithm.Algorithm.compute_nstep_return`, which provides the replay buffer, sample indices, and batch data. Since we store all the data in the order of time, the n-step return can be computed efficiently using the buffer's temporal structure. - def process_fn(self, batch, buffer, indices): - buffer_len = len(buffer) - batch_2 = buffer[(indices + 2) % buffer_len] - # this will return a batch data where batch_2.obs is s_t+2 - # we can also get s_t+2 through: - # batch_2_obs = buffer.obs[(indices + 2) % buffer_len] - # in short, buffer.obs[i] is equal to buffer[i].obs, but the former is more effecient. - Q = self(batch_2, eps=0) # shape: [batchsize, action_shape] - maxQ = Q.max(dim=-1) - batch.returns = batch.rew \ - + self._gamma * buffer.rew[(indices + 1) % buffer_len] \ - + self._gamma ** 2 * maxQ - return batch - -This code does not consider the done flag, so it may not work very well. It shows two ways to get :math:`s_{t + 2}` from the replay buffer easily in :meth:`~tianshou.algorithm.BasePolicy.process_fn`. +For custom preprocessing logic, you can override :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` in your algorithm subclass. The method receives the sampled batch, buffer, and indices, allowing you to add computed values like returns, advantages, or other algorithm-specific preprocessing steps. Collector @@ -378,29 +360,33 @@ There is also another type of collector :class:`~tianshou.data.AsyncCollector` w Trainer ------- -Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`. +Once you have an algorithm and a collector, you can start the training process. The trainer orchestrates the training loop and calls upon the algorithm's specific network updating logic. Each algorithm creates its appropriate trainer type through the :meth:`~tianshou.algorithm.Algorithm.create_trainer` method. -Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/03_api/trainer/index` for the usage. +Tianshou has three main trainer classes: :class:`~tianshou.trainer.OnPolicyTrainer` for on-policy algorithms such as Policy Gradient, :class:`~tianshou.trainer.OffPolicyTrainer` for off-policy algorithms such as DQN, and :class:`~tianshou.trainer.OfflineTrainer` for offline algorithms such as BCQ. -We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnPolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic: +The typical workflow is: :: - trainer = OnPolicyTrainer(...) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) - do_something_with_policy() - query_something_about_policy() - make_a_plot_with(epoch_stat) - display(info) + # Create algorithm with policy + algorithm = DQN(policy=policy, optim=optimizer_factory, ...) + + # Create trainer parameters + params = OffPolicyTrainerParams( + max_epochs=100, + step_per_epoch=1000, + train_collector=train_collector, + test_collector=test_collector, + ... + ) + + # Run training (trainer is created automatically) + result = algorithm.run_training(params) - # or even iterate on several trainers at the same time +You can also create trainers manually for more control: +:: - trainer1 = OnPolicyTrainer(...) - trainer2 = OnPolicyTrainer(...) - for result1, result2, ... in zip(trainer1, trainer2, ...): - compare_results(result1, result2, ...) + trainer = algorithm.create_trainer(params) + result = trainer.run() .. _pseudocode: @@ -414,22 +400,31 @@ We give a high-level explanation through the pseudocode used in section :ref:`pr # pseudocode, cannot work # methods in tianshou obs = env.reset() buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000) - agent = DQN() # policy.__init__(...) + algorithm = DQN(policy=policy, ...) # algorithm.__init__(...) for i in range(int(1e6)): # done in trainer - act = agent.compute_action(obs) # act = policy(batch, ...).act + act = algorithm.policy.compute_action(obs) # act = policy.compute_action(obs) obs_next, rew, done, _ = env.step(act) # collector.collect(...) buffer.store(obs, act, obs_next, rew, done) # collector.collect(...) obs = obs_next # collector.collect(...) if i % 1000 == 0: # done in trainer - # the following is done in policy.update(batch_size, buffer) + # the following is done in algorithm.update(batch_size, buffer) b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size) # compute 2-step returns. How? - b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indices) + b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # algorithm._preprocess_batch(batch, buffer, indices) # update DQN policy - agent.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # policy.learn(batch, ...) + algorithm.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # algorithm._update_with_batch(batch) Conclusion ---------- -So far, we go through the overall framework of Tianshou. Really simple, isn't it? +So far, we've covered the overall framework of Tianshou with its new architecture centered around the Algorithm abstraction. The key components are: + +- **Algorithm**: Encapsulates the complete RL learning method, containing a policy and defining how to update it +- **Policy**: Handles the mapping from observations to actions +- **Collector**: Manages environment interaction and data collection +- **Trainer**: Orchestrates the training loop and calls the algorithm's update logic +- **Buffer**: Stores and manages experience data +- **Batch**: A flexible data structure for passing data between components. Batches are collected to the buffer by the Collector and are sampled from the buffer by the `Algorithm` where they are used for learning. + +This modular design cleanly separates concerns while maintaining the flexibility to implement various RL algorithms. diff --git a/docs/01_tutorials/07_cheatsheet.rst b/docs/01_tutorials/07_cheatsheet.rst index 0215efe42..fc747d66f 100644 --- a/docs/01_tutorials/07_cheatsheet.rst +++ b/docs/01_tutorials/07_cheatsheet.rst @@ -1,6 +1,8 @@ Cheat Sheet =========== +**IMPORTANT**: The content here has not yet been adjusted to the v2 version of Tianshou. It is partially outdated and will be updated soon. + This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. diff --git a/docs/02_notebooks/0_intro.md b/docs/02_notebooks/0_intro.md index e4d839c2a..e68b36e63 100644 --- a/docs/02_notebooks/0_intro.md +++ b/docs/02_notebooks/0_intro.md @@ -5,3 +5,5 @@ directly in colab, or download them and run them locally. They will guide you step by step to show you how the most basic modules in Tianshou work and how they collaborate with each other to conduct a classic DRL experiment. + +**IMPORTANT**: The notebooks are not yet adjusted to the v2 version of Tianshou! Their content is partly outdated and will be updated soon. From 3d5ab5f54c28ba8ee588adb9d415bfe0cb449e6c Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 14 Jul 2025 14:14:43 +0200 Subject: [PATCH 227/230] v2: Docs. Updated readme and concepts_rst to use v2 structure policies --- README.md | 74 ++++++++++++++++++++----------- docs/01_tutorials/01_concepts.rst | 2 +- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 02ef2f052..0ccb4590e 100644 --- a/README.md +++ b/README.md @@ -361,10 +361,13 @@ train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_ test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)]) ``` -Create the network as well as its optimizer: +Create the network, policy, and algorithm: ```python from tianshou.utils.net.common import Net +from tianshou.algorithm import DQN +from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy +from tianshou.algorithm.optim import AdamOptimizerFactory # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network @@ -375,44 +378,61 @@ net = Net( state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128] ) -optim = torch.optim.Adam(net.parameters(), lr=lr) -``` - -Set up the policy and collectors: -```python -policy = ts.policy.DQN( +policy = DiscreteQLearningPolicy( model=net, - optim=optim, - discount_factor=gamma, action_space=env.action_space, - estimation_step=n_step, + eps_training=eps_train, + eps_inference=eps_test +) + +# Create the algorithm with the policy and optimizer factory +algorithm = DQN( + policy=policy, + optim=AdamOptimizerFactory(lr=lr), + gamma=gamma, + n_step_return_horizon=n_step, target_update_freq=target_freq ) +``` + +Set up the collectors: + +```python train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True) test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method ``` -Let's train it: +Let's train it using the algorithm: ```python -result = ts.trainer.OffPolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=epoch, - epoch_num_steps=epoch_num_steps, - collection_step_num_env_steps=collection_step_num_env_steps, - episode_per_test=test_num, - batch_size=batch_size, - update_per_step=1 / collection_step_num_env_steps, - train_fn=lambda epoch, env_step: policy.set_eps_training(eps_train), - test_fn=lambda epoch, env_step: policy.set_eps_training(eps_test), - stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, - logger=logger, -).run() +from tianshou.highlevel.config import OffPolicyTrainingConfig + +# Create training configuration +training_config = OffPolicyTrainingConfig( + max_epochs=epoch, + epoch_num_steps=epoch_num_steps, + batch_size=batch_size, + num_train_envs=train_num, + num_test_envs=test_num, + buffer_size=buffer_size, + collection_step_num_env_steps=collection_step_num_env_steps, + update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps, + test_step_num_episodes=test_num, +) + +# Run training (trainer is created automatically by the algorithm) +result = algorithm.run_training( + training_config=training_config, + train_collector=train_collector, + test_collector=test_collector, + logger=logger, + train_fn=lambda epoch, env_step: policy.set_eps(eps_train), + test_fn=lambda epoch, env_step: policy.set_eps(eps_test), + stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, +) print(f"Finished training in {result.timing.total_time} seconds") ``` @@ -427,7 +447,7 @@ Watch the agent with 35 FPS: ```python policy.eval() -policy.set_eps_training(eps_test) +policy.set_eps(eps_test) collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) ``` diff --git a/docs/01_tutorials/01_concepts.rst b/docs/01_tutorials/01_concepts.rst index 4204672b7..28b0dc276 100644 --- a/docs/01_tutorials/01_concepts.rst +++ b/docs/01_tutorials/01_concepts.rst @@ -332,7 +332,7 @@ The :class:`~tianshou.data.Collector` enables the policy to interact with differ The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation. Here are some example usages: :: - policy = PGPolicy(...) # or other policies if you wish + policy = DiscreteQLearningPolicy(...) # or other policies if you wish env = gym.make("CartPole-v1") replay_buffer = ReplayBuffer(size=10000) From 4215eafc95079d13f0d9fd0a3e7ae90d9d7f2d13 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 14 Jul 2025 16:00:11 +0200 Subject: [PATCH 228/230] v2: Changelog: Add information on changes pertaining to lagged networks --- CHANGELOG.md | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d8e11f1f..4f2b7c8ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,11 +22,14 @@ Developers: and offline learning: The base class is no longer a "God" class (formerly `BaseTrainer`) which does it all; logic and functionality has moved to the respective subclasses (`OnPolicyTrainer`, `OffPolicyTrainer` and `OfflineTrainer`, with `OnlineTrainer` being introduced as a base class for the two former specialisations). + * The trainers now use configuration objects with central documentation (which has been greatly improved to enhance clarity and usability in general); every type of trainer now has a dedicated configuration class which provides precisely the options that are applicable. + * The interface has been streamlined with improved naming of functions/parameters and limiting the public interface to purely the methods and attributes a user should reasonably access. + * Further changes potentially affecting usage: * We dropped the iterator semantics: Method `__next__` has been replaced by `execute_epoch`. #913 * We no longer report outdated statistics (e.g. on rewards/returns when a training step does not collect any full @@ -41,11 +44,13 @@ Developers: differentiated and makes the use of callback functions (`train_fn`, `test_fn`) unnecessary if only constants are to be set. * The setter method `set_eps` has been replaced with `set_eps_training` and `set_eps_inference` accordingly. + * Further internal changes unlikely to affect usage: * Module `trainer.utils` was removed and the functions therein where moved to class `Trainer` * The two places that collected and evaluated test episodes (`_test_in_train` and `_reset`) in addition to `_test_step` were unified to use `_test_step` (with some minor parametrisation) and now log the results of the test step accordingly. + * Issues resolved: * Methods `run` and `reset`: Parameter `reset_prior_to_run` of `run` was never respected if it was set to `False`, because the implementation of `__iter__` (now removed) would call `reset` regardless - and calling `reset` @@ -59,7 +64,8 @@ Developers: This is an inconsistency which has been resolved. * The `gradient_step` counter was flawed (as it made assumptions about the underlying algorithms, which were not valid). It has been replaced with an update step counter. - Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed accordingly. + Members of `InfoStats` and parameters of `Logger` (and subclasses) were changed accordingly. + * Migration information at a glance: * Training parameters are now passed via instances of configuration objects instead of directly as keyword arguments: `OnPolicyTrainerParams`, `OffPolicyTrainerParams`, `OfflineTrainerParams`. @@ -83,16 +89,21 @@ Developers: ### Algorithms and Policies * We now conceptually differentiate between the learning algorithm and the policy being optimised: + * The abstraction `BasePolicy` is thus replaced by `Algorithm` and `Policy`, and the package was renamed from `tianshou.policy` to `tianshou.algorithm`. + * Migration information: The instantiation of a policy is replaced by the instantiation of an `Algorithm`, which is passed a `Policy`. In most cases, the former policy class name `Policy` is replaced by algorithm class ``; exceptions are noted below. + * `ImitationPolicy` -> `OffPolicyImitationLearning`, `OfflineImitationLearning` * `PGPolicy` -> `Reinforce` * `MultiAgentPolicyManager` -> `MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm` * `MARLRandomPolicy` -> `MARLRandomDiscreteMaskedOffPolicyAlgorithm` + For the respective subtype of `Policy` to use, see the respective algorithm class' constructor. + * Interface changes/improvements: * Core methods have been renamed (and removed from the public interface; #898): * `process_fn` -> `_preprocess_batch` @@ -120,7 +131,9 @@ Developers: * `clip_grad` -> `max_grad_norm` (for consistency) * `clip_loss_grad` -> `huber_loss_delta` (allowing to control not only the use of the Huber loss but also its essential parameter) * `estimation_step` -> `n_step_return_horizon` (more precise naming) + * Internal design improvements: + * Introduced an abstraction for the alpha parameter (coefficient of the entropy term) in `SAC`, `DiscreteSAC` and other algorithms. * Class hierarchy: @@ -130,11 +143,14 @@ Developers: * The (auto-)updating logic is now completely encapsulated, reducing the complexity of the algorithms. * Implementations for continuous and discrete cases now share the same abstraction, making the codebase more consistent while preserving the original functionality. + * Introduced a policy base class `ContinuousPolicyWithExplorationNoise` which encapsulates noise generation - for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). + for continuous action spaces (e.g. relevant to `DDPG`, `SAC` and `REDQ`). + * Multi-agent RL methods are now differentiated by the type of the sub-algorithms being employed (`MultiAgentOnPolicyAlgorithm`, `MultiAgentOffPolicyAlgorithm`), which renders all interfaces clean. Helper class `MARLDispatcher` has been factored out to manage the dispatching of data to the respective agents. + * Algorithms now internally use a wrapper (`Algorithm.Optimizer`) around the optimizers; creation is handled by method `_create_optimizer`. * This facilitates backpropagation steps with gradient clipping. @@ -142,6 +158,23 @@ Developers: optimizers' states are handled alongside the model parameters when calling `state_dict` or `load_state_dict` on the `Algorithm` instance. Special handling of the restoration of optimizers' state dicts was thus removed from examples and tests. + + * Lagged networks (target networks) are now conveniently handled via the new algorithm mixins + `LaggedNetworkPolyakUpdateAlgorithmMixin` and `LaggedNetworkFullUpdateAlgorithmMixin`. + Using these mixins, + + * a lagged network can simply be added by calling `_add_lagged_network` + * the torch method `train` must no longer be overridden to ensure that the target networks + are never set to train mode/remain in eval mode (which was prone to errors), + * a method which updates all target networks with their source networks is automatically + provided and does not need to be implemented specifically for every algorithm + (`_update_lagged_network_weights`). + + All classes which make use of lagged networks were updated to use these mixins, simplifying + the implementations and reducing the potential for implementation errors. + (In the BCQ implementation, the VAE network was not correctly handled, but due to the way + in which examples were structured, it did not result in an error.) + * Fixed issues in the class hierarchy (particularly critical violations of the Liskov substitution principle): * Introduced base classes (to retain factorization without abusive inheritance): * `ActorCriticOnPolicyAlgorithm` @@ -187,6 +220,7 @@ Developers: * Learning rate schedulers remain separate parameters and now use `LRSchedulerFactoryFactory` instances. The respective parameter names now use the suffix `lr_scheduler` instead of `lr_scheduler_factory` (as the precise nature need not be reflected in the name; brevity is preferable). + * `SamplingConfig` is replaced by `TrainingConfig` and subclasses differentiating off-policy and on-policy cases appropriately (`OnPolicyTrainingConfig`, `OffPolicyTrainingConfig`). * The `test_in_train` parameter is now exposed (default False). From 8a49660d92c07367d0522797b2084c18faca9aef Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 14 Jul 2025 19:30:31 +0200 Subject: [PATCH 229/230] v2: Fix typo in docstring --- tianshou/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer/trainer.py index 13319c953..adec7e066 100644 --- a/tianshou/trainer/trainer.py +++ b/tianshou/trainer/trainer.py @@ -327,7 +327,7 @@ class OfflineTrainerParams(TrainerParams): batch_size: int = 64 """ - the the number of environment steps/transitions to sample from the buffer for a gradient update. + the number of environment steps/transitions to sample from the buffer for a gradient update. """ From c74cc1794e3c9b64b1aa30a335f701ecddceb02d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 14 Jul 2025 22:53:45 +0200 Subject: [PATCH 230/230] v2: Update examples in README --- README.md | 106 ++++++++++++++++++++++++++---------------------------- 1 file changed, 50 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 0ccb4590e..4eb9d2154 100644 --- a/README.md +++ b/README.md @@ -215,65 +215,60 @@ In the high-level API, the basis for an RL experiment is an `ExperimentBuilder` with which we can build the experiment we then seek to run. Since we want to use DQN, we use the specialization `DQNExperimentBuilder`. -As imports, we need only the experiment builder itself, the environment factory -and some configuration classes: +The high-level API provides largely declarative semantics, i.e. the code is +almost exclusively concerned with configuration that controls what to do +(rather than how to do it). ```python from tianshou.highlevel.config import OffPolicyTrainingConfig from tianshou.highlevel.env import ( - EnvFactoryRegistered, - VectorEnvType, + EnvFactoryRegistered, + VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig from tianshou.highlevel.params.algorithm_params import DQNParams from tianshou.highlevel.trainer import ( - EpochStopCallbackRewardThreshold, + EpochStopCallbackRewardThreshold, ) -``` -The high-level API provides largely declarative semantics, i.e. the code is -almost exclusively concerned with configuration that controls what to do -(rather than how to do it). - -```python experiment = ( - DQNExperimentBuilder( - EnvFactoryRegistered( - task="CartPole-v1", - venv_type=VectorEnvType.DUMMY, - train_seed=0, - test_seed=10, - ), - ExperimentConfig( - persistence_enabled=False, - watch=True, - watch_render=1 / 35, - watch_num_episodes=100, - ), - OffPolicyTrainingConfig( - num_epochs=10, - epoch_num_steps=10000, - batch_size=64, - num_train_envs=10, - num_test_envs=100, - buffer_size=20000, - collection_step_num_env_steps=10, - update_per_step=1 / 10, - ), - ) - .with_dqn_params( - DQNParams( - lr=1e-3, - discount_factor=0.9, - n_step_return_horizon=3, - target_update_freq=320, - eps_training=0.3, - eps_inference=0.0, - ), - ) - .with_model_factory_default(hidden_sizes=(64, 64)) - .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) - .build() + DQNExperimentBuilder( + EnvFactoryRegistered( + task="CartPole-v1", + venv_type=VectorEnvType.DUMMY, + train_seed=0, + test_seed=10, + ), + ExperimentConfig( + persistence_enabled=False, + watch=True, + watch_render=1 / 35, + watch_num_episodes=100, + ), + OffPolicyTrainingConfig( + max_epochs=10, + epoch_num_steps=10000, + batch_size=64, + num_train_envs=10, + num_test_envs=100, + buffer_size=20000, + collection_step_num_env_steps=10, + update_step_num_gradient_steps_per_sample=1 / 10, + ), + ) + .with_dqn_params( + DQNParams( + lr=1e-3, + gamma=0.9, + n_step_return_horizon=3, + target_update_freq=320, + eps_training=0.3, + eps_inference=0.0, + ), + ) + .with_model_factory_default(hidden_sizes=(64, 64)) + .with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195)) + .build() ) experiment.run() ``` @@ -300,13 +295,13 @@ The experiment builder takes three arguments: of data (`batch_size=64`) taken from the buffer of data that has been collected. For further details, see the documentation of configuration class. -We then proceed to configure some of the parameters of the DQN algorithm itself -and of the neural network model we want to use. -A DQN-specific detail is the way in which we control the epsilon parameter for -exploration. +We then proceed to configure some of the parameters of the DQN algorithm itself: +For instance, we control the epsilon parameter for exploration. We want to use random exploration during rollouts for training (`eps_training`), but we don't when evaluating the agent's performance in the test environments (`eps_inference`). +Furthermore, we configure model parameters of the network for the Q function, +parametrising the number of hidden layers of the default MLP factory. Find the script in [examples/discrete/discrete_dqn_hl.py](examples/discrete/discrete_dqn_hl.py). Here's a run (with the training time cut short): @@ -317,7 +312,7 @@ Here's a run (with the training time cut short): Find many further applications of the high-level API in the `examples/` folder; look for scripts ending with `_hl.py`. -Note that most of these examples require the extra package `argparse` +Note that most of these examples require the extra `argparse` (install it by adding `--extras argparse` when invoking poetry). ### Procedural API @@ -325,7 +320,7 @@ Note that most of these examples require the extra package `argparse` Let us now consider an analogous example in the procedural API. Find the full script in [examples/discrete/discrete_dqn.py](https://github.com/thu-ml/tianshou/blob/master/examples/discrete/discrete_dqn.py). -First, import some relevant packages: +First, import the relevant packages: ```python import gymnasium as gym @@ -334,7 +329,7 @@ from torch.utils.tensorboard import SummaryWriter import tianshou as ts ``` -Define some hyper-parameters: +Define hyper-parameters: ```python task = 'CartPole-v1' @@ -350,7 +345,6 @@ Initialize the logger: ```python logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) -# For other loggers, see https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html ``` Make environments: